From ddfcbed12fe580594f586f3ab7c5a7663d7e8bfa Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 15 Jun 2025 11:41:15 -0400 Subject: [PATCH 001/258] [Cute] Set check_inf=True always, return smem_pipe_read --- flash_attn/cute/flash_fwd.py | 38 ++++++++++++++++++------------------ flash_attn/cute/softmax.py | 22 ++++++++++++++------- flash_attn/cute/utils.py | 1 - 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d8ddd1ae443..e4178015743 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -837,7 +837,7 @@ def preprocess_Q(): smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range_dynamic(n_block, unroll=1): - compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=False) + compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -869,7 +869,7 @@ def compute_one_n_block( scoremod_premask_fn: Callable, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): """Compute one n_block of S/O. @@ -1448,11 +1448,10 @@ def scoremod_premask_fn(acc_S): acc_O.fill(0.0) else: self.warp_scheduler_barrier_sync() - compute_one_n_block( + smem_pipe_read = compute_one_n_block( n_block, smem_pipe_read, tiled_mma_qk, tiled_mma_pv, is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) ) - smem_pipe_read.advance() # Next couple of iterations with causal masking if cutlass.const_expr(self.is_causal): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( @@ -1461,18 +1460,16 @@ def scoremod_premask_fn(acc_S): # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask - compute_one_n_block( + smem_pipe_read = compute_one_n_block( n_block, smem_pipe_read, tiled_mma_qk_copy, tiled_mma_pv_copy, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) ) - smem_pipe_read.advance() # The remaining iterations have no masking for n_tile in cutlass.range_dynamic(n_block, unroll=1): - compute_one_n_block( + smem_pipe_read = compute_one_n_block( n_block - n_tile - 1, smem_pipe_read, tiled_mma_qk_copy1, tiled_mma_pv_copy1, - check_inf=False, + check_inf=True, ) - smem_pipe_read.advance() # Last "half" iteration if cutlass.const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) @@ -1519,7 +1516,7 @@ def compute_one_n_block( scoremod_premask_fn: Callable, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): acc_S = cute.make_fragment( tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 @@ -1556,6 +1553,8 @@ def compute_one_n_block( zero_init=is_first_n_block, wg_wait=0 ) pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read @cute.jit def compute_one_n_block_intrawg_overlap( @@ -1571,29 +1570,29 @@ def compute_one_n_block_intrawg_overlap( softmax: Softmax, scoremod_premask_fn: Callable, mask_fn: Optional[Callable] = None, - check_inf: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, ): - smem_pipe_read_k = smem_pipe_read.clone() - smem_pipe_read_k.advance() + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() acc_S = cute.make_fragment( tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 ) - pipeline_k.consumer_wait(smem_pipe_read_k, pipeline_k.consumer_try_wait(smem_pipe_read_k)) + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() sm90_utils.gemm( tiled_mma_qk, acc_S, mma_params.tSrQ, - mma_params.tSrK[None, None, None, smem_pipe_read_k.index], + mma_params.tSrK[None, None, None, smem_pipe_read.index], zero_init=True, wg_wait=-1 ) - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) sm90_utils.gemm( tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], + mma_params.tOrVt[None, None, None, smem_pipe_read_v.index], zero_init=False, wg_wait=-1 ) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) - pipeline_k.consumer_release(smem_pipe_read_k) + pipeline_k.consumer_release(smem_pipe_read) scoremod_premask_fn(acc_S) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) @@ -1601,7 +1600,7 @@ def compute_one_n_block_intrawg_overlap( # cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read) + pipeline_v.consumer_release(smem_pipe_read_v) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) @@ -1611,6 +1610,7 @@ def compute_one_n_block_intrawg_overlap( # Fence and barrier to make sure smem store is visible to WGMMA cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read @cute.jit def mma_init(self): diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index a658d072585..a7bb2305955 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -11,10 +11,16 @@ class Softmax: - def __init__(self, scale_log2: cutlass.Float32, num_rows: cutlass.Constexpr[int]): + def __init__( + self, + scale_log2: cutlass.Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + ): self.scale_log2 = scale_log2 self.row_max = cute.make_fragment(num_rows, cutlass.Float32) self.row_sum = cute.make_fragment_like(self.row_max) + self.arch = arch def reset(self) -> None: self.row_max.fill(-cutlass.Float32.inf) @@ -40,20 +46,22 @@ def online_softmax( # Each iteration processes one row of acc_S for r in range(cute.size(self.row_max)): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) - row_max_cur = acc_S_row.reduce(cute.ReductionOp.MAX, -cutlass.Float32.inf, 0) + row_max_cur = acc_S_row.reduce( + cute.ReductionOp.MAX, + -cutlass.Float32.inf if cutlass.const_expr(is_first) else self.row_max[r], + 0 + ) row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) + if cutlass.const_expr(check_inf): + if row_max_cur == -cutlass.Float32.inf: + row_max_cur = 0.0 if cutlass.const_expr(is_first): - if check_inf: - row_max_cur = 0.0 if row_max_cur == -cutlass.Float32.inf else row_max_cur row_max_cur_scaled = row_max_cur * self.scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) row_scale[r] = 1.0 else: row_max_prev = self.row_max[r] - row_max_cur = cute.arch.fmax(row_max_prev, row_max_cur) - if check_inf: - row_max_cur = 0.0 if row_max_cur == -cutlass.Float32.inf else row_max_cur row_max_cur_scaled = row_max_cur * self.scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 3768fa3a9a1..771045cb42e 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -84,7 +84,6 @@ def get_smem_store_atom(arch: cutlass.Constexpr[int], element_type: Type[cute.Nu ) - def max_constexpr( a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric] ) -> cutlass.Constexpr[cute.Numeric]: From 3733dbba37682e40ce04d584c5f3d415dcc7f4f4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 15 Jun 2025 11:41:45 -0400 Subject: [PATCH 002/258] Set line-length for ruff --- flash_attn/pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flash_attn/pyproject.toml b/flash_attn/pyproject.toml index 3201555763e..ce5eac916cd 100644 --- a/flash_attn/pyproject.toml +++ b/flash_attn/pyproject.toml @@ -1,3 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py38'] \ No newline at end of file +target-version = 'py39' +[tool.ruff] +line-length = 100 +target-version = 'py39' \ No newline at end of file From ecccf022220df95ec2f10fb52415e6a2ef3d7acd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 15 Jun 2025 15:27:00 -0400 Subject: [PATCH 003/258] [Cute] Refactor Softmax, add fmax_reduce and fadd_reduce --- flash_attn/cute/flash_bwd_postprocess.py | 4 +- flash_attn/cute/softmax.py | 67 ++++++++--- flash_attn/cute/utils.py | 140 ++++++++++++++++++++--- 3 files changed, 179 insertions(+), 32 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index ccb33d2c026..3662de580a6 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -132,7 +132,7 @@ def __call__( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, - scale: cute.Float32, + scale: cutlass.Float32, stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 @@ -185,7 +185,7 @@ def kernel( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, - scale: cute.Float32, + scale: cutlass.Float32, tiled_mma: cute.TiledMma, dQ_swapAB: cutlass.Constexpr, sdQaccum_layout: cute.Layout, diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index a7bb2305955..2273718aed8 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -2,9 +2,11 @@ import math import operator +from typing import Tuple import cutlass import cutlass.cute as cute +from cutlass import Float32 import flash_attn.cute.utils as utils @@ -13,19 +15,33 @@ class Softmax: def __init__( self, - scale_log2: cutlass.Float32, + scale_log2: Float32, num_rows: cutlass.Constexpr[int], arch: cutlass.Constexpr[int] = 80, ): self.scale_log2 = scale_log2 - self.row_max = cute.make_fragment(num_rows, cutlass.Float32) + self.row_max = cute.make_fragment(num_rows, Float32) self.row_sum = cute.make_fragment_like(self.row_max) self.arch = arch def reset(self) -> None: - self.row_max.fill(-cutlass.Float32.inf) + self.row_max.fill(-Float32.inf) self.row_sum.fill(0.0) + def _compute_row_max( + self, + acc_S_row: cute.TensorSSA, + init_val: float | Float32 = -Float32.inf + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, + acc_S_row_exp: cute.TensorSSA, + init_val: float | Float32 = Float32.zero + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + @cute.jit def online_softmax( self, @@ -42,44 +58,42 @@ def online_softmax( """ # Change acc_S to M,N layout view. acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - row_scale = cute.make_fragment_like(self.row_max, cutlass.Float32) + row_scale = cute.make_fragment_like(self.row_max, Float32) # Each iteration processes one row of acc_S for r in range(cute.size(self.row_max)): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) - row_max_cur = acc_S_row.reduce( - cute.ReductionOp.MAX, - -cutlass.Float32.inf if cutlass.const_expr(is_first) else self.row_max[r], - 0 + row_max_cur = self._compute_row_max( + acc_S_row, + init_val=-Float32.inf if cutlass.const_expr(is_first) else self.row_max[r], ) row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) if cutlass.const_expr(check_inf): - if row_max_cur == -cutlass.Float32.inf: + if row_max_cur == -Float32.inf: row_max_cur = 0.0 if cutlass.const_expr(is_first): row_max_cur_scaled = row_max_cur * self.scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) - acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) + acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) row_scale[r] = 1.0 else: row_max_prev = self.row_max[r] row_max_cur_scaled = row_max_cur * self.scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) - acc_S_row_sum = acc_S_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) - acc_S_row_sum = acc_S_row_sum + self.row_sum[r] * row_scale[r] + acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) + self.row_sum[r] * row_scale[r] self.row_max[r] = row_max_cur self.row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) return row_scale @cute.jit - def finalize(self, final_scale: cute.Float32 = 1.0) -> cute.Tensor: + def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: """Finalize the online softmax by computing the scale and logsumexp. """ # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) - row_scale = cute.make_fragment_like(self.row_max, cutlass.Float32) + row_scale = cute.make_fragment_like(self.row_max, Float32) for r in range(cute.size(self.row_sum)): # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] @@ -90,7 +104,7 @@ def finalize(self, final_scale: cute.Float32 = 1.0) -> cute.Tensor: LN2 = math.log(2.0) self.row_sum[r] = ( (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) return row_scale @@ -106,3 +120,26 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) for r in range(cute.size(row_scale)): acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +class SoftmaxSm100(Softmax): + + def __init__(self, scale_log2: Float32): + super().__init__(scale_log2, num_rows=1, arch=100) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA) -> Tuple[Float32, Float32]: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) + if acc_scale_ >= -8.0: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32) -> None: + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 771045cb42e..4b0b0ce2e47 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -6,6 +6,7 @@ import cutlass import cutlass.cute as cute +from cutlass import Float32 from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack @@ -160,30 +161,45 @@ def transpose_view(a: cute.Tensor) -> cute.Tensor: return cute.composition(a, cute.make_ordered_layout(shape, order=order)) -def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32: +@dsl_user_op +def exp2f_asm(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "ex2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: """exp2f calculation for both vector and scalar. :param x: input value - :type x: cute.TensorSSA or cutlass.Float32 + :type x: cute.TensorSSA or Float32 :return: exp2 value - :rtype: cute.TensorSSA or cutlass.Float32 + :rtype: cute.TensorSSA or Float32 """ if isinstance(x, cute.TensorSSA): - res = cute.make_fragment(x.shape, cutlass.Float32) + res = cute.make_fragment(x.shape, Float32) res.store(x) for i in range(cute.size(x.shape)): - res[i] = cute.arch.exp2(res[i]) + res[i] = exp2f_asm(res[i]) return res.load() else: - return cute.arch.exp2(x) + return exp2f_asm(x) @dsl_user_op -def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32: - return cutlass.Float32( +def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( llvm.inline_asm( T.f32(), - [cutlass.Float32(a).ir_value(loc=loc, ip=ip)], + [Float32(a).ir_value(loc=loc, ip=ip)], "lg2.approx.ftz.f32 $0, $1;", "=f,f", has_side_effects=False, @@ -193,16 +209,110 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32: ) +@dsl_user_op +def max3f(a: float | Float32, b: float | Float32, c: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), Float32(c).ir_value(loc=loc, ip=ip)], + "max.f32 $0, $1, $2, $3;", + "=f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +def fmax_reduce( + x: cute.TensorSSA, + init_val: float | Float32 = -Float32.inf, + arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + return x.reduce(cute.ReductionOp.MAX, init_val, 0) + else: + # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max + # We instead force the 3-input max by calling inline ptx. + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max = [init_val, -Float32.inf, -Float32.inf, -Float32.inf] + for i in range(0, cute.size(x.shape), 8): + local_max[0] = max3f(local_max[0], res[i], res[i + 1]) + local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) + local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) + return cute.arch.fmax(local_max[0], local_max[2]) + + # local_max = [cute.arch.fmax(res[0], res[1]), cute.arch.fmax(res[2], res[3]), + # cute.arch.fmax(res[4], res[5]), cute.arch.fmax(res[6], res[7])] + # for i in range(8, cute.size(x.shape), 8): + # local_max[0] = max3f(local_max[0], res[i], res[i + 1]) + # local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) + # local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) + # local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) + # local_max[0] = max3f(init_val, local_max[0], local_max[1]) + # local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) + # return cute.arch.fmax(local_max[0], local_max[2]) + + # local_max = [res[0], res[1], res[2], res[3]] + # for i in range(4, cute.size(x.shape), 8): + # local_max[0] = max3f(local_max[0], res[i], res[i + 1]) + # local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) + # local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) + # local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) + # i_f = cutlass.const_expr(cute.size(x.shape) - 4) + # # local_max[0] = max3f(local_max[0], res[i_f], res[i_f + 1]) + # # local_max[1] = max3f(local_max[1], res[i_f + 2], res[i_f + 3]) + # # local_max[0] = max3f(local_max[0], local_max[1], local_max[2]) + # # return max3f(local_max[0], local_max[3], init_val) + # local_max[0] = cute.arch.fmax(local_max[0], res[i_f]) + # local_max[1] = cute.arch.fmax(local_max[1], res[i_f + 1]) + # local_max[2] = cute.arch.fmax(local_max[2], res[i_f + 2]) + # local_max[3] = cute.arch.fmax(local_max[3], res[i_f + 3]) + # local_max[0] = max3f(local_max[0], local_max[1], init_val) + # local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) + # return cute.arch.fmax(local_max[0], local_max[2]) + + # local_max[0] = max3f(local_max[0], local_max[1], local_max[2]) + # return cute.arch.fmax(local_max[0], local_max[3]) + + +def fadd_reduce( + x: cute.TensorSSA, + init_val: float | Float32 = Float32.zero, + arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + else: + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_sum_0 = cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in range(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + @dsl_user_op def atomic_add_fp32( - a: float | cutlass.Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None + a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None ) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # # cache_hint = cutlass.Int64(0x12F0000000000000) # llvm.inline_asm( # None, - # [gmem_ptr_i64, cutlass.Float32(a).ir_value(loc=loc, ip=ip)], - # # [gmem_ptr_i64, cutlass.Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], + # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], # "red.global.add.f32 [$0], $1;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", @@ -216,7 +326,7 @@ def atomic_add_fp32( res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, - a=cutlass.Float32(a).ir_value() + a=Float32(a).ir_value() ) @@ -295,12 +405,12 @@ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: # @dsl_user_op -# def warp_vote_any_lt(a: float | cutlass.Float32, b: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Boolean: +# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: # mask = cutlass.Int32(-1) # return cutlass.Boolean( # llvm.inline_asm( # T.i32(), -# [cutlass.Float32(a).ir_value(loc=loc, ip=ip), cutlass.Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], +# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], # ".pred p1, p2;\n" # "setp.lt.f32 p1, $1, $2;\n" # "vote.sync.any.pred p2, p1, $3;\n" From 6c5f5ba272f472a0a98baa44ff5c19d4b8758574 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 25 Jun 2025 15:59:36 -0400 Subject: [PATCH 004/258] [Cute] Move load and mma to separate functions --- flash_attn/cute/flash_fwd.py | 491 ++++++++++++++++++++------------- flash_attn/cute/seqlen_info.py | 2 + flash_attn/cute/softmax.py | 51 +++- flash_attn/cute/utils.py | 72 ++--- 4 files changed, 364 insertions(+), 252 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index e4178015743..46c36aa7027 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1273,18 +1273,17 @@ def kernel( self.m_block_size, self.n_block_size, self.is_causal, self.qhead_per_kvhead if self.pack_gqa else 1, ) - seqlen = SeqlenInfo( - batch_idx, mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], mK.shape[0], mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK + SeqlenInfoCls = partial( + SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) + seqlen = SeqlenInfoCls(batch_idx) # Can't early exit so we have to write it this way (under an if statement) if mCuSeqlensQ is None or m_block * self.n_block_size < seqlen.seqlen_q: if cutlass.const_expr(self.is_causal): # Longest tile first m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 - if cutlass.const_expr(mCuSeqlensQ is None): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: # if n_block_max <= 0: @@ -1292,55 +1291,22 @@ def kernel( if warp_idx < 4: # Producer cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - head_idx_kv = head_idx // self.qhead_per_kvhead if not self.pack_gqa else head_idx - if cutlass.const_expr(mCuSeqlensK is None): - mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] - else: - mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] - gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) - gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) - if cutlass.const_expr(not self.pack_gqa): - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, - cute.make_layout(1), - cute.group_modes(sQ, 0, 2), - cute.group_modes(gQ, 0, 2), - ) - tKsK, tKgK = cpasync.tma_partition( + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, tma_atom_K, - 0, - cute.make_layout(1), - cute.group_modes(sK, 0, 2), - cute.group_modes(gK, 0, 2), - ) - tVsV, tVgV = cpasync.tma_partition( tma_atom_V, - 0, - cute.make_layout(1), - cute.group_modes(sV, 0, 2), - cute.group_modes(gV, 0, 2), + pipeline_k, + pipeline_v, + mbar_ptr_Q, + block_info, + SeqlenInfoCls ) - smem_pipe_write = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Producer, self.num_stages - ) - load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) - load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) - if warp_idx == 0: # Producer - # load_Q - if cutlass.const_expr(not self.pack_gqa): - with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes(mbar_ptr_Q, self.tma_copy_q_bytes) - cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) - for n_tile in cutlass.range_dynamic(n_block_max, unroll=2): - n_block = n_block_max - n_tile - 1 - load_K(n_block, smem_pipe_write=smem_pipe_write) - load_V(n_block, smem_pipe_write=smem_pipe_write) - smem_pipe_write.advance() else: # Consumer cute.arch.warpgroup_reg_alloc(self.num_mma_regs) @@ -1348,152 +1314,38 @@ def kernel( # Tile MMA compute thread partitions and allocate accumulators # /////////////////////////////////////////////////////////////////////////////// tidx = tidx - 128 - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group - ) - thr_mma_qk = tiled_mma_qk.get_slice(tidx) - wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) - tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) - tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) - tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if cutlass.const_expr(sP is not None) else None - tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) - - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_P = utils.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) - # tPsP = smem_thr_copy_P.partition_D(sP_pi) if cutlass.const_expr(sP_pi is not None) else None - tPsP = smem_thr_copy_P.partition_D(sP) if cutlass.const_expr(sP is not None) else None - # if cute.arch.thread_idx()[0] == 0: - # cute.printf(sP_pi.layout, sP_pi.iterator) - # cute.printf(sP.layout, sP.iterator) - # cute.printf(tPsP.layout, tPsP.iterator) - - self.mma_init() - - # shape: (atom_v_m * rest_m) softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) - softmax.reset() - # group parameters for compute_one_n_block - mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) - smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if cutlass.const_expr(self.has_softcap): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) - - mask = AttentionMask( - self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, - self.qhead_per_kvhead if self.pack_gqa else 1 - ) - mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal - ) - compute_one_n_block = partial( - self.compute_one_n_block_intrawg_overlap if cutlass.const_expr(self.intra_wg_overlap) else self.compute_one_n_block, - pipeline_k=pipeline_k, pipeline_v=pipeline_v, - mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, - ) - - # Load Q if PackGQA - if cutlass.const_expr(self.pack_gqa): - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) - # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, - # headdim=mQ.shape[1]) - pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) - - n_block = n_block_max - 1 - smem_pipe_read = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Consumer, self.num_stages + self.mma( + tiled_mma_qk, + tiled_mma_pv, + softmax, + acc_O, + mQ, + sQ, + sK, + sVt, + sP, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + tidx, + softcap_val, + block_info, + SeqlenInfoCls, + tiled_mma_qk_copy, + tiled_mma_pv_copy, + tiled_mma_qk_copy1, + tiled_mma_pv_copy1, ) - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=0) - # For performance reason, we separate out two kinds of iterations: - # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. - # We also need masking on S if it's causal, for the last several blocks. - # First iteration with seqlen masking - if cutlass.const_expr(self.intra_wg_overlap): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - pipeline_k.consumer_wait(smem_pipe_read) - sm90_utils.gemm( - tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, smem_pipe_read.index], - zero_init=True, wg_wait=0 - ) - pipeline_k.consumer_release(smem_pipe_read) - scoremod_premask_fn(acc_S) - mask_fn(acc_S, n_block=n_block, mask_seqlen=True) - softmax.online_softmax(acc_S, is_first=True, check_inf=True) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_thr_copy_P.retile(rP) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - smem_pipe_read = compute_one_n_block( - n_block, smem_pipe_read, tiled_mma_qk, tiled_mma_pv, - is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) - ) - # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min - ) - # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): - n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask - smem_pipe_read = compute_one_n_block( - n_block, smem_pipe_read, tiled_mma_qk_copy, tiled_mma_pv_copy, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) - ) - # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block, unroll=1): - smem_pipe_read = compute_one_n_block( - n_block - n_tile - 1, smem_pipe_read, tiled_mma_qk_copy1, tiled_mma_pv_copy1, - check_inf=True, - ) - # Last "half" iteration - if cutlass.const_expr(self.intra_wg_overlap): - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=False, wg_wait=-1 - ) - warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read) - smem_pipe_read.advance() - else: - self.warp_scheduler_barrier_arrive() - - # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize() - softmax.rescale_O(acc_O, row_scale) - # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// # reuse sQ's data iterator sO_pi = cute.make_tensor(sQ.iterator, sO_layout) - # TODO: idk why using not using sO_pi is faster + # TODO: idk why not using sO_pi is faster sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) self.epilogue( acc_O, softmax.row_sum, mO if not self.use_tma_O else mO_tma, mLSE, sO, seqlen, @@ -1502,7 +1354,254 @@ def scoremod_premask_fn(acc_S): ) @cute.jit - def compute_one_n_block( + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_k: cutlass.utils.PipelineAsync, + pipeline_v: cutlass.utils.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx() % 4) + m_block, head_idx, batch_idx = cute.arch.block_idx() + seqlen = SeqlenInfoCls(batch_idx) + if cutlass.const_expr(self.is_causal): # Longest tile first + m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + head_idx_kv = head_idx // self.qhead_per_kvhead if not self.pack_gqa else head_idx + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] + gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) + gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + if cutlass.const_expr(not self.pack_gqa): + gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 2), + cute.group_modes(gQ, 0, 2), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 2), + cute.group_modes(gK, 0, 2), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 2), + cute.group_modes(gV, 0, 2), + ) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.utils.PipelineUserType.Producer, self.num_stages + ) + load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) + load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) + if warp_idx_in_wg == 0: + # load_Q + if cutlass.const_expr(not self.pack_gqa): + with cute.arch.elect_one(): + cute.arch.mbarrier_init_tx_bytes(mbar_ptr_Q, self.tma_copy_q_bytes) + cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + for i in cutlass.range_dynamic(n_block_max - n_block_min, unroll=2): + n_block = n_block_max - i - 1 + load_K(n_block, producer_state=kv_producer_state) + load_V(n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + softmax: Softmax, + acc_O: cute.Tensor, + mQ: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: cute.Tensor | None, + pipeline_k: cutlass.utils.PipelineAsync, + pipeline_v: cutlass.utils.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + gmem_tiled_copy_Q: cute.TiledCopy, + tidx: cutlass.Int32, + softcap_val: cutlass.Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tiled_mma_qk_copy: cute.TiledMma, + tiled_mma_pv_copy: cute.TiledMma, + tiled_mma_qk_copy1: cute.TiledMma, + tiled_mma_pv_copy1: cute.TiledMma, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) + tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if cutlass.const_expr(sP is not None) else None + tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_P = utils.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + # tPsP = smem_thr_copy_P.partition_D(sP_pi) if cutlass.const_expr(sP_pi is not None) else None + tPsP = smem_thr_copy_P.partition_D(sP) if cutlass.const_expr(sP is not None) else None + # if cute.arch.thread_idx()[0] == 0: + # cute.printf(sP_pi.layout, sP_pi.iterator) + # cute.printf(sP.layout, sP.iterator) + # cute.printf(tPsP.layout, tPsP.iterator) + + self.mma_init() + + # shape: (atom_v_m * rest_m) + # group parameters for mma_one_n_block + mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + # -inf to e.g. -50.0, which can affect the attention softmax. + def scoremod_premask_fn(acc_S): + if cutlass.const_expr(self.has_softcap): + acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) + + mma_one_n_block = partial( + self.mma_one_n_block_intrawg_overlap if cutlass.const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + pipeline_k=pipeline_k, pipeline_v=pipeline_v, + mma_params=mma_params, smem_copy_params=smem_copy_params, + softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, + ) + + m_block, head_idx, batch_idx = cute.arch.block_idx() + seqlen = SeqlenInfoCls(batch_idx) + if cutlass.const_expr(self.is_causal): # Longest tile first + m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 + + mask = AttentionMask( + self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, + self.qhead_per_kvhead if self.pack_gqa else 1 + ) + mask_fn = partial( + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal + ) + # Load Q if PackGQA + if cutlass.const_expr(self.pack_gqa): + pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block = n_block_max - 1 + consumer_state = pipeline.make_pipeline_state( + cutlass.utils.PipelineUserType.Consumer, self.num_stages + ) + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=0) + softmax.reset() + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last several blocks. + # First iteration with seqlen masking + if cutlass.const_expr(self.intra_wg_overlap): + acc_S = cute.make_fragment( + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 + ) + pipeline_k.consumer_wait(consumer_state) + sm90_utils.gemm( + tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, consumer_state.index], + zero_init=True, wg_wait=0 + ) + pipeline_k.consumer_release(consumer_state) + scoremod_premask_fn(acc_S) + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + softmax.online_softmax(acc_S, is_first=True, check_inf=True) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) + tPrP = smem_thr_copy_P.retile(rP) + cute.copy(smem_thr_copy_P, tPrP, tPsP) + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() + consumer_state = mma_one_n_block( + n_block, consumer_state, tiled_mma_qk, tiled_mma_pv, + is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) + ) + # Next couple of iterations with causal masking + if cutlass.const_expr(self.is_causal): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 + for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): + n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask + consumer_state = mma_one_n_block( + n_block, consumer_state, tiled_mma_qk_copy, tiled_mma_pv_copy, + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + ) + # The remaining iterations have no masking + for n_tile in cutlass.range_dynamic(n_block, unroll=1): + consumer_state = mma_one_n_block( + n_block - n_tile - 1, consumer_state, tiled_mma_qk_copy1, tiled_mma_pv_copy1, + check_inf=True, + ) + # Last "half" iteration + if cutlass.const_expr(self.intra_wg_overlap): + pipeline_v.consumer_wait(consumer_state, pipeline_v.consumer_try_wait(consumer_state)) + sm90_utils.gemm( + tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, + mma_params.tOrVt[None, None, None, consumer_state.index], + zero_init=False, wg_wait=-1 + ) + warpgroup.wait_group(0) + pipeline_v.consumer_release(consumer_state) + consumer_state.advance() + else: + self.warp_scheduler_barrier_arrive() + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize() + softmax.rescale_O(acc_O, row_scale) + + @cute.jit + def mma_one_n_block( self, n_block: cutlass.Int32, smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, @@ -1557,7 +1656,7 @@ def compute_one_n_block( return smem_pipe_read @cute.jit - def compute_one_n_block_intrawg_overlap( + def mma_one_n_block_intrawg_overlap( self, n_block: cutlass.Int32, smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, @@ -1647,14 +1746,14 @@ def load_K( tKsK: cute.Tensor, pipeline: cutlass.utils.PipelineAsync, block: cutlass.Int32, - smem_pipe_write: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, + producer_state: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, ): # TODO: mcast # TODO check warp_idx if we have 128 producer threads - pipeline.producer_acquire(smem_pipe_write) + pipeline.producer_acquire(producer_state) cute.copy( tma_atom, tKgK[None, block], - tKsK[None, smem_pipe_write.index], - tma_bar_ptr=pipeline.producer_get_barrier(smem_pipe_write) + tKsK[None, producer_state.index], + tma_bar_ptr=pipeline.producer_get_barrier(producer_state) ) diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index d14bfb827f9..6316e5ee814 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -26,3 +26,5 @@ def __init__( self.seqlen_k = mSeqUsedK[batch_idx] else: self.seqlen_k = seqlen_k_static if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - self.offset_k + self.has_cu_seqlens_q: int = mCuSeqlensQ is not None + self.has_cu_seqlens_k: int = mCuSeqlensK is not None diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 2273718aed8..68f577f8d27 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -124,8 +124,9 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: class SoftmaxSm100(Softmax): - def __init__(self, scale_log2: Float32): + def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0): super().__init__(scale_log2, num_rows=1, arch=100) + self.rescale_threshold = rescale_threshold @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA) -> Tuple[Float32, Float32]: @@ -134,12 +135,52 @@ def update_row_max(self, acc_S_row: cute.TensorSSA) -> Tuple[Float32, Float32]: row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 acc_scale = utils.exp2f(acc_scale_) - if acc_scale_ >= -8.0: - row_max_new = row_max_old - row_max_safe = row_max_old - acc_scale = 1.0 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 self.row_max[0] = row_max_new return row_max_safe, acc_scale def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32) -> None: self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + minus_row_max_scaled = -row_max * self.scale_log2 + # assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + # for i in range(0, cute.size(acc_S_row.shape), 2): + # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + # (acc_S_row[i], acc_S_row[i + 1]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) + # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) + + frg_cnt = 4 + frg_tile = cute.size(acc_S_row) // frg_cnt + assert cute.size(acc_S_row) % (frg_cnt * 2) == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide(acc_S_row_converted, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + cute.arch.fma_packed_f32x2( + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + ) + # acc_S_row_frg[k, j] = fa_utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = fa_utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 4b0b0ce2e47..6ea68c05677 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -210,16 +210,15 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: @dsl_user_op -def max3f(a: float | Float32, b: float | Float32, c: float | Float32, *, loc=None, ip=None) -> Float32: +def fmax(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: return Float32( - llvm.inline_asm( + nvvm.fmax( T.f32(), - [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), Float32(c).ir_value(loc=loc, ip=ip)], - "max.f32 $0, $1, $2, $3;", - "=f,f,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, ) ) @@ -233,51 +232,22 @@ def fmax_reduce( return x.reduce(cute.ReductionOp.MAX, init_val, 0) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max - # We instead force the 3-input max by calling inline ptx. + # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) - local_max = [init_val, -Float32.inf, -Float32.inf, -Float32.inf] - for i in range(0, cute.size(x.shape), 8): - local_max[0] = max3f(local_max[0], res[i], res[i + 1]) - local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) - local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) - local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) - local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) - local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) - return cute.arch.fmax(local_max[0], local_max[2]) - - # local_max = [cute.arch.fmax(res[0], res[1]), cute.arch.fmax(res[2], res[3]), - # cute.arch.fmax(res[4], res[5]), cute.arch.fmax(res[6], res[7])] - # for i in range(8, cute.size(x.shape), 8): - # local_max[0] = max3f(local_max[0], res[i], res[i + 1]) - # local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) - # local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) - # local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) - # local_max[0] = max3f(init_val, local_max[0], local_max[1]) - # local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) - # return cute.arch.fmax(local_max[0], local_max[2]) - - # local_max = [res[0], res[1], res[2], res[3]] - # for i in range(4, cute.size(x.shape), 8): - # local_max[0] = max3f(local_max[0], res[i], res[i + 1]) - # local_max[1] = max3f(local_max[1], res[i + 2], res[i + 3]) - # local_max[2] = max3f(local_max[2], res[i + 4], res[i + 5]) - # local_max[3] = max3f(local_max[3], res[i + 6], res[i + 7]) - # i_f = cutlass.const_expr(cute.size(x.shape) - 4) - # # local_max[0] = max3f(local_max[0], res[i_f], res[i_f + 1]) - # # local_max[1] = max3f(local_max[1], res[i_f + 2], res[i_f + 3]) - # # local_max[0] = max3f(local_max[0], local_max[1], local_max[2]) - # # return max3f(local_max[0], local_max[3], init_val) - # local_max[0] = cute.arch.fmax(local_max[0], res[i_f]) - # local_max[1] = cute.arch.fmax(local_max[1], res[i_f + 1]) - # local_max[2] = cute.arch.fmax(local_max[2], res[i_f + 2]) - # local_max[3] = cute.arch.fmax(local_max[3], res[i_f + 3]) - # local_max[0] = max3f(local_max[0], local_max[1], init_val) - # local_max[2] = cute.arch.fmax(local_max[2], local_max[3]) - # return cute.arch.fmax(local_max[0], local_max[2]) - - # local_max[0] = max3f(local_max[0], local_max[1], local_max[2]) - # return cute.arch.fmax(local_max[0], local_max[3]) + local_max = [ + fmax(init_val, res[0], res[1]), + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in range(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) def fadd_reduce( From a5e1a3c5fccd8dc219300f4d4bb502e17f7fd4db Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 29 Jun 2025 14:47:21 -0400 Subject: [PATCH 005/258] [Cute] Add first version of flash_fwd_sm100 --- flash_attn/cute/blackwell_helpers.py | 578 +++++++++ flash_attn/cute/flash_fwd.py | 10 +- flash_attn/cute/flash_fwd_sm100.py | 1747 ++++++++++++++++++++++++++ flash_attn/cute/interface.py | 47 +- flash_attn/cute/mask.py | 38 + flash_attn/cute/mma_sm100_desc.py | 285 +++++ flash_attn/cute/softmax.py | 40 +- flash_attn/cute/utils.py | 49 +- flash_attn/utils/testing.py | 2 +- tests/cute/test_flash_attn.py | 9 +- 10 files changed, 2759 insertions(+), 46 deletions(-) create mode 100644 flash_attn/cute/blackwell_helpers.py create mode 100644 flash_attn/cute/flash_fwd_sm100.py create mode 100644 flash_attn/cute/mma_sm100_desc.py diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py new file mode 100644 index 00000000000..9a83f4a9998 --- /dev/null +++ b/flash_attn/cute/blackwell_helpers.py @@ -0,0 +1,578 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cutlass_dsl import T +from cutlass._mlir.dialects import llvm + +import flash_attn.cute.mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | cutlass.Boolean = False, +) -> None: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if not is_ts: + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ((cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4) + smem_desc_b_lo = smem_desc_start_b_lo + ((cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if not is_ts: + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2]))] + else: + offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if not is_ts: + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2]))] + else: + offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + "mov.b32 smem_desc_a_lo, $0;\n\t" + "mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: cutlass.Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: cutlass.Int32, + sB_base_addr_for_desc: cutlass.Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: cutlass.Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if not is_ts: + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + mask = [cutlass.Int32(0)] * 4 + + if cutlass.const_expr(not is_ts): + offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2]))] + else: + offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + # smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(sA_base_addr_for_desc).ir_value(), + cutlass.Int32(sA_stage).ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(sB_base_addr_for_desc).ir_value(), + cutlass.Int32(sB_stage).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value() + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value() + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 46c36aa7027..2b4372f1811 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -829,8 +829,8 @@ def preprocess_Q(): ) # Currently we can't do loop with negative step # https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): - n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask + for n_tile in cutlass.range_dynamic(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False)) smem_pipe_read = self.advance_pipeline(smem_pipe_read) @@ -1371,7 +1371,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, ): - warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx() % 4) + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block, head_idx, batch_idx = cute.arch.block_idx() seqlen = SeqlenInfoCls(batch_idx) if cutlass.const_expr(self.is_causal): # Longest tile first @@ -1570,8 +1570,8 @@ def scoremod_premask_fn(acc_S): seqlen, m_block, n_block_min ) # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_min_causal_local_mask, n_block_max - 1, unroll=1): - n_block = n_block_max - 2 - n_tile + n_block_min_causal_local_mask + for n_tile in cutlass.range_dynamic(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile consumer_state = mma_one_n_block( n_block, consumer_state, tiled_mma_qk_copy, tiled_mma_pv_copy, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py new file mode 100644 index 00000000000..e2310b4d9f0 --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -0,0 +1,1747 @@ +# Supported features, currently only tested for hdim 128. +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# Unsupported features that will be added later: +# - varlen +# - writing out lse +# - split-kv (optimizing for inference) +# - testing more hdim (64, 256, etc) +# Based on the cutlass example and cute-dsl example: +# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py + +import enum +import math +from typing import Type, Tuple, Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic + +import flash_attn.cute.utils as utils +# import flash_attn.cute.pipeline as pipeline +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils + + +# class NamedBarrierFwd(enum.IntEnum): +# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +# WarpSchedulerWG1 = enum.auto() +# WarpSchedulerWG2 = enum.auto() +# WarpSchedulerWG3 = enum.auto() +# PFull = enum.auto() +# PEmpty = enum.auto() + +class FmhaStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.is_persistent, self.problem_shape_mbh]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.is_persistent, self.problem_shape_mbh], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + +def create_fmha_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_mbh: cute.Shape, +) -> FmhaStaticTileSchedulerParams: + return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) + + +class FmhaStaticTileScheduler: + + def __init__( + self, + params: FmhaStaticTileSchedulerParams, + current_work_linear_idx: cutlass.Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout( + params.problem_shape_mbh, loc=loc, ip=ip + ) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + if params.is_persistent: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + cutlass.min( + sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) + ), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + return cutlass.utils.WorkTileInfo(blk_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self._params) + values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self._blk_coord)) + values.extend(cutlass.extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) + return FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_fmha_static_tile_scheduler( + params: FmhaStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> FmhaStaticTileScheduler: + return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +class FlashAttentionForwardSm100: + def __init__( + self, + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + is_causal: bool, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_persistent: bool = True, + ): + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + # 2 Q tile per CTA + self.cta_tiler = (2 * mma_tiler[0], mma_tiler[1], mma_tiler[2]) + self.mma_tiler_qk = mma_tiler + self.pv_mma_tiler = (mma_tiler[0], mma_tiler[2], mma_tiler[1]) + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_even_N = False + self.is_causal = is_causal + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = False + self.s0_s1_barrier = False # Does S1 need to wait for S0 to finish + + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epilogue_warp_id = 14 + self.empty_warp_id = 15 + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + self.empty_warp_id, + ) + ) + + self.tmem_alloc_sync_bar_id = 1 + + self.tmem_s0_offset = 0 + self.tmem_s1_offset = 128 + self.tmem_o0_offset = 256 + self.tmem_o1_offset = 384 + self.tmem_p0_offset = 32 + self.tmem_p1_offset = 160 + self.tmem_p_offset = 32 + # self.tmem_p0_offset = 0 + # self.tmem_p1_offset = 128 + + # vec buffer for row_max & row_sum + self.tmem_vec0_offset = 0 + self.tmem_vec1_offset = 128 + + # self.num_regs_softmax = 192 + # self.num_regs_softmax = 184 + self.num_regs_softmax = 176 + # self.num_regs_correction = 104 + # self.num_regs_correction = 96 + self.num_regs_correction = 80 + # self.num_regs_correction = 64 + # self.num_regs_other = 24 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + self.num_regs_other = 80 + # self.num_regs_other = 96 + # self.num_regs_other = 48 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.q_stage = 2 + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.epi_stage = 2 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + max_seqlen_q: Optional[cutlass.Int32], + softmax_scale: cutlass.Float32, + softcap: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = mO.element_type + QO_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) + for t in (mQ, mO) + ] + KV_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + + # (s, d, h, b) -> (s, d, (h, b)) + mQ, mK, mV, mO = [cute.group_modes(t, begin=2, end=4) for t in (mQ, mK, mV, mO)] + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=[1, 0, 2])) + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mQ is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mK is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of mV is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & mK-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.mma_tiler_qk[:2], + ) + tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.pv_mma_tiler[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_qk.thr_id.shape,), + ) + + self.epi_tile = self.pv_mma_tiler[:2] + + q_smem_layout_staged = sm100_utils_basic.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, + ) + k_smem_layout_staged = sm100_utils_basic.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, + ) + p_tmem_layout_staged = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pv, self.pv_mma_tiler, self.q_dtype, self.acc_stage, + ) + v_smem_layout_staged = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pv, self.pv_mma_tiler, self.v_dtype, self.kv_stage, + ) + o_smem_layout_staged = sm100_utils_basic.make_smem_layout_epi( + self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, + ) + + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_Q, tma_tensor_q = cute.nvgpu.make_tma_tile_atom_A( + tma_load_op, + mQ, + q_smem_layout, + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_K, tma_tensor_k = cute.nvgpu.make_tma_tile_atom_B( + tma_load_op, + mK, + k_smem_layout, + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_V, tma_tensor_v = cute.nvgpu.make_tma_tile_atom_B( + tma_load_op, + mV, + v_smem_layout, + self.pv_mma_tiler, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(mO.shape), self.epi_tile + ) + o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) + + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tma_tile_atom( + tma_store_op, + mO, + o_smem_layout, + o_cta_v_layout, + ) + + self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, q_smem_layout) + self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, k_smem_layout) + + self.tile_sched_params, grid = self._compute_grid( + mO, + self.cta_tiler, + self.is_persistent, + ) + + self.mbar_load_q_full_offset = 0 + self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage + self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage + self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage + self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage + self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + 2 + self.mbar_O_full_offset = self.mbar_S_full_offset + 2 + self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + 2 + self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + 2 + self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.epi_stage + self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.epi_stage + self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 + self.mbar_max_reg_setting_offset = self.mbar_s0_s1_sequence_offset + 8 + self.mbar_tmem_dealloc_offset = self.mbar_max_reg_setting_offset + 1 + self.mbar_total = self.mbar_tmem_dealloc_offset + 1 + + @cute.struct + class SharedStorage: + # m_barriers for pipelines + mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + sScale: cute.struct.MemRange[cutlass.Float32, 2 * 128 * 1] + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(o_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + # Right after this, we multiply by log2(e) before applying exp2. + # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) + # (assigning it to softmax_scale_log2). + LOG2_E = math.log2(math.e) + # if cutlass.const_expr(not self.has_softcap): + if cutlass.const_expr(True): + softmax_scale_log2 = softmax_scale * LOG2_E + softcap_val = cutlass.Float32(0.0) + else: + softmax_scale_log2 = softcap * LOG2_E + softcap_val = softmax_scale / softcap + + # Launch the kernel synchronously + self.kernel( + tiled_mma_qk, + tiled_mma_pv, + tma_atom_Q, + tma_tensor_q, + tma_atom_K, + tma_tensor_k, + tma_atom_V, + tma_tensor_v, + tma_atom_o, + tma_tensor_o, + softmax_scale_log2, + q_smem_layout_staged, + k_smem_layout_staged, + p_tmem_layout_staged, + v_smem_layout_staged, + o_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_Q: cute.CopyAtom, + mQ: cute.Tensor, + tma_atom_K: cute.CopyAtom, + mK: cute.Tensor, + tma_atom_V: cute.CopyAtom, + mV: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mO: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + o_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + """ + + # coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mbar_ptr = storage.mbar_ptr.data_ptr() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # Init "full" barrier with number of producers, "empty" barrier with number of consumers + for i in range(self.q_stage): + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + for i in range(2): + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + if cutlass.const_expr(self.s0_s1_barrier): + for i in range(8): + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + for i in range(2): + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len([self.epilogue_warp_id])) + for i in range(2): + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init_arrive_cnt( + mbar_ptr + self.mbar_max_reg_setting_offset, + cute.arch.WARP_SIZE + * len( + ( + self.empty_warp_id, + self.load_warp_id, + self.mma_warp_id, + self.epilogue_warp_id, + *self.correction_warp_ids, + ) + ), + ) + cute.arch.mbarrier_init_arrive_cnt( + mbar_ptr + self.mbar_tmem_dealloc_offset, + cute.arch.WARP_SIZE + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync + pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], self.cta_tiler[1], + is_causal=self.is_causal, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + ) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) + # sQ_pi = storage.sQ.get_tensor(q_smem_layout_staged) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) + # sK_pi = storage.sK.get_tensor(k_smem_layout_staged) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) + sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) + sO = storage.sO.get_tensor(o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner) + + sScale = storage.sScale.get_tensor(cute.make_layout(256)) + + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + + qk_acc_shape = thr_mma_qk.partition_shape_C((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) + # TODO: this is a fake tensor, need to retrieve tmem_ptr + tmem_ptr = cute.make_ptr(cutlass.Float32, 0, mem_space=cute.AddressSpace.tmem, + assumed_align=16) + tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) + + pv_acc_shape = thr_mma_pv.partition_shape_C((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) + + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) + + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + + tOrP0 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + tOrP1 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, + tOrP.layout, + ) + + SeqlenInfoCls = partial( + SeqlenInfo, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0] + ) + + if warp_idx >= 12: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_max_reg_setting_offset) + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + tile_scheduler = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + self.load( + tile_scheduler, + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + ) + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_id: + # Alloc tmem buffer + tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) + if warp_idx == self.mma_warp_id: + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + # tile_scheduler = create_fmha_static_tile_scheduler( + # tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + # ) + + self.mma( + # tile_scheduler, + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + # sQ_pi.iterator, + # sK_pi.iterator, + q_smem_layout_staged.inner, + k_smem_layout_staged.inner, + v_smem_layout_staged.inner, + tStS0, + tStS1, + tOtO0, + tOtO1, + tOrP0, + tOrP1, + pipeline_kv, + mbar_ptr, + tile_sched_params, + block_info, + SeqlenInfoCls, + ) + + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + cutlass.Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.epilogue_warp_id: + tile_scheduler = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + self.epilogue_s2g(tile_scheduler, mO, sO, tma_atom_o, mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.correction_warp_ids[0]: + # increase register after decreasing + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_max_reg_setting_offset, 0) + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + tile_scheduler = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + softmax_loop = partial( + self.softmax_loop, + softmax_scale_log2=softmax_scale_log2, + thr_mma_qk=thr_mma_qk, + sScale=sScale, + mbar_ptr=mbar_ptr, + tile_scheduler=tile_scheduler, + block_info=block_info, + SeqlenInfoCls=SeqlenInfoCls, + ) + + if cutlass.const_expr(not self.s0_s1_barrier): + stage = cutlass.Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) + softmax_loop( + stage=stage, + tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s0_offset if stage == 0 else self.tmem_s1_offset), tStS.layout)) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + else: + # If there's s0_s1_barrier, it's faster to have 2 WGs having different code + if warp_idx < self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + softmax_loop(stage=0, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) + softmax_loop(stage=1, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_max_reg_setting_offset) + self.correction_loop( + thr_mma_qk, + thr_mma_pv, + tStS, + tOtO0, + tOtO1, + sScale, + mO, + sO, + tma_atom_o, + mbar_ptr, + tile_sched_params, + block_info, + SeqlenInfoCls, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + return + + @cute.jit + def load( + self, + tile_scheduler, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_kv: cutlass.utils.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + ): + # (bM, bK, loopM, loopL) + gQ_qdl = cute.local_tile(mQ, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0, None)) + tSgQ_qdl = thr_mma_qk.partition_A(gQ_qdl) + # (bN, bK, loopN, loopL) + gK_kdl = cute.local_tile(mK, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) + tSgK_kdl = thr_mma_qk.partition_B(gK_kdl) + # (bK, bN, loopN, loopL) + gV_dkl = cute.local_tile(mV, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None, None)) + tOgV_dkl = thr_mma_pv.partition_B(gV_dkl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_Q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV_dkl, 0, 3), + ) + + q_producer_phase = cutlass.Int32(1) + kv_producer_state = cutlass.utils.make_pipeline_state(cutlass.utils.PipelineUserType.Producer, self.kv_stage) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + tQgQ = tQgQ_qdl[None, None, (head_idx, batch_idx)] + head_idx_kv = head_idx // self.qhead_per_kvhead + tKgK, tVgV = [t[None, None, (head_idx_kv, batch_idx)] for t in (tKgK_kdl, tVgV_dkl)] + + def load_Q(stage: int): + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_empty_offset + stage, q_producer_phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_init_tx_bytes(mbar_ptr + self.mbar_load_q_full_offset + stage, self.tma_copy_q_bytes) + cute.copy( + tma_atom_Q, + tQgQ[None, 2 * m_block + stage], + tQsQ[None, stage], + tma_bar_ptr=mbar_ptr + self.mbar_load_q_full_offset + stage, + ) + + load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_kv) + load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_kv) + + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + load_Q(0) # Q0 + load_K(n_block_max - 1, kv_producer_state) # K0 + kv_producer_state.advance() + load_Q(1) # Q1 + q_producer_phase ^= 1 + load_V(n_block_max - 1, kv_producer_state) # V0 + kv_producer_state.advance() + for i in cutlass.range_dynamic(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + load_K(n_block, kv_producer_state) # Ki + kv_producer_state.advance() + load_V(n_block, kv_producer_state) # Vi + kv_producer_state.advance() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.core.ThrMma, + tiled_mma_pv: cute.core.ThrMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + # sQ_base_addr: cute.Pointer, + # sK_base_addr: cute.Pointer, + sQ_swizzle: cute.Swizzle, + sK_swizzle: cute.Swizzle, + sV_swizzle: cute.Swizzle, + tStS0: cute.Tensor, + tStS1: cute.Tensor, + tOtO0: cute.Tensor, + tOtO1: cute.Tensor, + tOrP0: cute.Tensor, + tOrP1: cute.Tensor, + pipeline_kv: cutlass.utils.PipelineAsync, + mbar_ptr: cute.Pointer, + # tile_scheduler, + tile_sched_params, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + ): + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + tSrQ = thr_mma_qk.make_fragment_A(sQ) + tSrK = thr_mma_qk.make_fragment_B(sK) + tOrV = thr_mma_pv.make_fragment_B(sV) + tStSs = (tStS0, tStS1) + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + tOrPs = (tOrP0, tOrP1) + + qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op + # sQ_base_addr_for_desc = cute.arch.make_warp_uniform(sm100_desc.make_smem_desc_start_addr(sQ_base_addr)) + # sK_base_addr_for_desc = cute.arch.make_warp_uniform(sm100_desc.make_smem_desc_start_addr(sK_base_addr)) + # sQ_addr_offset_for_desc = (cute.crd2idx((0, 0, 0, 1), sQ.layout) * sQ.element_type.width // 8) >> 4 + # sK_addr_offset_for_desc = (cute.crd2idx((0, 0, 0, 1), sK.layout) * sK.element_type.width // 8) >> 4 + # sQ_layout = cute.select(sQ.layout, mode=[0, 1, 2]) + # sK_layout = cute.select(sK.layout, mode=[0, 1, 2]) + + gemm_Si = [ + partial( + sm100_utils.gemm_ptx_partial, + qk_mma_op, self.tmem_s0_offset if stage == 0 else self.tmem_s1_offset, tSrQs[stage], + sA=sQ[None, None, None, stage], + sA_swizzle=sQ_swizzle, sB_swizzle=sK_swizzle, zero_init=True + ) + for stage in range(2) + ] + gemm_Pi = [ + partial( + sm100_utils.gemm_ptx_partial, + pv_mma_op, self.tmem_o0_offset if stage == 0 else self.tmem_o1_offset, tOrPs[stage], + sA=None, sA_swizzle=None, sB_swizzle=sV_swizzle + ) + for stage in range(2) + ] + + mma_q_consumer_phase = cutlass.Int32(0) + mma_kv_consumer_state = cutlass.utils.make_pipeline_state( + cutlass.utils.PipelineUserType.Consumer, self.kv_stage + ) + P_full_O_rescaled_phase = cutlass.Int32(0) + + tile_scheduler = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + for stage in range(2): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) + # 2. wait for K0 + if stage == 0: + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + gemm_Si[stage](tCrB=tSrKi, sB=sK[None, None, None, mma_kv_consumer_state.index]) + # sm100_utils.gemm_ptx_partial1( + # qk_mma_op, 0 + stage * self.tmem_s1_offset, tSrQs[stage], tSrKi, + # sQ_base_addr_for_desc, sQ_addr_offset_for_desc, stage, + # sK_base_addr_for_desc, sK_addr_offset_for_desc, 0, + # sQ_layout, sK_layout, sQ_swizzle, sK_swizzle, + # zero_init=True + # ) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + O_should_accumulate = False + for i in cutlass.range_dynamic(n_block_max - 1 - n_block_min, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index = mma_kv_consumer_state.index + tOrVi = tOrV[None, None, None, Vi_index] + for stage in range(2): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if stage == 1: + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if stage == 0: + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index = mma_kv_consumer_state.index + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # sm100_utils.gemm(tiled_mma_qk, tStS0, tSrQs[0], tSrK[None, None, None, Ki_index], zero_init=True) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK[None, None, None, Ki_index]) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + 0) + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + 1) + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Vi_index = mma_kv_consumer_state.index + tOrVi = tOrV[None, None, None, Vi_index] + for stage in range(2): + # 2. acquire corrected Oi_partial and Pi + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warp, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # for both softmax0 and softmax1 warp group + @cute.jit + def softmax_loop( + self, + stage: int, + # stage: cutlass.Int32, + softmax_scale_log2: cutlass.Float32, + thr_mma_qk: cute.core.ThrMma, + tStSi: cute.Tensor, + sScale: cute.Tensor, + mbar_ptr: cute.Pointer, + tile_scheduler, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + """ + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE + # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + * (len(self.softmax0_warp_ids) + ) + ) + + cS_base = cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + tScS = thr_mma_qk.partition_C(cS_base) + + tStS_scale_layout = cute.composition(tStSi.layout, cute.make_layout((128, 1))) + tStScale = cute.make_tensor(tStSi.iterator, tStS_scale_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tStP_layout = cute.composition(tStSi.layout, cute.make_layout((128, tilePlikeFP32))) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_p_offset, tStP_layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32, + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tStSi) + + tmem_store_scale_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), cutlass.Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) + + tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) + tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) + thr_tmem_store = tiled_tmem_store.get_slice(tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + mma_si_consumer_phase = cutlass.Int32(0) + si_corr_producer_phase = cutlass.Int32(1) + s0_s1_sequence_phase = cutlass.Int32(1 if stage == 0 else 0) + + # self.warp_scheduler_barrier_init() + + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + mask = AttentionMask( + self.mma_tiler_qk[0], self.mma_tiler_qk[1], seqlen.seqlen_q, seqlen.seqlen_k, + self.qhead_per_kvhead if self.pack_gqa else 1, + ) + mask_fn = partial( + mask.apply_mask_sm100, m_block=m_block, m_stage=stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal + ) + softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if self.q_dtype.width == 16 else 0.0) + softmax.reset() + + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + si_corr_producer_phase ^= 1 + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + ) + + # 1 masking iter + if cutlass.const_expr(not self.is_even_N): + # mask_trip_count = 1 if seqlen.seqlen_k % self.mma_tiler_qk[1] == 0 else 0 + softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, mask_fn=partial(mask_fn, mask_seqlen=True)) + si_corr_producer_phase ^= 1 + mma_si_consumer_phase ^= 1 + s0_s1_sequence_phase ^= 1 + n_block_max -= 1 + # Next couple of iterations with causal masking + if cutlass.const_expr(self.is_causal): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 + for n_tile in cutlass.range_dynamic(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + si_corr_producer_phase ^= 1 + mma_si_consumer_phase ^= 1 + s0_s1_sequence_phase ^= 1 + n_block_max = n_block_min_causal_local_mask + # The remaining iterations have no masking + for n_tile in cutlass.range_dynamic(n_block_max, unroll=1): + n_block = n_block_max - n_tile - 1 + softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=None) + si_corr_producer_phase ^= 1 + mma_si_consumer_phase ^= 1 + s0_s1_sequence_phase ^= 1 + + # mma_softmax_pipeline.sync_object_array_full.wait(stage, mma_si_consumer_phase) + + # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, cutlass.Float32) + # tSrScale_r2t[0] = softmax.row_sum[0] + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + sScale[tidx + stage * 128] = softmax.row_sum[0] + + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def softmax_step( + self, + # stage: cutlass.Int32, + mma_si_consumer_phase: cutlass.Int32, + si_corr_producer_phase: cutlass.Int32, + s0_s1_sequence_phase: cutlass.Int32, + n_block: cutlass.Int32, + softmax: SoftmaxSm100, + mbar_ptr: cute.Pointer, + mbar_s0_s1_sequence_offset: cutlass.Int32, + thr_mma_qk: cute.core.ThrMma, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_scale: cute.CopyAtom, + tStS_t2r: cute.Tensor, + tStScale_r2t: cute.Tensor, + tStP_r2t: cute.Tensor, + sScale: cute.Tensor, + mask_fn: Optional[Callable], + stage: int, + ) -> None: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + tilePlikeFP32 = self.mma_tiler_qk[1] // cutlass.Float32.width * self.v_dtype.width + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScP_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape + + # Wait for Si + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) + tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if cutlass.const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load()) + + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, cutlass.Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * 128] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + # Notify correction wg that row_max is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # Sequence barrier wait + if cutlass.const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) + tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, cutlass.Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, + ) + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + # print(tSrP_r2t_f32, tStP_r2t) + # Sequence barrier arrive + if cutlass.const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale) + # acc_scale = cute.arch.exp2(acc_scale_) + + @cute.jit + def correction_loop( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOtO0: cute.Tensor, + tOtO1: cute.Tensor, + sScale: cute.Tensor, + mO: cute.Tensor, + sO: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mbar_ptr: cute.Pointer, + # tile_scheduler, + tile_sched_params, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + ): + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) + tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((128, 1))) + tStScale_0 = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_scale_layout) + tStScale_1 = cute.make_tensor(tStS.iterator + self.tmem_vec1_offset, tStS_scale_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, + ) + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScale_0) + thread_idx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(thread_idx) + + tStScale_0_t2r = thr_tmem_load_vec.partition_S(tStScale_0) + tStScale_1_t2r = thr_tmem_load_vec.partition_S(tStScale_1) + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape + + tOtOs = [tOtO0, tOtO1] + tStScales_t2r = [tStScale_0_t2r, tStScale_1_t2r] + + # First iter: no correction is required + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) + + softmax_corr_consumer_phase = cutlass.Int32(0) + o_corr_consumer_phase = cutlass.Int32(0) + corr_epi_producer_phase = cutlass.Int32(1) + + tile_scheduler = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, cutlass.Float32) + for i in cutlass.range_dynamic(n_block_max - n_block_min - 1, unroll=1): + for stage in range(2): + # wait for S0 / S1 + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + # cute.copy(tiled_tmem_load_vec, tStScale_1_t2r, tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[stage] + scale = sScale[thread_idx + stage * 128] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if thread_idx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # should_rescale = True + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale(thr_mma_pv, tOtOs[stage], thread_idx, scale) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + # End of seqlen_corr_loop_steps + + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + + for stage in range(2): + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[thread_idx + stage * 128] + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) + self.correction_epilogue( + thr_mma_pv, tOtOs[stage], thread_idx, 1.0 / scale, sO[None, None, stage], + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # if thread_idx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + # gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) + # gO = gO_qdl[None, None, None, (head_idx, batch_idx)] + # tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + # tma_atom_o, + # 0, + # cute.make_layout(1), + # cute.group_modes(sO, 0, 2), + # cute.group_modes(gO, 0, 2), + # ) + # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + # stage = warp_idx_in_wg + # if stage < 2: + # # wait from corr, issue tma store on smem + # # 1. wait for O0 / O1 final + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) + # # 2. copy O0 / O1 to gmem + # cute.copy(tma_atom_o, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + # cute.arch.cp_async_bulk_commit_group() + # # Ensure O0 / O1 buffer is ready to be released + # cute.arch.cp_async_bulk_wait_group(0, read=True) + # cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: cutlass.Int32, + scale: cutlass.Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + """ + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, 128 // corr_tile_size), self.pv_acc_dtype) + for i in range(self.cta_tiler[2] // corr_tile_size): + tOrO_frg_i = tOrO_frg[None, i] + tTMrO_i_layout = cute.composition(tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0])) + tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) + tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tTMrO_i) + for j in range(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), (scale, scale), + ) + tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) + cute.copy(tiled_tmem_store, tTMrO_i, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def correction_epilogue( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: cutlass.Int32, + scale: cutlass.Float32, + sO: cute.Tensor, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilogue function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: cutlass.Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( + self.pv_mma_tiler, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_load.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_load.tiler_mn, + ) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in range(self.cta_tiler[2] // corr_tile_size): + tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] + tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in range(0, cute.size(tOrO_frg), 2): + tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + ) + tSMrO = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + o_vec = tOrO_frg.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tOsO_r2s_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, + ) + + @cute.jit + def epilogue_s2g( + self, + tile_scheduler, + mO: cute.Tensor, + sO: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mbar_ptr: cute.Pointer, + ): + gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) + epi_consumer_phase = cutlass.Int32(0) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + gO = gO_qdl[None, None, None, (head_idx, batch_idx)] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + for stage in range(2): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + # 2. copy O0 / O1 to gmem + cute.copy(tma_atom_o, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + cute.arch.cp_async_bulk_commit_group() + for stage in range(2): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + # Advance to next tile + epi_consumer_phase ^= 1 + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # @cute.jit + def load_K( + self, + tma_atom: cute.CopyAtom, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + pipeline: cutlass.utils.PipelineAsync, + block: cutlass.Int32, + producer_state: cutlass.utils.PipelineState, + ): + pipeline.producer_acquire(producer_state) + cute.copy( + tma_atom, + tKgK[None, block], + tKsK[None, producer_state.index], + tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + ) + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_producer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread, len([self.load_warp_id]) + ) + load_kv_consumer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread, len([self.mma_warp_id])) + return cutlass.utils.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_kv_bytes, + ) + + # @cute.jit + # def warp_scheduler_barrier_init(self): + # warp_group_idx = utils.canonical_warp_group_idx(sync=False) + # if warp_group_idx == 0: + # utils.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # ) + + # def warp_scheduler_barrier_sync(self): + # cute.arch.barrier( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # number_of_threads=2 * 128 + # ) + + # def warp_scheduler_barrier_arrive(self): + # cur_wg = utils.canonical_warp_group_idx(sync=False) + # next_wg = 1 - cur_wg + # utils.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, + ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + o_shape = o.shape + tile_sched_params = create_fmha_static_tile_scheduler_params( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9a5bd894b56..9ed247232e4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -24,6 +24,7 @@ from flash_attn.cute import utils from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80, FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess @@ -58,6 +59,7 @@ def _flash_attn_fwd( m_block_size: int = 128, n_block_size: int = 128, num_threads: int = 384, + _compute_capability: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] @@ -119,27 +121,40 @@ def _flash_attn_fwd( max_seqlen_q = cutlass.Int32(max_seqlen_q) if max_seqlen_q is not None else None current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, - m_block_size, n_block_size, num_threads + m_block_size, n_block_size, num_threads, + compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: - # fa_fwd = FlashAttentionForwardSm80( - fa_fwd = FlashAttentionForwardSm90( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - is_causal=causal, - has_softcap=softcap != 0.0, - m_block_size=m_block_size, - n_block_size=n_block_size, - # num_stages=1, - num_stages=2, - num_threads=num_threads, - Q_in_regs=False, - ) + if compute_capability == 9: + # fa_fwd = FlashAttentionForwardSm80( + fa_fwd = FlashAttentionForwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + has_softcap=softcap != 0.0, + m_block_size=m_block_size, + n_block_size=n_block_size, + # num_stages=1, + num_stages=2, + num_threads=num_threads, + Q_in_regs=False, + ) + else: + fa_fwd = FlashAttentionForwardSm100( + cutlass.Float32, + cutlass.Float32, + (128, 128, head_dim), + is_causal=causal, + qhead_per_kvhead=qhead_per_kvhead, + is_persistent=True, + ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index eb3770deea8..617e7115f55 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -69,3 +69,41 @@ def apply_mask( # only consider the column index, so the row index sets to 0. if t0ScS_mn[0, c][1] >= col_limit_right: acc_S_mn[r, c] = -cutlass.Float32.inf + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + m_stage: cutlass.Int32, + thr_mma: cute.TiledMma, + thr_tmem_load: cute.TiledCopy, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + ) -> None: + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS = thr_mma.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size + if not mask_causal: + if mask_seqlen: + for i in range(cute.size(tScS_t2r.shape)): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -cutlass.Float32.inf + # For some reason the 2 lines above generate really bad SASS, so we just call ptx directly + acc_S[i] = utils.neg_inf_if_ge(acc_S[i], tScS_t2r[i][1], seqlenk_col_limit) + else: # Causal + assert self.qhead_per_kvhead_packgqa == 1, "PackGQA not supported for SM100 yet" + causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q + row_idx = tScS_t2r[0][0] + (m_block * 2 + m_stage) * self.m_block_size + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + for i in range(cute.size(tScS_t2r.shape)): + # if tScS_t2r[i][1] >= col_limit_right: + # acc_S[i] = -cutlass.Float32.inf + # For some reason the 2 lines above generate really bad SASS, so we just call ptx directly + acc_S[i] = utils.neg_inf_if_ge(acc_S[i], tScS_t2r[i][1], col_limit_right) diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py new file mode 100644 index 00000000000..0170f0e99ae --- /dev/null +++ b/flash_attn/cute/mma_sm100_desc.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025, Tri Dao. +# Ported Cutlass code from C++ to Python: +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type → encoding helpers +# --------------------------------------------------------------------------- + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them + if cutlass_type is cutlass.FloatE4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.FloatE5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for Blackwell MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ + # Swizzle string has the form "S" + swz_str = str(swizzle) + inside = swz_str[swz_str.index('<') + 1 : swz_str.index('>')] # '3,4,3' + B, M, S = [int(x) for x in inside.split(',')] # [3, 4, 3] + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 68f577f8d27..58f8c12c26c 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -134,17 +134,19 @@ def update_row_max(self, acc_S_row: cute.TensorSSA) -> Tuple[Float32, Float32]: row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 - acc_scale = utils.exp2f(acc_scale_) if cutlass.const_expr(self.rescale_threshold > 0.0): if acc_scale_ >= -self.rescale_threshold: row_max_new = row_max_old row_max_safe = row_max_old - acc_scale = 1.0 + acc_scale_ = 0.0 + acc_scale = utils.exp2f(acc_scale_) self.row_max[0] = row_max_new return row_max_safe, acc_scale def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32) -> None: self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + # tmp = self._compute_row_sum(acc_S_row_exp) + # self.row_sum[0] = self.row_sum[0] * row_scale + tmp def scale_apply_exp2_convert( self, @@ -152,8 +154,15 @@ def scale_apply_exp2_convert( row_max: Float32, acc_S_row_converted: cute.Tensor, ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" minus_row_max_scaled = -row_max * self.scale_log2 - # assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + for i in range(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + # for i in range(0, cute.size(acc_S_row.shape), 2): # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( # (acc_S_row[i], acc_S_row[i + 1]), @@ -163,22 +172,23 @@ def scale_apply_exp2_convert( # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) - frg_cnt = 4 - frg_tile = cute.size(acc_S_row) // frg_cnt - assert cute.size(acc_S_row) % (frg_cnt * 2) == 0 + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) acc_S_row_converted_frg = cute.logical_divide(acc_S_row_converted, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): - acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( - cute.arch.fma_packed_f32x2( - (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), - (self.scale_log2, self.scale_log2), - (minus_row_max_scaled, minus_row_max_scaled), - ) - ) - # acc_S_row_frg[k, j] = fa_utils.exp2f(acc_S_row_frg[k, j]) - # acc_S_row_frg[k + 1, j] = fa_utils.exp2f(acc_S_row_frg[k + 1, j]) + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + # cute.arch.fma_packed_f32x2( + # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # ) + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) acc_S_row_converted_frg[None, j].store( diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 6ea68c05677..5b4a4438513 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -225,10 +225,12 @@ def fmax(a: float | Float32, b: float | Float32, c: float | Float32 | None = Non def fmax_reduce( x: cute.TensorSSA, - init_val: float | Float32 = -Float32.inf, + init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if cutlass.const_expr(init_val is None): + init_val = -cutlass.Float32.inf return x.reduce(cute.ReductionOp.MAX, init_val, 0) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max @@ -236,7 +238,7 @@ def fmax_reduce( res = cute.make_fragment(x.shape, Float32) res.store(x) local_max = [ - fmax(init_val, res[0], res[1]), + fmax(init_val, res[0], res[1]) if cutlass.const_expr(init_val is not None) else fmax(res[0], res[1]), fmax(res[2], res[3]), fmax(res[4], res[5]), fmax(res[6], res[7]), @@ -252,18 +254,20 @@ def fmax_reduce( def fadd_reduce( x: cute.TensorSSA, - init_val: float | Float32 = Float32.zero, + init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if cutlass.const_expr(init_val is None): + init_val = Float32.zero return x.reduce(cute.ReductionOp.ADD, init_val, 0) else: res = cute.make_fragment(x.shape, Float32) res.store(x) - local_sum_0 = cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + local_sum_0 = cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) if cutlass.const_expr(init_val is not None) else (res[0], res[1]) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] for i in range(8, cute.size(x.shape), 8): - local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i], res[i + 1])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) @@ -414,3 +418,38 @@ def shuffle_sync( for i in range(cute.size(val_i32)): val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) return val[0] + + +@dsl_user_op +def noop_asm(val: cutlass.Int32, *, loc=None, ip=None) -> cute.Numeric: + assert val.width == 32, "noop_asm only supports 32-bit types" + return type(val)( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(val).ir_value(loc=loc, ip=ip)], + "mov.b32 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def neg_inf_if_ge(val: cutlass.Float32, idx: int, limit: cutlass.Int32, *, loc=None, ip=None) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [cutlass.Float32(val).ir_value(loc=loc, ip=ip), cutlass.Int32(limit).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .pred p;\n\t" + f"setp.ge.s32 p, {idx}, $2;\n\t" + "selp.f32 $0, 0fFF800000, $1, p;" + "}\n", + "=f,f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py index 339af1767c4..772f955dedb 100644 --- a/flash_attn/utils/testing.py +++ b/flash_attn/utils/testing.py @@ -4,7 +4,7 @@ import torch from einops import rearrange, repeat -from padding import pad_input, unpad_input +from flash_attn.bert_padding import pad_input, unpad_input def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index bc41a56d813..82398622093 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -137,13 +137,13 @@ def test_flash_attn_output( # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() - # if qv is not None: - # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() @@ -185,6 +185,7 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 + and False ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) From cc2521394527158b15cd1438e3448cb7f9559cee Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 29 Jun 2025 15:17:12 -0400 Subject: [PATCH 006/258] [Cute] Don't need neg_inf_if_ge ptx any more --- flash_attn/cute/mask.py | 8 ++++---- flash_attn/cute/utils.py | 19 ------------------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 617e7115f55..1d013caefd5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -91,8 +91,8 @@ def apply_mask_sm100( for i in range(cute.size(tScS_t2r.shape)): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf - # For some reason the 2 lines above generate really bad SASS, so we just call ptx directly - acc_S[i] = utils.neg_inf_if_ge(acc_S[i], tScS_t2r[i][1], seqlenk_col_limit) + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: # Causal assert self.qhead_per_kvhead_packgqa == 1, "PackGQA not supported for SM100 yet" causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q @@ -105,5 +105,5 @@ def apply_mask_sm100( for i in range(cute.size(tScS_t2r.shape)): # if tScS_t2r[i][1] >= col_limit_right: # acc_S[i] = -cutlass.Float32.inf - # For some reason the 2 lines above generate really bad SASS, so we just call ptx directly - acc_S[i] = utils.neg_inf_if_ge(acc_S[i], tScS_t2r[i][1], col_limit_right) + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 5b4a4438513..c2de62897e9 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -434,22 +434,3 @@ def noop_asm(val: cutlass.Int32, *, loc=None, ip=None) -> cute.Numeric: asm_dialect=llvm.AsmDialect.AD_ATT, ) ) - - -@dsl_user_op -def neg_inf_if_ge(val: cutlass.Float32, idx: int, limit: cutlass.Int32, *, loc=None, ip=None) -> cutlass.Float32: - return cutlass.Float32( - llvm.inline_asm( - T.f32(), - [cutlass.Float32(val).ir_value(loc=loc, ip=ip), cutlass.Int32(limit).ir_value(loc=loc, ip=ip)], - "{\n\t" - ".reg .pred p;\n\t" - f"setp.ge.s32 p, {idx}, $2;\n\t" - "selp.f32 $0, 0fFF800000, $1, p;" - "}\n", - "=f,f,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) From 96acd0f70944c957ef9707a76a425f6ce7995b2c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 29 Jun 2025 19:18:21 -0400 Subject: [PATCH 007/258] [Cute] Test flash_fwd_sm100.py with hdim 64 --- flash_attn/cute/flash_fwd_sm100.py | 267 ++++++++++++++++------------- flash_attn/cute/interface.py | 5 +- flash_attn/cute/softmax.py | 70 ++++++-- tests/cute/test_flash_attn.py | 5 +- 4 files changed, 207 insertions(+), 140 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index e2310b4d9f0..6ea681e0837 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1,4 +1,4 @@ -# Supported features, currently only tested for hdim 128. +# Supported features, currently only tested for hdim 64 and 128. # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA @@ -183,26 +183,38 @@ def create_fmha_static_tile_scheduler( class FlashAttentionForwardSm100: def __init__( self, - qk_acc_dtype: Type[cutlass.Numeric], - pv_acc_dtype: Type[cutlass.Numeric], - mma_tiler: Tuple[int, int, int], - is_causal: bool, + # dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + is_causal: bool = False, qhead_per_kvhead: cutlass.Constexpr[int] = 1, + m_block_size: int = 128, + n_block_size: int = 128, is_persistent: bool = True, ): - self.qk_acc_dtype = qk_acc_dtype - self.pv_acc_dtype = pv_acc_dtype + # self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.m_block_size = m_block_size + self.n_block_size = n_block_size # 2 Q tile per CTA - self.cta_tiler = (2 * mma_tiler[0], mma_tiler[1], mma_tiler[2]) - self.mma_tiler_qk = mma_tiler - self.pv_mma_tiler = (mma_tiler[0], mma_tiler[2], mma_tiler[1]) + self.cta_tiler = (2 * m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) + self.pv_mma_tiler = (m_block_size, self.head_dim_v_padded, n_block_size) + self.qk_acc_dtype = cutlass.Float32 + self.pv_acc_dtype = cutlass.Float32 self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent self.is_even_N = False self.is_causal = is_causal self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False - self.s0_s1_barrier = False # Does S1 need to wait for S0 to finish + self.s0_s1_barrier = head_dim == 64 # Does S1 need to wait for S0 to finish self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) @@ -229,18 +241,18 @@ def __init__( self.tmem_alloc_sync_bar_id = 1 self.tmem_s0_offset = 0 - self.tmem_s1_offset = 128 - self.tmem_o0_offset = 256 - self.tmem_o1_offset = 384 - self.tmem_p0_offset = 32 - self.tmem_p1_offset = 160 - self.tmem_p_offset = 32 - # self.tmem_p0_offset = 0 - # self.tmem_p1_offset = 128 + self.tmem_s1_offset = self.tmem_s0_offset + self.n_block_size + self.tmem_o0_offset = self.tmem_s1_offset + self.n_block_size + self.tmem_o1_offset = self.tmem_o0_offset + self.head_dim_v_padded + self.tmem_total = self.tmem_o1_offset + self.head_dim_v_padded + assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + self.tmem_p_offset = 0 + self.tmem_p0_offset = self.tmem_s0_offset + self.tmem_p_offset + self.tmem_p1_offset = self.tmem_s1_offset + self.tmem_p_offset # vec buffer for row_max & row_sum self.tmem_vec0_offset = 0 - self.tmem_vec1_offset = 128 + self.tmem_vec1_offset = self.tmem_vec0_offset + self.n_block_size # self.num_regs_softmax = 192 # self.num_regs_softmax = 184 @@ -373,19 +385,19 @@ def __call__( self.epi_tile = self.pv_mma_tiler[:2] - q_smem_layout_staged = sm100_utils_basic.make_smem_layout_a( + sQ_layout_staged = sm100_utils_basic.make_smem_layout_a( tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, ) - k_smem_layout_staged = sm100_utils_basic.make_smem_layout_b( + sK_layout_staged = sm100_utils_basic.make_smem_layout_b( tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, ) - p_tmem_layout_staged = sm100_utils_basic.make_smem_layout_a( + tP_layout_staged = sm100_utils_basic.make_smem_layout_a( tiled_mma_pv, self.pv_mma_tiler, self.q_dtype, self.acc_stage, ) - v_smem_layout_staged = sm100_utils_basic.make_smem_layout_b( + sV_layout_staged = sm100_utils_basic.make_smem_layout_b( tiled_mma_pv, self.pv_mma_tiler, self.v_dtype, self.kv_stage, ) - o_smem_layout_staged = sm100_utils_basic.make_smem_layout_epi( + sO_layout_staged = sm100_utils_basic.make_smem_layout_epi( self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, ) @@ -393,32 +405,32 @@ def __call__( tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() - q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_Q, tma_tensor_q = cute.nvgpu.make_tma_tile_atom_A( + sQ_layout = cute.select(sQ_layout_staged, mode=[0, 1, 2]) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tma_tile_atom_A( tma_load_op, mQ, - q_smem_layout, + sQ_layout, self.mma_tiler_qk, tiled_mma_qk, self.cluster_layout_vmnk.shape, ) # TMA load for K - k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_K, tma_tensor_k = cute.nvgpu.make_tma_tile_atom_B( + sK_layout = cute.select(sK_layout_staged, mode=[0, 1, 2]) + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tma_tile_atom_B( tma_load_op, mK, - k_smem_layout, + sK_layout, self.mma_tiler_qk, tiled_mma_qk, self.cluster_layout_vmnk.shape, ) # TMA load for V - v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_V, tma_tensor_v = cute.nvgpu.make_tma_tile_atom_B( + sV_layout = cute.select(sV_layout_staged, mode=[0, 1, 2]) + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tma_tile_atom_B( tma_load_op, mV, - v_smem_layout, + sV_layout, self.pv_mma_tiler, tiled_mma_pv, self.cluster_layout_vmnk.shape, @@ -427,23 +439,19 @@ def __call__( o_cta_v_layout = cute.composition( cute.make_identity_layout(mO.shape), self.epi_tile ) - o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) + sO_layout = cute.select(sO_layout_staged, mode=[0, 1]) - tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tma_tile_atom( + tma_atom_O, tma_tensor_O = cute.nvgpu.cpasync.make_tma_tile_atom( tma_store_op, mO, - o_smem_layout, + sO_layout, o_cta_v_layout, ) - self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, q_smem_layout) - self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, sQ_layout) + self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, sK_layout) - self.tile_sched_params, grid = self._compute_grid( - mO, - self.cta_tiler, - self.is_persistent, - ) + self.tile_sched_params, grid = self._compute_grid(mO, self.cta_tiler, self.is_persistent) self.mbar_load_q_full_offset = 0 self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage @@ -468,17 +476,17 @@ class SharedStorage: # Tmem holding buffer tmem_holding_buf: cutlass.Int32 # Smem tensors - sScale: cute.struct.MemRange[cutlass.Float32, 2 * 128 * 1] + sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size] sO: cute.struct.Align[ - cute.struct.MemRange[self.o_dtype, cute.cosize(o_smem_layout_staged)], + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout_staged)], self.buffer_align_bytes, ] sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout_staged)], self.buffer_align_bytes, ] sK: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)], + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout_staged)], self.buffer_align_bytes, ] @@ -500,22 +508,27 @@ class SharedStorage: # Launch the kernel synchronously self.kernel( - tiled_mma_qk, - tiled_mma_pv, + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + tma_tensor_O, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, tma_atom_Q, - tma_tensor_q, tma_atom_K, - tma_tensor_k, tma_atom_V, - tma_tensor_v, - tma_atom_o, - tma_tensor_o, + tma_atom_O, + tiled_mma_qk, + tiled_mma_pv, softmax_scale_log2, - q_smem_layout_staged, - k_smem_layout_staged, - p_tmem_layout_staged, - v_smem_layout_staged, - o_smem_layout_staged, + sQ_layout_staged, + sK_layout_staged, + tP_layout_staged, + sV_layout_staged, + sO_layout_staged, self.tile_sched_params, ).launch( grid=grid, @@ -530,22 +543,27 @@ class SharedStorage: @cute.kernel def kernel( self, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - tma_atom_Q: cute.CopyAtom, mQ: cute.Tensor, - tma_atom_K: cute.CopyAtom, mK: cute.Tensor, - tma_atom_V: cute.CopyAtom, mV: cute.Tensor, - tma_atom_o: cute.CopyAtom, mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, softmax_scale_log2: cutlass.Float32, - q_smem_layout_staged: cute.ComposedLayout, - k_smem_layout_staged: cute.ComposedLayout, - p_tmem_layout_staged: cute.ComposedLayout, - v_smem_layout_staged: cute.ComposedLayout, - o_smem_layout_staged: cute.ComposedLayout, + sQ_layout_staged: cute.ComposedLayout, + sK_layout_staged: cute.ComposedLayout, + tP_layout_staged: cute.ComposedLayout, + sV_layout_staged: cute.ComposedLayout, + sO_layout_staged: cute.ComposedLayout, tile_sched_params: FmhaStaticTileSchedulerParams, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -625,16 +643,15 @@ def kernel( # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) - sQ = storage.sQ.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) - # sQ_pi = storage.sQ.get_tensor(q_smem_layout_staged) + sQ = storage.sQ.get_tensor(sQ_layout_staged.outer, swizzle=sQ_layout_staged.inner) + # sQ_pi = storage.sQ.get_tensor(sQ_layout_staged) # (MMA, MMA_K, MMA_D, PIPE) - sK = storage.sK.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) - # sK_pi = storage.sK.get_tensor(k_smem_layout_staged) + sK = storage.sK.get_tensor(sK_layout_staged.outer, swizzle=sK_layout_staged.inner) + # sK_pi = storage.sK.get_tensor(sK_layout_staged) # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem - sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) - sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) - sO = storage.sO.get_tensor(o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner) + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout_staged.inner), sV_layout_staged.outer) + sO = storage.sO.get_tensor(sO_layout_staged.outer, swizzle=sO_layout_staged.inner) sScale = storage.sScale.get_tensor(cute.make_layout(256)) @@ -657,7 +674,7 @@ def kernel( tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) - tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + tP = cute.make_tensor(tStS.iterator, tP_layout_staged.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] tOrP0 = cute.make_tensor( @@ -726,9 +743,9 @@ def kernel( sV, # sQ_pi.iterator, # sK_pi.iterator, - q_smem_layout_staged.inner, - k_smem_layout_staged.inner, - v_smem_layout_staged.inner, + sQ_layout_staged.inner, + sK_layout_staged.inner, + sV_layout_staged.inner, tStS0, tStS1, tOtO0, @@ -761,7 +778,7 @@ def kernel( tile_scheduler = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) - self.epilogue_s2g(tile_scheduler, mO, sO, tma_atom_o, mbar_ptr) + self.epilogue_s2g(tile_scheduler, mO, sO, tma_atom_O, mbar_ptr) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -817,7 +834,7 @@ def kernel( sScale, mO, sO, - tma_atom_o, + tma_atom_O, mbar_ptr, tile_sched_params, block_info, @@ -1154,13 +1171,13 @@ def softmax_loop( cS_base = cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) tScS = thr_mma_qk.partition_C(cS_base) - tStS_scale_layout = cute.composition(tStSi.layout, cute.make_layout((128, 1))) + tStS_scale_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, 1))) tStScale = cute.make_tensor(tStSi.iterator, tStS_scale_layout) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width - tStP_layout = cute.composition(tStSi.layout, cute.make_layout((128, tilePlikeFP32))) + tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) tStP = cute.make_tensor(tStSi.iterator + self.tmem_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( @@ -1207,9 +1224,6 @@ def softmax_loop( softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if self.q_dtype.width == 16 else 0.0) softmax.reset() - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) - si_corr_producer_phase ^= 1 - softmax_step = partial( self.softmax_step, softmax=softmax, @@ -1226,10 +1240,13 @@ def softmax_loop( stage=stage, ) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + si_corr_producer_phase ^= 1 + # 1 masking iter if cutlass.const_expr(not self.is_even_N): # mask_trip_count = 1 if seqlen.seqlen_k % self.mma_tiler_qk[1] == 0 else 0 - softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, mask_fn=partial(mask_fn, mask_seqlen=True)) + softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=False, mask_fn=partial(mask_fn, mask_seqlen=True)) si_corr_producer_phase ^= 1 mma_si_consumer_phase ^= 1 s0_s1_sequence_phase ^= 1 @@ -1250,7 +1267,7 @@ def softmax_loop( # The remaining iterations have no masking for n_tile in cutlass.range_dynamic(n_block_max, unroll=1): n_block = n_block_max - n_tile - 1 - softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=None) + softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) si_corr_producer_phase ^= 1 mma_si_consumer_phase ^= 1 s0_s1_sequence_phase ^= 1 @@ -1261,7 +1278,7 @@ def softmax_loop( # tSrScale_r2t[0] = softmax.row_sum[0] # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() - sScale[tidx + stage * 128] = softmax.row_sum[0] + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) @@ -1289,8 +1306,9 @@ def softmax_step( tStScale_r2t: cute.Tensor, tStP_r2t: cute.Tensor, sScale: cute.Tensor, - mask_fn: Optional[Callable], stage: int, + mask_fn: Optional[Callable] = None, + is_first: bool = False, ) -> None: """Perform a single step of the softmax computation on a block of attention scores. @@ -1309,10 +1327,10 @@ def softmax_step( """ tilePlikeFP32 = self.mma_tiler_qk[1] // cutlass.Float32.width * self.v_dtype.width tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) tScP = cute.make_tensor(tScS.iterator, tScP_layout) tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape @@ -1322,18 +1340,22 @@ def softmax_step( cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) if cutlass.const_expr(mask_fn is not None): mask_fn(tSrS_t2r, n_block=n_block) - row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load()) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, cutlass.Float32) # tSrScale_r2t[0] = acc_scale # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() - thread_idx = thr_tmem_load.thr_idx - sScale[thread_idx + stage * 128] = acc_scale - # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + if cutlass.const_expr(not is_first): + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * self.m_block_size] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) # Notify correction wg that row_max is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + # print(tSrS_t2r) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait if cutlass.const_expr(self.s0_s1_barrier): cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) @@ -1341,18 +1363,18 @@ def softmax_step( tSrP_r2t = cute.make_tensor( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) - # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) - softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - # print(tSrP_r2t_f32, tStP_r2t) + # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t) # Sequence barrier arrive if cutlass.const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + # print(tSrP_r2t_f32, tStP_r2t) cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) cute.arch.fence_view_async_tmem_store() # Notify mma warp that P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) - softmax.update_row_sum(tSrS_t2r.load(), acc_scale) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.arch.exp2(acc_scale_) @cute.jit @@ -1366,7 +1388,7 @@ def correction_loop( sScale: cute.Tensor, mO: cute.Tensor, sO: cute.Tensor, - tma_atom_o: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, # tile_scheduler, tile_sched_params, @@ -1374,10 +1396,10 @@ def correction_loop( SeqlenInfoCls: Callable, ): tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) - tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((128, 1))) + tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) tStScale_0 = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_scale_layout) tStScale_1 = cute.make_tensor(tStS.iterator + self.tmem_vec1_offset, tStS_scale_layout) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 1))) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) tmem_load_v_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, @@ -1424,7 +1446,7 @@ def correction_loop( # cute.copy(tiled_tmem_load_vec, tStScale_1_t2r, tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[stage] - scale = sScale[thread_idx + stage * 128] + scale = sScale[thread_idx + stage * self.m_block_size] should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 # should_rescale = True # if thread_idx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) @@ -1447,7 +1469,7 @@ def correction_loop( # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] - scale = sScale[thread_idx + stage * 128] + scale = sScale[thread_idx + stage * self.m_block_size] cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) @@ -1467,7 +1489,7 @@ def correction_loop( # gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) # gO = gO_qdl[None, None, None, (head_idx, batch_idx)] # tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( - # tma_atom_o, + # tma_atom_O, # 0, # cute.make_layout(1), # cute.group_modes(sO, 0, 2), @@ -1480,7 +1502,7 @@ def correction_loop( # # 1. wait for O0 / O1 final # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) # # 2. copy O0 / O1 to gmem - # cute.copy(tma_atom_o, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) # cute.arch.cp_async_bulk_commit_group() # # Ensure O0 / O1 buffer is ready to be released # cute.arch.cp_async_bulk_wait_group(0, read=True) @@ -1524,8 +1546,8 @@ def correction_rescale( self.pv_acc_dtype, ) - tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) @@ -1538,8 +1560,9 @@ def correction_rescale( tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) - tOrO_frg = cute.make_fragment((tOrO_t2r_shape, 128 // corr_tile_size), self.pv_acc_dtype) - for i in range(self.cta_tiler[2] // corr_tile_size): + frg_count = self.head_dim_v_padded // corr_tile_size + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) + for i in range(frg_count): tOrO_frg_i = tOrO_frg[None, i] tTMrO_i_layout = cute.composition(tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0])) tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) @@ -1590,9 +1613,9 @@ def correction_epilogue( tOsO = thr_mma.partition_C(sO) tOcO = thr_mma.partition_C(cO) - tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) - tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) - tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size))) epi_subtile = (self.epi_tile[0], corr_tile_size) tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( @@ -1620,7 +1643,7 @@ def correction_epilogue( tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) - for i in range(self.cta_tiler[2] // corr_tile_size): + for i in range(self.head_dim_v_padded // corr_tile_size): tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) @@ -1645,7 +1668,7 @@ def epilogue_s2g( tile_scheduler, mO: cute.Tensor, sO: cute.Tensor, - tma_atom_o: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, ): gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) @@ -1655,7 +1678,7 @@ def epilogue_s2g( m_block, head_idx, batch_idx = work_tile.tile_idx gO = gO_qdl[None, None, None, (head_idx, batch_idx)] tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( - tma_atom_o, + tma_atom_O, 0, cute.make_layout(1), cute.group_modes(sO, 0, 2), @@ -1666,7 +1689,7 @@ def epilogue_s2g( # 1. wait for O0 / O1 final cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem - cute.copy(tma_atom_o, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) cute.arch.cp_async_bulk_commit_group() for stage in range(2): # Ensure O0 / O1 buffer is ready to be released diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9ed247232e4..9743b4a7222 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -148,9 +148,8 @@ def _flash_attn_fwd( ) else: fa_fwd = FlashAttentionForwardSm100( - cutlass.Float32, - cutlass.Float32, - (128, 128, head_dim), + head_dim, + head_dim_v, is_causal=causal, qhead_per_kvhead=qhead_per_kvhead, is_persistent=True, diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 58f8c12c26c..cb9bd1c897f 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -129,25 +129,69 @@ def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[flo self.rescale_threshold = rescale_threshold @cute.jit - def update_row_max(self, acc_S_row: cute.TensorSSA) -> Tuple[Float32, Float32]: - row_max_old = self.row_max[0] - row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) - row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 - acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 - if cutlass.const_expr(self.rescale_threshold > 0.0): - if acc_scale_ >= -self.rescale_threshold: - row_max_new = row_max_old - row_max_safe = row_max_old - acc_scale_ = 0.0 - acc_scale = utils.exp2f(acc_scale_) + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + # row_max_new = self._compute_row_max(acc_S_row, init_val=-Float32.inf) + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale_ = 0.0 + acc_scale = utils.exp2f(acc_scale_) self.row_max[0] = row_max_new return row_max_safe, acc_scale - def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32) -> None: - self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) # tmp = self._compute_row_sum(acc_S_row_exp) # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in range(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide(acc_S_row_converted, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + def scale_apply_exp2_convert( self, acc_S_row: cute.Tensor, diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 82398622093..6fa2609c98f 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -38,7 +38,7 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -79,7 +79,8 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] From 4834bb596cb23ac70016f2384aaa75e0de7c0fba Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 29 Jun 2025 20:31:14 -0400 Subject: [PATCH 008/258] [Cute] Test flash_fwd_sm100.py with hdim 96 --- flash_attn/cute/flash_fwd_sm100.py | 5 +++-- flash_attn/cute/interface.py | 4 ++-- tests/cute/test_flash_attn.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6ea681e0837..a69dc102f64 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1,7 +1,8 @@ -# Supported features, currently only tested for hdim 64 and 128. +# Supported features: # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA +# - hdim 64, 96, 128. # Unsupported features that will be added later: # - varlen # - writing out lse @@ -214,7 +215,7 @@ def __init__( self.is_causal = is_causal self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False - self.s0_s1_barrier = head_dim == 64 # Does S1 need to wait for S0 to finish + self.s0_s1_barrier = self.head_dim_padded in [64, 96] # Does S1 need to wait for S0 to finish self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9743b4a7222..38acf80a2ca 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -93,7 +93,7 @@ def _flash_attn_fwd( assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" - alignment = 128 // q.element_size() + alignment = 16 // q.element_size() assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: @@ -209,7 +209,7 @@ def _flash_attn_bwd( assert all(t.is_cuda for t in (q, k, v, out, dout, lse)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" - alignment = 128 // q.element_size() + alignment = 16 // q.element_size() assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 6fa2609c98f..80e5fae1f09 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -38,7 +38,7 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From b517a592049ed81a4cf9ad3aa4b4a7372e9d9a56 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 29 Jun 2025 22:41:18 -0400 Subject: [PATCH 009/258] [Cute] Write out LSE for flash_fwd_sm100 --- flash_attn/cute/flash_fwd.py | 6 +-- flash_attn/cute/flash_fwd_sm100.py | 82 +++++++++++++++++++++++++----- flash_attn/cute/interface.py | 7 +-- tests/cute/test_flash_attn.py | 5 +- 4 files changed, 79 insertions(+), 21 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 2b4372f1811..4a59491ee94 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -280,7 +280,6 @@ def epilogue( m_block: cutlass.Int32, head_idx: cutlass.Int32, batch_idx: cutlass.Int32, - is_varlen: cutlass.Constexpr[bool] = False, ): # store acc_O rO = cute.make_fragment_like(acc_O, self.dtype) @@ -299,7 +298,7 @@ def epilogue( # Write LSE from rmem -> gmem if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not is_varlen): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) @@ -1061,7 +1060,7 @@ def __call__( for t in (mK, mV) ] LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 @@ -1350,7 +1349,6 @@ def kernel( self.epilogue( acc_O, softmax.row_sum, mO if not self.use_tma_O else mO_tma, mLSE, sO, seqlen, gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, - is_varlen=cutlass.const_expr(mCuSeqlensQ is not None), ) @cute.jit diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index a69dc102f64..0669196b8bc 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -5,7 +5,6 @@ # - hdim 64, 96, 128. # Unsupported features that will be added later: # - varlen -# - writing out lse # - split-kv (optimizing for inference) # - testing more hdim (64, 256, etc) # Based on the cutlass example and cute-dsl example: @@ -332,6 +331,8 @@ def __call__( cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] + LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None # (s, d, h, b) -> (s, d, (h, b)) mQ, mK, mV, mO = [cute.group_modes(t, begin=2, end=4) for t in (mQ, mK, mV, mO)] @@ -477,7 +478,7 @@ class SharedStorage: # Tmem holding buffer tmem_holding_buf: cutlass.Int32 # Smem tensors - sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size] + sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size * (1 if mLSE is None else 2)] sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout_staged)], self.buffer_align_bytes, @@ -690,7 +691,10 @@ def kernel( ) SeqlenInfoCls = partial( - SeqlenInfo, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0] + SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) if warp_idx >= 12: @@ -797,6 +801,7 @@ def kernel( softmax_scale_log2=softmax_scale_log2, thr_mma_qk=thr_mma_qk, sScale=sScale, + mLSE=mLSE, mbar_ptr=mbar_ptr, tile_scheduler=tile_scheduler, block_info=block_info, @@ -834,9 +839,11 @@ def kernel( tOtO1, sScale, mO, + mLSE, sO, tma_atom_O, mbar_ptr, + softmax_scale_log2, tile_sched_params, block_info, SeqlenInfoCls, @@ -1146,6 +1153,7 @@ def softmax_loop( thr_mma_qk: cute.core.ThrMma, tStSi: cute.Tensor, sScale: cute.Tensor, + mLSE: Optional[cute.Tensor], mbar_ptr: cute.Pointer, tile_scheduler, block_info: BlockInfo, @@ -1280,9 +1288,32 @@ def softmax_loop( # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - + if cutlass.const_expr(mLSE is not None): + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + + # # Write LSE to gmem + # if cutlass.const_expr(mLSE is not None): + # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] + # scale = ( + # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) + # ) + # LN2 = math.log(2.0) + # lse = ( + # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 + # if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + # ) + # if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + # mLSE_cur = mLSE[None, head_idx, batch_idx] + # else: + # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) + # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + # gLSE[tidx] = lse + # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1388,9 +1419,11 @@ def correction_loop( tOtO1: cute.Tensor, sScale: cute.Tensor, mO: cute.Tensor, + mLSE: cute.Tensor, sO: cute.Tensor, tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, + softmax_scale_log2: cutlass.Float32, # tile_scheduler, tile_sched_params, block_info: BlockInfo, @@ -1406,8 +1439,8 @@ def correction_loop( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, ) tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScale_0) - thread_idx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(thread_idx) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(tidx) tStScale_0_t2r = thr_tmem_load_vec.partition_S(tStScale_0) tStScale_1_t2r = thr_tmem_load_vec.partition_S(tStScale_1) @@ -1447,16 +1480,16 @@ def correction_loop( # cute.copy(tiled_tmem_load_vec, tStScale_1_t2r, tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[stage] - scale = sScale[thread_idx + stage * self.m_block_size] + scale = sScale[tidx + stage * self.m_block_size] should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 # should_rescale = True - # if thread_idx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) # should_rescale = True # Don't need O_full anymore, since by the time softmax has signaled the correction # warps, S_i must have been done, so O_i-1 must have been done as well. # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) if should_rescale: - self.correction_rescale(thr_mma_pv, tOtOs[stage], thread_idx, scale) + self.correction_rescale(thr_mma_pv, tOtOs[stage], tidx, scale) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) softmax_corr_consumer_phase ^= 1 @@ -1465,23 +1498,48 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + stats = [None, None] for stage in range(2): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] - scale = sScale[thread_idx + stage * self.m_block_size] + row_sum = sScale[tidx + stage * self.m_block_size] + if cutlass.const_expr(mLSE is not None): + row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + else: + row_max = None cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) self.correction_epilogue( - thr_mma_pv, tOtOs[stage], thread_idx, 1.0 / scale, sO[None, None, stage], + thr_mma_pv, tOtOs[stage], tidx, scale, sO[None, None, stage], ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so # mma warp can write to them cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - # if thread_idx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2,)) + for stage in range(2): + row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] + # if tidx == 0 and stage <= 1: + # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 + if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + ) + if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + gLSE[tidx + stage * self.m_block_size] = lse o_corr_consumer_phase ^= 1 softmax_corr_consumer_phase ^= 1 diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 38acf80a2ca..3ad5e21eddb 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -105,7 +105,8 @@ def _flash_attn_fwd( q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad else None dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor = [ @@ -113,7 +114,7 @@ def _flash_attn_fwd( t.detach(), leading_dim=t.ndim - 1, divisibility=128 // dtype.width ) for t in (q, k, v, out) ] - lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) + lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) if lse is not None else None cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -125,7 +126,7 @@ def _flash_attn_fwd( assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, - cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, m_block_size, n_block_size, num_threads, compute_capability, ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 80e5fae1f09..552b5c6fc5e 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -38,7 +38,8 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -186,7 +187,7 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and False + # and False ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) From 7661781d001e0900121c000a0aaf21b3f94337d6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 30 Jun 2025 01:35:22 -0400 Subject: [PATCH 010/258] [Cute] Fix fwd_sm90 epilogue when varlen --- flash_attn/cute/flash_fwd.py | 24 +++++++++++++----------- flash_attn/cute/flash_fwd_sm100.py | 3 ++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4a59491ee94..4a84cc7ea1f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -321,7 +321,7 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - if cutlass.const_expr(not is_varlen): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) @@ -1071,7 +1071,7 @@ def __call__( self.num_mma_regs = 240 self.num_producer_regs = 24 self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if self.intra_wg_overlap else (self.num_mma_warp_groups == 2) - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and not self.pack_gqa + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm self._setup_attributes() SharedStorage = self._get_shared_storage_cls() @@ -1099,9 +1099,12 @@ def __call__( (self.n_block_size, self.head_dim_v_padded), 1 # No mcast for now ) - tma_atom_O, tma_tensor_O = cpasync.make_tma_tile_atom( - gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast - ) + if cutlass.const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tma_tile_atom( + gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast + ) + else: + tma_atom_O = None if cutlass.const_expr(self.pack_gqa): shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) @@ -1109,9 +1112,10 @@ def __call__( shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + if cutlass.const_expr(mLSE is not None): + shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) + stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) + mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( cute.ceil_div(cute.size(mQ.shape[0]) if mCuSeqlensQ is None else max_seqlen_q, self.m_block_size), @@ -1135,7 +1139,6 @@ def __call__( tma_tensor_K, tma_tensor_V, mO, - tma_tensor_O, mLSE, mCuSeqlensQ, mCuSeqlensK, @@ -1175,7 +1178,6 @@ def kernel( mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, - mO_tma: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], @@ -1347,7 +1349,7 @@ def kernel( # TODO: idk why not using sO_pi is faster sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) self.epilogue( - acc_O, softmax.row_sum, mO if not self.use_tma_O else mO_tma, mLSE, sO, seqlen, + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0669196b8bc..3914d9b9e0a 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -5,8 +5,9 @@ # - hdim 64, 96, 128. # Unsupported features that will be added later: # - varlen +# - sliding window # - split-kv (optimizing for inference) -# - testing more hdim (64, 256, etc) +# - more hdim (192, 256) # Based on the cutlass example and cute-dsl example: # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py From 10a89168b0a92218f38c393f5c9e691c9feba155 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Jul 2025 21:47:01 -0400 Subject: [PATCH 011/258] [Cute] Implement sliding window for forward pass --- flash_attn/cute/block_info.py | 77 +++++--- flash_attn/cute/flash_fwd.py | 147 +++++++++----- flash_attn/cute/flash_fwd_sm100.py | 308 +++++++++++++++++++---------- flash_attn/cute/interface.py | 43 +++- flash_attn/cute/mask.py | 177 ++++++++++++----- flash_attn/utils/testing.py | 8 +- tests/cute/test_flash_attn.py | 17 +- 7 files changed, 517 insertions(+), 260 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index d91c15c54bb..a3505e5dbb5 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -1,4 +1,6 @@ -from typing import Tuple +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +from typing import Tuple, Optional +from dataclasses import dataclass import cutlass import cutlass.cute as cute @@ -6,37 +8,38 @@ from flash_attn.cute.seqlen_info import SeqlenInfo +@dataclass(frozen=True) class BlockInfo: - - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - n_block_size: cutlass.Constexpr[int], - is_causal: cutlass.Constexpr[bool], - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if we're doing PackGQA - *, - loc=None, - ip=None - ): - self.m_block_size: cutlass.Constexpr[int] = m_block_size - self.n_block_size: cutlass.Constexpr[int] = n_block_size - self.is_causal: cutlass.Constexpr[bool] = is_causal - self.qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = qhead_per_kvhead_packgqa - self._loc = loc + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + is_local: cutlass.Constexpr[bool] = False + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit def get_n_block_min_max( self, seqlen_info: SeqlenInfo, m_block: cutlass.Int32 ) -> Tuple[cutlass.Int32, cutlass.Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) - n_block_min = 0 - if cutlass.const_expr(self.is_causal): + if cutlass.const_expr( + self.is_causal or (self.is_local and self.window_size_right is not None) + ): m_idx_max = (m_block + 1) * self.m_block_size if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): - m_idx_max = (m_idx_max - 1) // self.qhead_per_kvhead_packgqa + 1 + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx - n_block_max = min(cute.ceil_div(n_idx_right, self.n_block_size), n_block_max) + n_idx_right = n_idx if self.is_causal else n_idx + self.window_size_right + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.n_block_size)) + n_block_min = 0 + if cutlass.const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + n_block_min = cutlass.max(n_idx_left // self.n_block_size, 0) return n_block_min, n_block_max @cute.jit @@ -46,16 +49,32 @@ def get_n_block_min_causal_local_mask( m_block: cutlass.Int32, n_block_min: cutlass.Int32, ) -> cutlass.Int32: + """If we have separate iterations with causal or local masking at the start, where do we stop""" m_idx_min = m_block * self.m_block_size if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx + n_idx_right = ( + n_idx + if not self.is_local or self.window_size_right is None + else n_idx + self.window_size_right + ) return cutlass.max(n_block_min, n_idx_right // self.n_block_size) - def __extract_mlir_values__(self): - # We just create a dummy value. Otherwise unpack_to_irvalue in cutlass.py will complain - return [cutlass.Int32(0).ir_value()] - - def __new_from_mlir_values__(self, values): - return BlockInfo(self.m_block_size, self.n_block_size, self.is_causal, self.qhead_per_kvhead_packgqa, loc=self._loc) + @cute.jit + def get_n_block_min_before_local_mask( + self, + seqlen_info: SeqlenInfo, + m_block: cutlass.Int32, + n_block_min: cutlass.Int32, + ) -> cutlass.Int32: + """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" + if cutlass.const_expr(not self.is_local or self.window_size_left is None): + return n_block_min + else: + m_idx_max = (m_block + 1) * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.n_block_size)) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4a84cc7ea1f..825965f9535 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -7,7 +7,7 @@ import math from types import SimpleNamespace -from typing import Type, Callable, Optional +from typing import Type, Callable, Optional, Tuple from functools import partial import cuda.bindings.driver as cuda @@ -41,7 +41,7 @@ def __init__( head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, - has_softcap: bool = False, + is_local: bool = False, pack_gqa: bool = True, m_block_size: int = 128, n_block_size: int = 128, @@ -76,7 +76,7 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal - self.has_softcap = has_softcap + self.is_local = is_local self.pack_gqa = pack_gqa self.m_block_size = m_block_size self.n_block_size = n_block_size @@ -542,9 +542,11 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, stream: cuda.CUstream, + softmax_scale: Optional[cutlass.Float32] = None, + softcap: Optional[cutlass.Float32] = None, + window_size_left: Optional[cutlass.Int32] = None, + window_size_right: Optional[cutlass.Int32] = None, ): """Configures and launches the flash attention kernel. @@ -575,12 +577,12 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(not self.has_softcap): + if cutlass.const_expr(softcap is not None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = cutlass.Float32(0.0) + softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = softmax_scale / softcap + softcap_val = cutlass.Float32(softmax_scale / softcap) self.kernel( mQ, mK, @@ -589,6 +591,8 @@ def __call__( mLSE, softmax_scale_log2, softcap_val, + window_size_left, + window_size_right, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -617,7 +621,9 @@ def kernel( mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale_log2: cutlass.Float32, - softcap_val: cutlass.Float32, + softcap_val: Optional[cutlass.Float32], + window_size_left: cutlass.Int32, + window_size_right: cutlass.Int32, sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -636,8 +642,9 @@ def kernel( m_block, num_head, batch_size = cute.arch.block_idx() block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, - self.qhead_per_kvhead if self.pack_gqa else 1, + self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, ) seqlen = SeqlenInfo(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) @@ -754,7 +761,7 @@ def kernel( # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn # -inf to e.g. -50.0, which can affect the attention softmax. def scoremod_premask_fn(acc_S): - if cutlass.const_expr(self.has_softcap): + if cutlass.const_expr(softcap_val is not None): acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) compute_one_n_block = partial( @@ -808,10 +815,12 @@ def preprocess_Q(): # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, + window_size_left, window_size_right, self.qhead_per_kvhead if self.pack_gqa else 1, ) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, + mask_causal=self.is_causal, mask_local=self.is_local, ) # First iteration with seqlen masking @@ -822,7 +831,7 @@ def preprocess_Q(): smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking - if self.is_causal: + if self.is_causal or self.is_local: n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) @@ -839,6 +848,7 @@ def preprocess_Q(): compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # TODO: local # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize() @@ -1031,14 +1041,16 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mCuSeqlensK: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - mSeqUsedK: Optional[cute.Tensor], - max_seqlen_q: Optional[cutlass.Int32], softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + max_seqlen_q: Optional[cutlass.Int32] = None, + softcap: cutlass.Float32 | float | None = None, + window_size_left: cutlass.Int32 | int | None = None, + window_size_right: cutlass.Int32 | int | None = None, ): """Configures and launches the flash attention kernel. @@ -1070,6 +1082,8 @@ def __call__( self.num_epilogue_threads = self.num_mma_threads self.num_mma_regs = 240 self.num_producer_regs = 24 + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if self.intra_wg_overlap else (self.num_mma_warp_groups == 2) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm @@ -1079,7 +1093,7 @@ def __call__( gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() - self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, self.sQ_layout) + self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) tma_atom_Q, tma_tensor_Q = cpasync.make_tma_tile_atom( @@ -1128,12 +1142,16 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(not self.has_softcap): + if cutlass.const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = cutlass.Float32(0.0) + softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = softmax_scale / softcap + softcap_val = cutlass.Float32(softmax_scale / softcap) + if cutlass.const_expr(window_size_left is not None): + window_size_left = cutlass.Int32(window_size_left) + if cutlass.const_expr(window_size_right is not None): + window_size_right = cutlass.Int32(window_size_right) self.kernel( tma_tensor_Q if not self.pack_gqa else mQ, tma_tensor_K, @@ -1150,6 +1168,8 @@ def __call__( tma_atom_O, softmax_scale_log2, softcap_val, + window_size_left, + window_size_right, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1162,7 +1182,7 @@ def __call__( # the compiler is unhappy about us using tiled_mma_qk/pv and setting the ACCUMULATE # field inside a for loop, so we work around by creating multiple copies of the # tiled_mma_qk/pv. - *((tiled_mma_qk, tiled_mma_pv) * 3), + *((tiled_mma_qk, tiled_mma_pv) * 4), SharedStorage, ).launch( grid=grid_dim, @@ -1188,7 +1208,9 @@ def kernel( tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: cutlass.Float32, - softcap_val: cutlass.Float32, + softcap_val: Optional[cutlass.Float32], + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1204,6 +1226,8 @@ def kernel( tiled_mma_pv_copy: cute.TiledMma, tiled_mma_qk_copy1: cute.TiledMma, tiled_mma_pv_copy1: cute.TiledMma, + tiled_mma_qk_copy2: cute.TiledMma, + tiled_mma_pv_copy2: cute.TiledMma, SharedStorage: cutlass.Constexpr, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1271,8 +1295,9 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() m_block, head_idx, batch_idx = cute.arch.block_idx() block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, - self.qhead_per_kvhead if self.pack_gqa else 1, + self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, ) SeqlenInfoCls = partial( SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], @@ -1280,6 +1305,11 @@ def kernel( mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) + AttentionMaskCls = partial( + AttentionMask, self.m_block_size, self.n_block_size, + window_size_left=window_size_left, window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + ) seqlen = SeqlenInfoCls(batch_idx) # Can't early exit so we have to write it this way (under an if statement) if mCuSeqlensQ is None or m_block * self.n_block_size < seqlen.seqlen_q: @@ -1336,10 +1366,13 @@ def kernel( softcap_val, block_info, SeqlenInfoCls, + AttentionMaskCls, tiled_mma_qk_copy, tiled_mma_pv_copy, tiled_mma_qk_copy1, tiled_mma_pv_copy1, + tiled_mma_qk_copy2, + tiled_mma_pv_copy2, ) # /////////////////////////////////////////////////////////////////////////////// # Epilogue @@ -1422,6 +1455,8 @@ def load( cute.arch.mbarrier_init_tx_bytes(mbar_ptr_Q, self.tma_copy_q_bytes) cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) for i in cutlass.range_dynamic(n_block_max - n_block_min, unroll=2): n_block = n_block_max - i - 1 load_K(n_block, producer_state=kv_producer_state) @@ -1448,10 +1483,13 @@ def mma( softcap_val: cutlass.Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, tiled_mma_qk_copy: cute.TiledMma, tiled_mma_pv_copy: cute.TiledMma, tiled_mma_qk_copy1: cute.TiledMma, tiled_mma_pv_copy1: cute.TiledMma, + tiled_mma_qk_copy2: cute.TiledMma, + tiled_mma_pv_copy2: cute.TiledMma, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1487,7 +1525,7 @@ def mma( # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn # -inf to e.g. -50.0, which can affect the attention softmax. def scoremod_premask_fn(acc_S): - if cutlass.const_expr(self.has_softcap): + if cutlass.const_expr(softcap_val is not None): acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) mma_one_n_block = partial( @@ -1502,12 +1540,10 @@ def scoremod_premask_fn(acc_S): if cutlass.const_expr(self.is_causal): # Longest tile first m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 - mask = AttentionMask( - self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, - self.qhead_per_kvhead if self.pack_gqa else 1 - ) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, + mask_causal=self.is_causal, mask_local=self.is_local, ) # Load Q if PackGQA if cutlass.const_expr(self.pack_gqa): @@ -1524,7 +1560,6 @@ def scoremod_premask_fn(acc_S): utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - n_block = n_block_max - 1 consumer_state = pipeline.make_pipeline_state( cutlass.utils.PipelineUserType.Consumer, self.num_stages ) @@ -1546,7 +1581,9 @@ def scoremod_premask_fn(acc_S): ) pipeline_k.consumer_release(consumer_state) scoremod_premask_fn(acc_S) - mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) softmax.online_softmax(acc_S, is_first=True, check_inf=True) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) @@ -1561,27 +1598,44 @@ def scoremod_premask_fn(acc_S): else: self.warp_scheduler_barrier_sync() consumer_state = mma_one_n_block( - n_block, consumer_state, tiled_mma_qk, tiled_mma_pv, + n_block_max - 1, consumer_state, tiled_mma_qk, tiled_mma_pv, is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal): + if cutlass.const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 2 - n_tile + for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( n_block, consumer_state, tiled_mma_qk_copy, tiled_mma_pv_copy, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block, unroll=1): + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( - n_block - n_tile - 1, consumer_state, tiled_mma_qk_copy1, tiled_mma_pv_copy1, - check_inf=True, + n_block, consumer_state, tiled_mma_qk_copy1, tiled_mma_pv_copy1, check_inf=True, ) + # Separate iterations with local masking on the left + if cutlass.const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + consumer_state = mma_one_n_block( + n_block, consumer_state, tiled_mma_qk_copy2, tiled_mma_pv_copy2, + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + ) # Last "half" iteration if cutlass.const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(consumer_state, pipeline_v.consumer_try_wait(consumer_state)) @@ -1633,8 +1687,7 @@ def mma_one_n_block( if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) - # if cute.arch.thread_idx()[0] == 0: - # cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) @@ -1693,10 +1746,10 @@ def mma_one_n_block_intrawg_overlap( warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) scoremod_premask_fn(acc_S) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: - # cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 3914d9b9e0a..c0cccf6c1c1 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -3,9 +3,9 @@ # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128. +# - sliding window # Unsupported features that will be added later: # - varlen -# - sliding window # - split-kv (optimizing for inference) # - more hdim (192, 256) # Based on the cutlass example and cute-dsl example: @@ -21,6 +21,7 @@ import cutlass import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic @@ -182,12 +183,16 @@ def create_fmha_static_tile_scheduler( class FlashAttentionForwardSm100: + + arch = 100 + def __init__( self, # dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, is_causal: bool = False, + is_local: bool = False, qhead_per_kvhead: cutlass.Constexpr[int] = 1, m_block_size: int = 128, n_block_size: int = 128, @@ -201,6 +206,8 @@ def __init__( self.same_hdim_kv = head_dim == head_dim_v assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.m_block_size = m_block_size self.n_block_size = n_block_size # 2 Q tile per CTA @@ -213,8 +220,10 @@ def __init__( self.is_persistent = is_persistent self.is_even_N = False self.is_causal = is_causal + self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False + self.use_tma_O = True self.s0_s1_barrier = self.head_dim_padded in [64, 96] # Does S1 need to wait for S0 to finish self.softmax0_warp_ids = (0, 1, 2, 3) @@ -222,7 +231,7 @@ def __init__( self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 self.load_warp_id = 13 - self.epilogue_warp_id = 14 + self.epilogue_warp_ids = (14,) self.empty_warp_id = 15 SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS @@ -234,7 +243,7 @@ def __init__( *self.correction_warp_ids, self.mma_warp_id, self.load_warp_id, - self.epilogue_warp_id, + *self.epilogue_warp_ids, self.empty_warp_id, ) ) @@ -294,14 +303,16 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mCuSeqlensK: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - mSeqUsedK: Optional[cute.Tensor], - max_seqlen_q: Optional[cutlass.Int32], softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + max_seqlen_q: Optional[cutlass.Int32] = None, + softcap: cutlass.Float32 | float | None = None, + window_size_left: cutlass.Int32 | int | None = None, + window_size_right: cutlass.Int32 | int | None = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -334,10 +345,8 @@ def __call__( ] LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None - - # (s, d, h, b) -> (s, d, (h, b)) - mQ, mK, mV, mO = [cute.group_modes(t, begin=2, end=4) for t in (mQ, mK, mV, mO)] - mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=[1, 0, 2])) + # (s, d, h, b) -> (d, s, h, b) + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=[1, 0, 2, 3])) self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() @@ -357,6 +366,7 @@ def __call__( if cutlass.const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa cta_group = tcgen05.CtaGroup.ONE # the intermediate tensor p is from tmem & mK-major @@ -405,8 +415,8 @@ def __call__( ) # TMA load for Q - tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) - tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() sQ_layout = cute.select(sQ_layout_staged, mode=[0, 1, 2]) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tma_tile_atom_A( @@ -444,12 +454,32 @@ def __call__( ) sO_layout = cute.select(sO_layout_staged, mode=[0, 1]) - tma_atom_O, tma_tensor_O = cute.nvgpu.cpasync.make_tma_tile_atom( - tma_store_op, - mO, - sO_layout, - o_cta_v_layout, - ) + # print(sO_layout.outer) + self.epilogue_warp_ids = (14,) if self.use_tma_O else (14, 15) + self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + if cutlass.const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tma_tile_atom( + tma_store_op, + mO, + sO_layout, + o_cta_v_layout, + ) + gmem_tiled_copy_O = None + else: + tma_atom_O = None + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.o_dtype.width + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.o_dtype, num_bits_per_copy=universal_copy_bits, + ) + tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + vO_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, sQ_layout) self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, sK_layout) @@ -501,20 +531,22 @@ class SharedStorage: # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - # if cutlass.const_expr(not self.has_softcap): - if cutlass.const_expr(True): + if cutlass.const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = cutlass.Float32(0.0) + softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = softmax_scale / softcap - + softcap_val = cutlass.Float32(softmax_scale / softcap) + if cutlass.const_expr(window_size_left is not None): + window_size_left = cutlass.Int32(window_size_left) + if cutlass.const_expr(window_size_right is not None): + window_size_right = cutlass.Int32(window_size_right) # Launch the kernel synchronously self.kernel( tma_tensor_Q, tma_tensor_K, tma_tensor_V, - tma_tensor_O, + mO, mLSE, mCuSeqlensQ, mCuSeqlensK, @@ -524,14 +556,18 @@ class SharedStorage: tma_atom_K, tma_atom_V, tma_atom_O, - tiled_mma_qk, - tiled_mma_pv, softmax_scale_log2, + softcap_val, + window_size_left, + window_size_right, sQ_layout_staged, sK_layout_staged, tP_layout_staged, sV_layout_staged, sO_layout_staged, + gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, self.tile_sched_params, ).launch( grid=grid, @@ -559,14 +595,18 @@ def kernel( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_O: cute.CopyAtom, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, softmax_scale_log2: cutlass.Float32, + softcap_val: Optional[cutlass.Float32], + window_size_left: Optional[cutlass.Int32], + window_size_right: Optional[cutlass.Int32], sQ_layout_staged: cute.ComposedLayout, sK_layout_staged: cute.ComposedLayout, tP_layout_staged: cute.ComposedLayout, sV_layout_staged: cute.ComposedLayout, sO_layout_staged: cute.ComposedLayout, + gmem_tiled_copy_O: Optional[cute.TiledCopy], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, tile_sched_params: FmhaStaticTileSchedulerParams, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -586,30 +626,42 @@ def kernel( # coord inside cta tidx, _, _ = cute.arch.thread_idx() + if cutlass.const_expr(not self.pack_gqa): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if cutlass.const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) + # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) mbar_ptr = storage.mbar_ptr.data_ptr() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 0: + if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in range(self.q_stage): cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + if warp_idx == 2: for i in range(2): cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + if warp_idx == 3: if cutlass.const_expr(self.s0_s1_barrier): for i in range(8): cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + if warp_idx == 4: for i in range(2): cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len([self.epilogue_warp_id])) + cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + if warp_idx == 5: for i in range(2): cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + if warp_idx == 6: cute.arch.mbarrier_init_arrive_cnt( mbar_ptr + self.mbar_max_reg_setting_offset, cute.arch.WARP_SIZE @@ -618,11 +670,12 @@ def kernel( self.empty_warp_id, self.load_warp_id, self.mma_warp_id, - self.epilogue_warp_id, + *self.epilogue_warp_ids, *self.correction_warp_ids, ) ), ) + if warp_idx == 7: cute.arch.mbarrier_init_arrive_cnt( mbar_ptr + self.mbar_tmem_dealloc_offset, cute.arch.WARP_SIZE @@ -637,13 +690,6 @@ def kernel( # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) - block_info = BlockInfo( - # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) - self.cta_tiler[0], self.cta_tiler[1], - is_causal=self.is_causal, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, - ) - # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) sQ = storage.sQ.get_tensor(sQ_layout_staged.outer, swizzle=sQ_layout_staged.inner) @@ -691,12 +737,23 @@ def kernel( tOrP.layout, ) + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + ) SeqlenInfoCls = partial( SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) + AttentionMaskCls = partial( + AttentionMask, self.m_block_size, self.n_block_size, + window_size_left=window_size_left, window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + ) if warp_idx >= 12: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) @@ -780,11 +837,11 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.epilogue_warp_id: + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: tile_scheduler = create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) - self.epilogue_s2g(tile_scheduler, mO, sO, tma_atom_O, mbar_ptr) + self.epilogue_s2g(tile_scheduler, mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -807,6 +864,7 @@ def kernel( tile_scheduler=tile_scheduler, block_info=block_info, SeqlenInfoCls=SeqlenInfoCls, + AttentionMaskCls=AttentionMaskCls, ) if cutlass.const_expr(not self.s0_s1_barrier): @@ -874,34 +932,34 @@ def load( SeqlenInfoCls: Callable, ): # (bM, bK, loopM, loopL) - gQ_qdl = cute.local_tile(mQ, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0, None)) - tSgQ_qdl = thr_mma_qk.partition_A(gQ_qdl) + gQ_qdhb = cute.local_tile(mQ, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0, None, None)) + tSgQ_qdhb = thr_mma_qk.partition_A(gQ_qdhb) # (bN, bK, loopN, loopL) - gK_kdl = cute.local_tile(mK, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) - tSgK_kdl = thr_mma_qk.partition_B(gK_kdl) + gK_kdhb = cute.local_tile(mK, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None, None)) + tSgK_kdhb = thr_mma_qk.partition_B(gK_kdhb) # (bK, bN, loopN, loopL) - gV_dkl = cute.local_tile(mV, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None, None)) - tOgV_dkl = thr_mma_pv.partition_B(gV_dkl) - tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + gV_dkhb = cute.local_tile(mV, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None, None, None)) + tOgV_dkhb = thr_mma_pv.partition_B(gV_dkhb) + tQsQ, tQgQ_qdhb = cpasync.tma_partition( tma_atom_Q, 0, # no multicast cute.make_layout(1), cute.group_modes(sQ, 0, 3), - cute.group_modes(tSgQ_qdl, 0, 3), + cute.group_modes(tSgQ_qdhb, 0, 3), ) - tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tKsK, tKgK_kdhb = cpasync.tma_partition( tma_atom_K, 0, # no multicast cute.make_layout(1), cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK_kdl, 0, 3), + cute.group_modes(tSgK_kdhb, 0, 3), ) - tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tVsV, tVgV_dkl = cpasync.tma_partition( tma_atom_V, 0, # no multicast cute.make_layout(1), cute.group_modes(sV, 0, 3), - cute.group_modes(tOgV_dkl, 0, 3), + cute.group_modes(tOgV_dkhb, 0, 3), ) q_producer_phase = cutlass.Int32(1) @@ -909,9 +967,9 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx - tQgQ = tQgQ_qdl[None, None, (head_idx, batch_idx)] + tQgQ = tQgQ_qdhb[None, None, head_idx, batch_idx] head_idx_kv = head_idx // self.qhead_per_kvhead - tKgK, tVgV = [t[None, None, (head_idx_kv, batch_idx)] for t in (tKgK_kdl, tVgV_dkl)] + tKgK, tVgV = [t[None, None, head_idx_kv, batch_idx] for t in (tKgK_kdhb, tVgV_dkl)] def load_Q(stage: int): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_empty_offset + stage, q_producer_phase) @@ -1159,6 +1217,7 @@ def softmax_loop( tile_scheduler, block_info: BlockInfo, SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1224,12 +1283,9 @@ def softmax_loop( m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - mask = AttentionMask( - self.mma_tiler_qk[0], self.mma_tiler_qk[1], seqlen.seqlen_q, seqlen.seqlen_k, - self.qhead_per_kvhead if self.pack_gqa else 1, - ) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask_sm100, m_block=m_block, m_stage=stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal + mask.apply_mask_sm100, m_block=m_block * 2 + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local ) softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if self.q_dtype.width == 16 else 0.0) softmax.reset() @@ -1256,33 +1312,31 @@ def softmax_loop( # 1 masking iter if cutlass.const_expr(not self.is_even_N): # mask_trip_count = 1 if seqlen.seqlen_k % self.mma_tiler_qk[1] == 0 else 0 - softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=False, mask_fn=partial(mask_fn, mask_seqlen=True)) - si_corr_producer_phase ^= 1 - mma_si_consumer_phase ^= 1 - s0_s1_sequence_phase ^= 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=False, mask_fn=partial(mask_fn, mask_seqlen=True)) n_block_max -= 1 # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal): + if cutlass.const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 for n_tile in cutlass.range_dynamic(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) - si_corr_producer_phase ^= 1 - mma_si_consumer_phase ^= 1 - s0_s1_sequence_phase ^= 1 - n_block_max = n_block_min_causal_local_mask + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block_max, unroll=1): + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range_dynamic(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 - softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) - si_corr_producer_phase ^= 1 - mma_si_consumer_phase ^= 1 - s0_s1_sequence_phase ^= 1 - - # mma_softmax_pipeline.sync_object_array_full.wait(stage, mma_si_consumer_phase) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) + # Separate iterations with local masking on the left + if cutlass.const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, cutlass.Float32) # tSrScale_r2t[0] = softmax.row_sum[0] @@ -1342,7 +1396,7 @@ def softmax_step( stage: int, mask_fn: Optional[Callable] = None, is_first: bool = False, - ) -> None: + ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: """Perform a single step of the softmax computation on a block of attention scores. This method processes one block of the attention matrix, computing numerically stable @@ -1409,6 +1463,7 @@ def softmax_step( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.arch.exp2(acc_scale_) + return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 @cute.jit def correction_loop( @@ -1485,7 +1540,6 @@ def correction_loop( should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 # should_rescale = True # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) - # should_rescale = True # Don't need O_full anymore, since by the time softmax has signaled the correction # warps, S_i must have been done, so O_i-1 must have been done as well. # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) @@ -1546,9 +1600,9 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 - # gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) - # gO = gO_qdl[None, None, None, (head_idx, batch_idx)] - # tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + # gO_qdhb = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None, None)) + # gO = gO_qdhb[None, None, None, head_idx, batch_idx] + # tOsO, tOgO = cpasync.tma_partition( # tma_atom_O, # 0, # cute.make_layout(1), @@ -1728,33 +1782,69 @@ def epilogue_s2g( tile_scheduler, mO: cute.Tensor, sO: cute.Tensor, - tma_atom_O: cute.CopyAtom, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], mbar_ptr: cute.Pointer, + SeqlenInfoCls: Callable, ): - gO_qdl = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None)) epi_consumer_phase = cutlass.Int32(0) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx - gO = gO_qdl[None, None, None, (head_idx, batch_idx)] - tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( - tma_atom_O, - 0, - cute.make_layout(1), - cute.group_modes(sO, 0, 2), - cute.group_modes(gO, 0, 2), - ) - for stage in range(2): - # wait from corr, issue tma store on smem - # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) - # 2. copy O0 / O1 to gmem - cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) - cute.arch.cp_async_bulk_commit_group() - for stage in range(2): - # Ensure O0 / O1 buffer is ready to be released - cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + seqlen = SeqlenInfoCls(batch_idx) + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + if cutlass.const_expr(self.use_tma_O): + tOsO, tOgO = cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + for stage in range(2): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + # 2. copy O0 / O1 to gmem + cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + cute.arch.cp_async_bulk_commit_group() + for stage in range(2): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + else: + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epi_warp_ids)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + for stage in range(2): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + # TODO: need stage + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + # Advance to next tile epi_consumer_phase ^= 1 tile_scheduler.advance_to_next_work() @@ -1813,17 +1903,17 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): @staticmethod def _compute_grid( - o: cute.Tensor, + mO: cute.Tensor, cta_tiler: Tuple[int, int, int], is_persistent: bool, ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: - o_shape = o.shape + o_shape = mO.shape tile_sched_params = create_fmha_static_tile_scheduler_params( is_persistent, ( cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), - cute.size(o_shape[2][0]), - cute.size(o_shape[2][1]), + cute.size(o_shape[2]), + cute.size(o_shape[3]), ), ) grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 3ad5e21eddb..bbab8301522 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -4,7 +4,6 @@ # Lightly tested with headdim 128. # Features not supported yet: # - varlen -# - sliding window # - split (i.e. FlashDecoding) # - tuned block sizes # - paged KV @@ -52,7 +51,9 @@ def _flash_attn_fwd( max_seqlen_q: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, - softcap: float = 0.0, + softcap: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, @@ -98,6 +99,8 @@ def _flash_attn_fwd( assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) + if softcap == 0.0: + softcap = None qhead_per_kvhead = num_head // num_head_kv out_torch_dtype = q.dtype @@ -120,13 +123,22 @@ def _flash_attn_fwd( for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] max_seqlen_q = cutlass.Int32(max_seqlen_q) if max_seqlen_q is not None else None + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + else: + causal, local = False, True current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + window_size_left is not None, window_size_right is not None, m_block_size, n_block_size, num_threads, compute_capability, ) @@ -139,7 +151,7 @@ def _flash_attn_fwd( head_dim_v, qhead_per_kvhead, is_causal=causal, - has_softcap=softcap != 0.0, + is_local=local, m_block_size=m_block_size, n_block_size=n_block_size, # num_stages=1, @@ -152,19 +164,20 @@ def _flash_attn_fwd( head_dim, head_dim_v, is_causal=causal, + is_local=local, qhead_per_kvhead=qhead_per_kvhead, is_persistent=True, ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, + fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softmax_scale, softcap, current_stream + max_seqlen_q, softcap, window_size_left, window_size_right, ) _flash_attn_fwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, + q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softmax_scale, softcap, current_stream + max_seqlen_q, softcap, window_size_left, window_size_right, ) return out, lse @@ -367,6 +380,7 @@ def forward( v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -375,11 +389,14 @@ def forward( v, softmax_scale=softmax_scale, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.softcap = softcap return out, lse @@ -397,7 +414,7 @@ def backward(ctx, dout, *args): ctx.causal, ctx.softcap, ) - return dq, dk, dv, *((None,) * 3) + return dq, dk, dv, *((None,) * 4) class FlashAttnVarlenFunc(torch.autograd.Function): @@ -415,6 +432,7 @@ def forward( max_seqlen_q: Optional[int], softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -428,12 +446,15 @@ def forward( max_seqlen_q, softmax_scale=softmax_scale, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.softcap = softcap return out, lse @@ -451,6 +472,7 @@ def flash_attn_func( v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), softcap: float = 0.0, ): return FlashAttnFunc.apply( @@ -459,6 +481,7 @@ def flash_attn_func( v, softmax_scale, causal, + window_size, softcap, ) @@ -474,6 +497,7 @@ def flash_attn_varlen_func( max_seqlen_q: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), softcap: float = 0.0, ): return FlashAttnVarlenFunc.apply( @@ -487,5 +511,6 @@ def flash_attn_varlen_func( max_seqlen_q, softmax_scale, causal, + window_size, softcap, ) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 1d013caefd5..351b8692d5d 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,26 +1,23 @@ # Copyright (c) 2025, Tri Dao. +from typing import Optional +from dataclasses import dataclass + import cutlass import cutlass.cute as cute import flash_attn.cute.utils as utils +@dataclass(frozen=True) class AttentionMask: - - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - n_block_size: cutlass.Constexpr[int], - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # only pass in if we're doing PackGQA - ): - self.m_block_size = m_block_size - self.n_block_size = n_block_size - self.seqlen_q = seqlen_q - self.seqlen_k = seqlen_k - self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA @cute.jit def apply_mask( @@ -29,9 +26,11 @@ def apply_mask( m_block: cutlass.Int32, n_block: cutlass.Int32, thr_mma: cute.TiledMma, - mask_seqlen: cutlass.Constexpr, - mask_causal: cutlass.Constexpr, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) @@ -40,35 +39,78 @@ def apply_mask( t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) thr_col_offset = tScS_mn[0][1] seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset - if not mask_causal: - if mask_seqlen: + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): # traverse column index. for c in range(cute.size(tScS_mn.shape[1])): if t0ScS_mn[0, c][1] >= seqlenk_col_limit: acc_S_mn[None, c].fill(-cutlass.Float32.inf) - else: # Causal + else: # Causal or local # If PackGQA, we split the work of compute divmod among threads in the same row threads_per_row = thr_mma.tv_layout_C.shape[0][0] - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): - assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + assert cute.arch.WARP_SIZE % threads_per_row == 0, ( + "threads_per_row must divide WARP_SIZE" + ) assert cute.size(acc_S_mn.shape[0]) <= threads_per_row tidx = thr_mma.thr_idx - mma_m_idx = (m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0]) // self.qhead_per_kvhead_packgqa - causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset - for r in range(cute.size(tScS_mn.shape[0])): - # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. - if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size - else: - row_idx = utils.shuffle_sync(mma_m_idx, r % threads_per_row, width=threads_per_row) - col_limit_right = row_idx + causal_row_offset - if cutlass.const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): - # only consider the column index, so the row index sets to 0. - if t0ScS_mn[0, c][1] >= col_limit_right: - acc_S_mn[r, c] = -cutlass.Float32.inf + mma_m_idx = ( + m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset + ) + if cutlass.const_expr(mask_causal): + for r in range(cute.size(tScS_mn.shape[0])): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # traverse column index. + for c in range(cute.size(tScS_mn.shape[1])): + # only consider the column index, so the row index sets to 0. + if t0ScS_mn[0, c][1] >= col_limit_right: + acc_S_mn[r, c] = -cutlass.Float32.inf + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if self.window_size_right is not None + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if self.window_size_left is not None + else None + ) + for r in range(cute.size(tScS_mn.shape[0])): + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left if self.window_size_left is not None else 0 + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) + # traverse column index. + for c in range(cute.size(tScS_mn.shape[1])): + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -cutlass.Float32.inf @cute.jit def apply_mask_sm100( @@ -76,34 +118,61 @@ def apply_mask_sm100( acc_S: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, - m_stage: cutlass.Int32, thr_mma: cute.TiledMma, thr_tmem_load: cute.TiledCopy, mask_seqlen: cutlass.Constexpr, mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - if not mask_causal: - if mask_seqlen: + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): for i in range(cute.size(tScS_t2r.shape)): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf # For some reason the 2 lines above generate really bad SASS - acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] - else: # Causal - assert self.qhead_per_kvhead_packgqa == 1, "PackGQA not supported for SM100 yet" + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + ) + else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - row_idx = tScS_t2r[0][0] + (m_block * 2 + m_stage) * self.m_block_size - col_limit_right = row_idx + causal_row_offset - if cutlass.const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - # if cute.arch.thread_idx()[0] % 32 == 0: - # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) - for i in range(cute.size(tScS_t2r.shape)): - # if tScS_t2r[i][1] >= col_limit_right: - # acc_S[i] = -cutlass.Float32.inf - # For some reason the 2 lines above generate really bad SASS - acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + row_idx = tScS_t2r[0][0] + m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + row_idx = row_idx // self.qhead_per_kvhead_packgqa + if cutlass.const_expr(mask_causal): + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + for i in range(cute.size(tScS_t2r.shape)): + acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + + else: + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if self.window_size_right is not None + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if self.window_size_left is not None + else None + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left if self.window_size_left is not None else 0 + ) + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in range(cute.size(tScS_t2r.shape)): + col_idx = tScS_t2r[i][1] + acc_S[i] = -cutlass.Float32.inf if col_idx >= col_limit_right or col_idx < col_limit_left else acc_S[i] diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py index 772f955dedb..b2c03addd2b 100644 --- a/flash_attn/utils/testing.py +++ b/flash_attn/utils/testing.py @@ -158,7 +158,7 @@ def generate_qkv( def construct_local_mask( seqlen_q, seqlen_k, - window_size=(-1, -1), # -1 means infinite window size + window_size=(None, None), sink_token_length=0, query_padding_mask=None, key_padding_mask=None, @@ -181,7 +181,7 @@ def construct_local_mask( if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) - if window_size[0] < 0: + if window_size[0] is None: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk @@ -237,7 +237,7 @@ def attention_ref( causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, - window_size=(-1, -1), # -1 means infinite window size + window_size=(None, None), attention_chunk=0, sink_token_length=0, softcap=0.0, @@ -297,7 +297,7 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) local_mask = None - if window_size[0] >= 0 or window_size[1] >= 0: + if window_size[0] is not None or window_size[1] is not None: local_mask = construct_local_mask( seqlen_q, seqlen_k, diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 552b5c6fc5e..f19080fc001 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -27,10 +27,10 @@ @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -38,8 +38,8 @@ # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) -@pytest.mark.parametrize("d", [64, 128]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -69,7 +69,7 @@ def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): - if causal and seqlen_k < seqlen_q: + if (causal or local) and seqlen_k < seqlen_q: pytest.skip("Causal attention requires seqlen_k >= seqlen_q") device = "cuda" # set seed @@ -99,7 +99,7 @@ def test_flash_attn_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] @@ -165,7 +165,7 @@ def test_flash_attn_output( causal=causal, # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - # window_size=window_size, + window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, # pack_gqa=pack_gqa, @@ -187,6 +187,7 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 + and not local # and False ): g = torch.randn_like(out) From de2ce8f3beea50dac2d88dc08764963e43ceb757 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Jul 2025 22:48:04 -0400 Subject: [PATCH 012/258] [Cute] Add ruff options --- flash_attn/cute/pyproject.toml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 flash_attn/cute/pyproject.toml diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml new file mode 100644 index 00000000000..585c50079a3 --- /dev/null +++ b/flash_attn/cute/pyproject.toml @@ -0,0 +1,8 @@ +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +ignore = [ + "E731", # do not assign a lambda expression, use a def + "F841", # local variable is assigned to but never used +] \ No newline at end of file From 217c9d34d951ad23523a941640eddffe619a6992 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Jul 2025 22:55:21 -0400 Subject: [PATCH 013/258] [Cute] Run ruff on utility files --- flash_attn/cute/ampere_helpers.py | 34 +++++++++--- flash_attn/cute/mask.py | 10 +++- flash_attn/cute/mma_sm100_desc.py | 86 ++++++++++++++++--------------- flash_attn/cute/pack_gqa.py | 15 ++++-- flash_attn/cute/pipeline.py | 26 ++++------ flash_attn/cute/seqlen_info.py | 13 +++-- flash_attn/cute/softmax.py | 36 +++++++------ flash_attn/cute/utils.py | 79 ++++++++++++++++------------ 8 files changed, 179 insertions(+), 120 deletions(-) diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py index 41238edc365..804d052a78b 100644 --- a/flash_attn/cute/ampere_helpers.py +++ b/flash_attn/cute/ampere_helpers.py @@ -8,8 +8,16 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: dtype_byte = dtype.width // 8 bytes_per_row = k_dim * dtype_byte - smem_k_block_size = (128 if bytes_per_row % 128 == 0 else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))) // dtype_byte - swizzle_bits = 4 if smem_k_block_size == 128 else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + smem_k_block_size = ( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) // dtype_byte + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), @@ -34,8 +42,18 @@ def gemm( ) -> None: if swap_AB: gemm( - tiled_mma, acc, tCrB, tCrA, tCsB, tCsA, smem_thr_copy_B, smem_thr_copy_A, hook_fn, - A_in_regs=B_in_regs, B_in_regs=A_in_regs, swap_AB=False + tiled_mma, + acc, + tCrB, + tCrA, + tCsB, + tCsA, + smem_thr_copy_B, + smem_thr_copy_A, + hook_fn, + A_in_regs=B_in_regs, + B_in_regs=A_in_regs, + swap_AB=False, ) else: tCrA_copy_view = smem_thr_copy_A.retile(tCrA) @@ -47,9 +65,13 @@ def gemm( for k in range(cute.size(tCsA.shape[2])): if k < cute.size(tCsA.shape[2]) - 1: if not A_in_regs: - cute.copy(smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]) + cute.copy( + smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] + ) if not B_in_regs: - cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) + cute.copy( + smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] + ) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): hook_fn() diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 351b8692d5d..be04357c695 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -150,7 +150,9 @@ def apply_mask_sm100( # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) for i in range(cute.size(tScS_t2r.shape)): - acc_S[i] = -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + ) else: local_row_offset_right = ( @@ -175,4 +177,8 @@ def apply_mask_sm100( # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) for i in range(cute.size(tScS_t2r.shape)): col_idx = tScS_t2r[i][1] - acc_S[i] = -cutlass.Float32.inf if col_idx >= col_limit_right or col_idx < col_limit_left else acc_S[i] + acc_S[i] = ( + -cutlass.Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py index 0170f0e99ae..62f1bc742e1 100644 --- a/flash_attn/cute/mma_sm100_desc.py +++ b/flash_attn/cute/mma_sm100_desc.py @@ -13,36 +13,36 @@ # --------------------------------------------------------------------------- -class Major(IntEnum): # matrix “layout” in the ISA docs - K = 0 +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 MN = 1 -class ScaleIn(IntEnum): # negate flags +class ScaleIn(IntEnum): # negate flags One = 0 Neg = 1 class Saturate(IntEnum): False_ = 0 - True_ = 1 + True_ = 1 -class CFormat(IntEnum): # 2-bit field (bits 4-5) +class CFormat(IntEnum): # 2-bit field (bits 4-5) F16 = 0 F32 = 1 S32 = 2 -class F16F32Format(IntEnum): # 3-bit field (A/B element type) - F16 = 0 +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 BF16 = 1 TF32 = 2 class S8Format(IntEnum): UINT8 = 0 - INT8 = 1 + INT8 = 1 class MXF8F6F4Format(IntEnum): @@ -54,8 +54,8 @@ class MXF8F6F4Format(IntEnum): class MaxShift(IntEnum): - NoShift = 0 - MaxShift8 = 1 + NoShift = 0 + MaxShift8 = 1 MaxShift16 = 2 MaxShift32 = 3 @@ -64,6 +64,7 @@ class MaxShift(IntEnum): # CUTLASS-type → encoding helpers # --------------------------------------------------------------------------- + def to_UMMA_format(cutlass_type) -> int: """ Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. @@ -106,18 +107,19 @@ def to_C_format(cutlass_type) -> int: # The constructor – accepts only CUTLASS scalar classes # --------------------------------------------------------------------------- + def make_instr_desc( - a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 b_type, c_type, - M: int, # 64, 128 or 256 - N: int, # 8 … 256 (multiple of 8) - a_major: Major, - b_major: Major, - a_neg: ScaleIn = ScaleIn.One, - b_neg: ScaleIn = ScaleIn.One, - c_sat: Saturate = Saturate.False_, - is_sparse: bool = False, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, max_shift: MaxShift = MaxShift.NoShift, ) -> int: """ @@ -170,26 +172,28 @@ def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): ) -class LayoutType(IntEnum): # occupies the top-3 bits [61:64) - SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) - SWIZZLE_128B_BASE32B = 1 - SWIZZLE_128B = 2 - SWIZZLE_64B = 4 - SWIZZLE_32B = 6 +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 # values 3,5,7 are reserved / illegal for UMMA + # --------------------------------------------------------------------------- # Helpers – figure out the SWIZZLE_* family from the tensor layout # --------------------------------------------------------------------------- + def _layout_type(swizzle: cute.Swizzle) -> LayoutType: # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ # Swizzle string has the form "S" swz_str = str(swizzle) - inside = swz_str[swz_str.index('<') + 1 : swz_str.index('>')] # '3,4,3' - B, M, S = [int(x) for x in inside.split(',')] # [3, 4, 3] + inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' + B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] - if M == 4: # Swizzle<*,4,3> + if M == 4: # Swizzle<*,4,3> if S != 3: raise ValueError("Unexpected swizzle shift – want S==3 for M==4") return { @@ -197,8 +201,8 @@ def _layout_type(swizzle: cute.Swizzle) -> LayoutType: 1: LayoutType.SWIZZLE_32B, 2: LayoutType.SWIZZLE_64B, 3: LayoutType.SWIZZLE_128B, - }[B] # KeyError ⇒ invalid B→ raise - if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) if (B, S) != (2, 2): raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") return LayoutType.SWIZZLE_128B_BASE32B @@ -214,11 +218,11 @@ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major layout must correspond to layout of an uint128 tensor. """ # ------------------------------------------------------------------ meta - layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family - VERSION = 1 # bits 46–47 - LBO_MODE = 0 # bit 52 - BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) # ---------------------------------------------------------- strides (units: uint128_t = 16 B) swizzle_atom_mn_size = { @@ -263,21 +267,21 @@ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major stride_byte_offset, leading_byte_offset = stride_01, stride_10 # ------------------------------------------------------------------ pack - desc = 0 + desc = 0 # leading_byte_offset_ [16:30) desc |= (leading_byte_offset & 0x3FFF) << 16 # stride_byte_offset_ [32:46) - desc |= (stride_byte_offset & 0x3FFF) << 32 + desc |= (stride_byte_offset & 0x3FFF) << 32 # version_ [46:48) - desc |= (VERSION & 0x3) << 46 + desc |= (VERSION & 0x3) << 46 # base_offset_ [49:52) - desc |= (BASE_OFFSET & 0x7) << 49 + desc |= (BASE_OFFSET & 0x7) << 49 # lbo_mode_ [52:53) - desc |= (LBO_MODE & 0x1) << 52 + desc |= (LBO_MODE & 0x1) << 52 # layout_type_ [61:64) - desc |= (int(layout_type) & 0x7) << 61 + desc |= (int(layout_type) & 0x7) << 61 - return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index a2dafa73c2f..9d2d43e0a6f 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -10,7 +10,6 @@ class PackGQA: - def __init__( self, m_block_size: cutlass.Constexpr[int], @@ -71,7 +70,10 @@ def load_Q( q_gmem_ptr = cute.make_ptr( mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) - if t0QcQ[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]: + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tQsQ.shape[0][0]) mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) @@ -107,7 +109,9 @@ def store_LSE( tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) for m in range(cute.size(tLSErLSE)): lse_ptr_i64 = utils.shuffle_sync( - tPrLSEPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row, + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, ) lse_gmem_ptr = cute.make_ptr( mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 @@ -145,7 +149,10 @@ def store_O( o_gmem_ptr = cute.make_ptr( mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) - if t0OcO[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]: + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tOrO.shape[0][0]) mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 3df229c4f3e..775e1754b3d 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -88,24 +88,20 @@ def make_pipeline_state(type: PipelineUserType, stages: int): elif type is PipelineUserType.Consumer: return PipelineStateSimple(stages, Int32(0)) else: - assert ( - False - ), "Error: invalid PipelineUserType specified for make_pipeline_state." - + assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." @dataclass(frozen=True) class PipelineTmaAsyncNoCluster(PipelineAsync): - """ - If size(ClusterShape) == 1, PipelineTmaAsync has all threads - signaling the barrier during consumer_release. This causes a perf regression in FA3 - forward pass (especially hdim 128 causal). We instead implement a version of - PipelineTmaAsync where only 1 out of 128 threads signals the barrier. - - Assumptions: - (1) num_consumers % NumThreadsPerWarpGroup == 0 - (2) all 128 threads in the warp group are sync'ed right before calling consumer_release + If size(ClusterShape) == 1, PipelineTmaAsync has all threads + signaling the barrier during consumer_release. This causes a perf regression in FA3 + forward pass (especially hdim 128 causal). We instead implement a version of + PipelineTmaAsync where only 1 out of 128 threads signals the barrier. + + Assumptions: + (1) num_consumers % NumThreadsPerWarpGroup == 0 + (2) all 128 threads in the warp group are sync'ed right before calling consumer_release """ @staticmethod @@ -152,9 +148,7 @@ def create( dst_rank, ) - def producer_acquire( - self, state: PipelineState, try_acquire_token: Optional[Boolean] = None - ): + def producer_acquire(self, state: PipelineState, try_acquire_token: Optional[Boolean] = None): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 6316e5ee814..8d7eb904c8b 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -5,7 +5,6 @@ class SeqlenInfo: - def __init__( self, batch_idx: cutlass.Int32, @@ -21,10 +20,18 @@ def __init__( if cutlass.const_expr(mSeqUsedQ is not None): self.seqlen_q = mSeqUsedQ[batch_idx] else: - self.seqlen_q = seqlen_q_static if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx + 1] - self.offset_q + self.seqlen_q = ( + seqlen_q_static + if cutlass.const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - self.offset_q + ) if cutlass.const_expr(mSeqUsedK is not None): self.seqlen_k = mSeqUsedK[batch_idx] else: - self.seqlen_k = seqlen_k_static if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - self.offset_k + self.seqlen_k = ( + seqlen_k_static + if cutlass.const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - self.offset_k + ) self.has_cu_seqlens_q: int = mCuSeqlensQ is not None self.has_cu_seqlens_k: int = mCuSeqlensK is not None diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index cb9bd1c897f..f94f8579e87 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -12,7 +12,6 @@ class Softmax: - def __init__( self, scale_log2: Float32, @@ -29,16 +28,12 @@ def reset(self) -> None: self.row_sum.fill(0.0) def _compute_row_max( - self, - acc_S_row: cute.TensorSSA, - init_val: float | Float32 = -Float32.inf + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 = -Float32.inf ) -> Float32: return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) def _compute_row_sum( - self, - acc_S_row_exp: cute.TensorSSA, - init_val: float | Float32 = Float32.zero + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 = Float32.zero ) -> Float32: return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) @@ -81,7 +76,9 @@ def online_softmax( acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) - acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) + self.row_sum[r] * row_scale[r] + acc_S_row_sum = ( + self._compute_row_sum(acc_S_row_exp) + self.row_sum[r] * row_scale[r] + ) self.row_max[r] = row_max_cur self.row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) @@ -89,14 +86,15 @@ def online_softmax( @cute.jit def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: - """Finalize the online softmax by computing the scale and logsumexp. - """ + """Finalize the online softmax by computing the scale and logsumexp.""" # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(self.row_max, Float32) for r in range(cute.size(self.row_sum)): # if row_sum is zero or nan, set acc_O_mn_row to 1.0 - acc_O_mn_row_is_zero_or_nan = self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + acc_O_mn_row_is_zero_or_nan = ( + self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + ) row_scale[r] = ( cute.arch.rcp_approx(self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) ) * final_scale @@ -104,7 +102,8 @@ def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: LN2 = math.log(2.0) self.row_sum[r] = ( (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf ) return row_scale @@ -123,7 +122,6 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: class SoftmaxSm100(Softmax): - def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0): super().__init__(scale_log2, num_rows=1, arch=100) self.rescale_threshold = rescale_threshold @@ -149,7 +147,9 @@ def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Floa self.row_max[0] = row_max_new return row_max_safe, acc_scale - def update_row_sum(self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False) -> None: + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) @@ -181,7 +181,9 @@ def apply_exp2_convert( frg_cnt = cute.size(acc_S_row) // frg_tile assert cute.size(acc_S_row) % frg_tile == 0 acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) - acc_S_row_converted_frg = cute.logical_divide(acc_S_row_converted, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) for j in range(frg_cnt): for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) @@ -221,7 +223,9 @@ def scale_apply_exp2_convert( frg_cnt = cute.size(acc_S_row) // frg_tile assert cute.size(acc_S_row) % frg_tile == 0 acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) - acc_S_row_converted_frg = cute.logical_divide(acc_S_row_converted, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) for j in range(frg_cnt): for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index c2de62897e9..af6a8c7332a 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -74,14 +74,19 @@ def mma_make_fragment_B( return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) -def get_smem_store_atom(arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric]) -> cute.CopyAtom: +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] +) -> cute.CopyAtom: if arch < 90: return cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), element_type, num_bits_per_copy=2 * element_type.width, + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, ) else: return cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), element_type, + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + element_type, ) @@ -94,7 +99,7 @@ def max_constexpr( def warp_reduce( val: cute.TensorSSA | cute.Numeric, op: Callable, - width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: if isinstance(val, cute.TensorSSA): res = cute.make_fragment(val.shape, val.dtype) @@ -117,12 +122,20 @@ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: acc_layout_mn = cute.make_layout( ( (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M - (acc_layout_col_major.shape[0][0], *acc_layout_col_major.shape[0][2:], acc_layout_col_major.shape[2]), # MMA_N + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N *acc_layout_col_major.shape[3:], ), stride=( (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M - (acc_layout_col_major.stride[0][0], *acc_layout_col_major.stride[0][2:], acc_layout_col_major.stride[2]), # MMA_N + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N *acc_layout_col_major.stride[3:], ), ) @@ -154,8 +167,7 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: def transpose_view(a: cute.Tensor) -> cute.Tensor: - """Transpose the first two dimensions of a tensor on smem. - """ + """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) order = (1, 0, *range(2, cute.rank(a))) return cute.composition(a, cute.make_ordered_layout(shape, order=order)) @@ -210,7 +222,9 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: @dsl_user_op -def fmax(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: return Float32( nvvm.fmax( T.f32(), @@ -224,9 +238,7 @@ def fmax(a: float | Float32, b: float | Float32, c: float | Float32 | None = Non def fmax_reduce( - x: cute.TensorSSA, - init_val: float | Float32 | None = None, - arch: cutlass.Constexpr[int] = 80 + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): if cutlass.const_expr(init_val is None): @@ -238,7 +250,9 @@ def fmax_reduce( res = cute.make_fragment(x.shape, Float32) res.store(x) local_max = [ - fmax(init_val, res[0], res[1]) if cutlass.const_expr(init_val is not None) else fmax(res[0], res[1]), + fmax(init_val, res[0], res[1]) + if cutlass.const_expr(init_val is not None) + else fmax(res[0], res[1]), fmax(res[2], res[3]), fmax(res[4], res[5]), fmax(res[6], res[7]), @@ -253,9 +267,7 @@ def fmax_reduce( def fadd_reduce( - x: cute.TensorSSA, - init_val: float | Float32 | None = None, - arch: cutlass.Constexpr[int] = 80 + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): if cutlass.const_expr(init_val is None): @@ -264,7 +276,11 @@ def fadd_reduce( else: res = cute.make_fragment(x.shape, Float32) res.store(x) - local_sum_0 = cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) if cutlass.const_expr(init_val is not None) else (res[0], res[1]) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + if cutlass.const_expr(init_val is not None) + else (res[0], res[1]) + ) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] for i in range(8, cute.size(x.shape), 8): local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) @@ -278,9 +294,7 @@ def fadd_reduce( @dsl_user_op -def atomic_add_fp32( - a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None -) -> None: +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # # cache_hint = cutlass.Int64(0x12F0000000000000) # llvm.inline_asm( @@ -297,10 +311,7 @@ def atomic_add_fp32( # asm_dialect=llvm.AsmDialect.AD_ATT, # ) nvvm.atomicrmw( - res=T.f32(), - op=nvvm.AtomicOpKind.FADD, - ptr=gmem_ptr.llvm_ptr, - a=Float32(a).ir_value() + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() ) @@ -325,11 +336,15 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: @dsl_user_op -def barrier_sync(barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, - *, loc=None, ip=None) -> None: +def barrier_sync( + barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None +) -> None: llvm.inline_asm( None, - [cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip), cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip)], + [ + cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip), + cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip), + ], "bar.sync $0, $1;", "r,r", has_side_effects=True, @@ -339,15 +354,15 @@ def barrier_sync(barrier_id: int | cutlass.Int32, number_of_threads: int | cutla @dsl_user_op -def barrier_arrive(barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None) -> None: +def barrier_arrive( + barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None +) -> None: """ Arrive at a named barrier. """ barrier_id = cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip) number_of_threads = cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip) - nvvm.barrier_arrive( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) + nvvm.barrier_arrive(barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip) # llvm.inline_asm( # None, # [barrier_id, number_of_threads], @@ -405,7 +420,7 @@ def shuffle_sync( width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, *, loc=None, - ip=None + ip=None, ) -> cute.Numeric: assert value.width % 32 == 0, "value type must be a multiple of 32 bits" # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 From 3222ea302bed64ca4190e838675527ef257a1aff Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Jul 2025 22:57:26 -0400 Subject: [PATCH 014/258] [Cute] Run ruff on bwd_pre/postprocess.py --- flash_attn/cute/flash_bwd_postprocess.py | 48 +++++++++++++------- flash_attn/cute/flash_bwd_preprocess.py | 57 +++++++++++++++++------- 2 files changed, 75 insertions(+), 30 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 3662de580a6..616ea30e1e5 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -81,38 +81,48 @@ def _setup_attributes(self): num_bits_per_copy=universal_copy_bits, ) # We don't do bound checking for the gmem -> smem load so we just assert here. - assert (self.m_block_size * self.head_dim_padded // async_copy_elems_accum) % self.tiled_mma.size == 0 + assert ( + self.m_block_size * self.head_dim_padded // async_copy_elems_accum + ) % self.tiled_mma.size == 0 self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_async_copy_accum, cute.make_layout(self.tiled_mma.size), - cute.make_layout(async_copy_elems_accum) + cute.make_layout(async_copy_elems_accum), ) atom_universal_copy_accum = cute.make_copy_atom( # multiply by 4 for Sm90 - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width, + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=cutlass.Float32.width, ) self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_universal_copy_accum, cute.make_layout(self.tiled_mma.size), - cute.make_layout(1) # 4 for Sm90 + cute.make_layout(1), # 4 for Sm90 ) async_copy_elems = universal_copy_bits // self.dtype.width # atom_universal_copy: universal copy atom for dQ store atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, ) # tdQ_layout: thread layout for dQ store assert self.head_dim_padded % async_copy_elems == 0 - gmem_threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, - self.tiled_mma.size) + gmem_threads_per_row = math.gcd( + self.head_dim_padded // async_copy_elems, self.tiled_mma.size + ) assert self.tiled_mma.size % gmem_threads_per_row == 0 tdQ_layout = cute.make_ordered_layout( - (self.tiled_mma.size // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), + (self.tiled_mma.size // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), ) # Value layouts for copies vdQ_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_dQ = cute.make_tiled_copy_tv(atom_universal_copy, tdQ_layout, vdQ_layout) + self.gmem_tiled_copy_dQ = cute.make_tiled_copy_tv( + atom_universal_copy, tdQ_layout, vdQ_layout + ) # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQaccum / dQ # /////////////////////////////////////////////////////////////////////////////// @@ -126,7 +136,6 @@ def _setup_attributes(self): sdQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1) ) - @cute.jit def __call__( self, @@ -143,7 +152,11 @@ def __call__( raise TypeError("dQaccum tensor must be Float32") num_mma_warps = self.num_threads // 32 - AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if not self.dQ_swapAB else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + AtomLayoutdQ = ( + (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) + if not self.dQ_swapAB + else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + ) tiled_mma = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdQ, @@ -153,8 +166,10 @@ def __call__( self._setup_attributes() - smem_size = max(cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), - cute.size_in_bytes(self.dtype, self.sdQ_layout)) + smem_size = max( + cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), + cute.size_in_bytes(self.dtype, self.sdQ_layout), + ) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( @@ -202,7 +217,9 @@ def kernel( # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile(mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,)) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) blkdQ_shape = (self.m_block_size, self.head_dim_padded) gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) @@ -235,7 +252,8 @@ def kernel( # thr_mma = tiled_mma.get_slice(tidx) # print(tiled_mma) acc_shape = tiled_mma.partition_shape_C( - (self.m_block_size, self.head_dim_padded) if not dQ_swapAB + (self.m_block_size, self.head_dim_padded) + if not dQ_swapAB else (self.head_dim_padded, self.m_block_size) ) acc = cute.make_fragment(acc_shape, cutlass.Float32) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 21f209ed97f..c6955574083 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -73,33 +73,52 @@ def _setup_attributes(self): # Thread layouts for copies # We want kBlockKGmem to be a power of 2 so that when we do the summing, # it's just between threads in the same warp - gmem_k_block_size = 128 if self.head_dim_padded % 128 == 0 else (64 if self.head_dim_padded % 64 == 0 else (32 if self.head_dim_padded % 32 == 0 else 16)) + gmem_k_block_size = ( + 128 + if self.head_dim_padded % 128 == 0 + else ( + 64 + if self.head_dim_padded % 64 == 0 + else (32 if self.head_dim_padded % 32 == 0 else 16) + ) + ) universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.dtype.width # atom_universal_copy: universal copy atom for O & dO load atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, ) # tOdO_layout: thread layout for O & dO load self.gmem_threads_per_row = gmem_k_block_size // async_copy_elems assert self.num_threads % self.gmem_threads_per_row == 0 tOdO_layout = cute.make_ordered_layout( - (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row), order=(1, 0), + (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row), + order=(1, 0), ) # Value layouts for copies vOdO_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tOdO_layout, vOdO_layout) - self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv(atom_universal_copy, tOdO_layout, vOdO_layout) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOdO_layout, vOdO_layout + ) + self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv( + atom_universal_copy, tOdO_layout, vOdO_layout + ) async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width atom_universal_copy_accum = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, ) - assert (self.m_block_size * self.head_dim_padded // async_copy_elems_accum) % self.num_threads == 0 + assert ( + self.m_block_size * self.head_dim_padded // async_copy_elems_accum + ) % self.num_threads == 0 self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_universal_copy_accum, cute.make_layout(self.num_threads), - cute.make_layout(async_copy_elems_accum) + cute.make_layout(async_copy_elems_accum), ) @cute.jit @@ -202,7 +221,9 @@ def kernel( seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) if cutlass.const_expr(mLSE is not None): - gLSE = cute.local_tile(mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,)) + gLSE = cute.local_tile( + mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) lse = cutlass.Float32.inf if tidx < seqlen_q - m_block * self.m_block_size: lse = gLSE[tidx] @@ -229,15 +250,17 @@ def kernel( pred=tOpdO[None, m, None] if self.check_hdim_oob else None, ) # Sum across the "k" dimension - dpsum = ( - tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32) - ).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)) + dpsum = (tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32)).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) + ) dpsum = utils.warp_reduce(dpsum, operator.add, width=self.gmem_threads_per_row) dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), cutlass.Float32) dP_sum.store(dpsum) # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile(mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,)) + gdPsum = cute.local_tile( + mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) # Only the thread corresponding to column 0 writes out the lse to gmem if tOcO[0, 0, 0][1] == 0: for m in cutlass.range_constexpr(cute.size(dP_sum)): @@ -247,7 +270,9 @@ def kernel( # Clear dQaccum if cutlass.const_expr(mdQaccum is not None): blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile(mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,)) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) zero = cute.make_fragment_like(tQgQaccum) @@ -255,7 +280,9 @@ def kernel( cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) if cutlass.const_expr(mLSE is not None): - gLSElog2 = cute.local_tile(mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,)) + gLSElog2 = cute.local_tile( + mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.m_block_size: gLSElog2[tidx] = lse * LOG2_E if lse != -cutlass.Float32.inf else 0.0 From 62349eb3bef7ffc1c2651464ee7873af230bc1ff Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 2 Jul 2025 22:02:39 -0400 Subject: [PATCH 015/258] [Cute] Move tile scheduler to a separate file --- flash_attn/cute/flash_fwd.py | 3 +- flash_attn/cute/flash_fwd_sm100.py | 290 +++++++---------------------- flash_attn/cute/interface.py | 2 +- flash_attn/cute/pipeline.py | 3 - flash_attn/cute/tile_scheduler.py | 175 +++++++++++++++++ tests/cute/test_flash_attn.py | 10 +- 6 files changed, 254 insertions(+), 229 deletions(-) create mode 100644 flash_attn/cute/tile_scheduler.py diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 825965f9535..f2fa3e3c2f3 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1283,9 +1283,8 @@ def kernel( else: sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) if cutlass.const_expr(sP_layout is not None): - # sP_pi = storage.sP.get_tensor(sP_layout) + sP_pi = storage.sP.get_tensor(sP_layout) sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) - sP_pi = cute.make_tensor(sP.iterator, sP_layout) else: sP, sP_pi = None # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c0cccf6c1c1..f2b8235580f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -33,6 +33,7 @@ from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.tile_scheduler import TileSchedulerParams, SingleTileScheduler, StaticPersistentTileScheduler # class NamedBarrierFwd(enum.IntEnum): @@ -43,143 +44,19 @@ # PFull = enum.auto() # PEmpty = enum.auto() -class FmhaStaticTileSchedulerParams: - def __init__( - self, - is_persistent: bool, - problem_shape_mbh: cute.Shape, - *, - loc=None, - ip=None, - ): - self.is_persistent = is_persistent - self.problem_shape_mbh = problem_shape_mbh - self._loc = loc - self._ip = ip - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self.is_persistent, self.problem_shape_mbh]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [self.is_persistent, self.problem_shape_mbh], self._values_pos - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) +def get_tile_scheduler_cls(params: TileSchedulerParams) -> Callable: + """Returns the appropriate tile scheduler class based on the parameters.""" + if cutlass.const_expr(params.is_persistent): + return StaticPersistentTileScheduler + else: + return SingleTileScheduler -def create_fmha_static_tile_scheduler_params( - is_persistent: bool, - problem_shape_mbh: cute.Shape, -) -> FmhaStaticTileSchedulerParams: - return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) - - -class FmhaStaticTileScheduler: - - def __init__( - self, - params: FmhaStaticTileSchedulerParams, - current_work_linear_idx: cutlass.Int32, - blk_coord: cute.Coord, - grid_shape: cute.Shape, - *, - loc=None, - ip=None, - ): - self._params = params - self._blk_coord = blk_coord - self._grid_shape = grid_shape - self._is_persistent = params.is_persistent - self._current_work_linear_idx = current_work_linear_idx - self._problem_shape_mbh = cute.make_layout( - params.problem_shape_mbh, loc=loc, ip=ip - ) - self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) - self._is_first_block = True - self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) - self._loc = loc - self._ip = ip - - # called by host - @staticmethod - def get_grid_shape( - params: FmhaStaticTileSchedulerParams, - *, - loc=None, - ip=None, - ) -> cute.Shape: - if params.is_persistent: - hardware_info = cutlass.utils.HardwareInfo() - sm_count = hardware_info.get_device_multiprocessor_count() - return ( - cutlass.min( - sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) - ), - 1, - 1, - ) - else: - return params.problem_shape_mbh - - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - is_valid = ( - self._current_work_linear_idx < self._num_blocks - if self._is_persistent - else self._is_first_block - ) - - blk_coord = (0, 0, 0) - if self._is_persistent: - blk_coord = self._problem_shape_mbh.get_hier_coord( - self._current_work_linear_idx, loc=loc, ip=ip - ) - else: - blk_coord = self._blk_coord - - return cutlass.utils.WorkTileInfo(blk_coord, is_valid) - - def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) - - def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): - if self._is_persistent: - self._current_work_linear_idx += advance_count * self.num_persistent_sm - self._is_first_block = False - - def __extract_mlir_values__(self): - values = cutlass.extract_mlir_values(self._params) - values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) - values.extend(cutlass.extract_mlir_values(self._blk_coord)) - values.extend(cutlass.extract_mlir_values(self._grid_shape)) - return values - - def __new_from_mlir_values__(self, values): - assert len(values) == 10 - new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) - new_current_work_linear_idx = cutlass.new_from_mlir_values( - self._current_work_linear_idx, [values[3]] - ) - new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) - new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) - return FmhaStaticTileScheduler( - new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape - ) - - -def create_fmha_static_tile_scheduler( - params: FmhaStaticTileSchedulerParams, - blk_coord: cute.Coord, - grid_shape: cute.Shape, -) -> FmhaStaticTileScheduler: - return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) +def create_tile_scheduler( + params: TileSchedulerParams, +) -> SingleTileScheduler | StaticPersistentTileScheduler: + return get_tile_scheduler_cls(params).create(params) class FlashAttentionForwardSm100: @@ -223,7 +100,6 @@ def __init__( self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False - self.use_tma_O = True self.s0_s1_barrier = self.head_dim_padded in [64, 96] # Does S1 need to wait for S0 to finish self.softmax0_warp_ids = (0, 1, 2, 3) @@ -232,7 +108,7 @@ def __init__( self.mma_warp_id = 12 self.load_warp_id = 13 self.epilogue_warp_ids = (14,) - self.empty_warp_id = 15 + self.empty_warp_ids = (15,) SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS @@ -244,7 +120,7 @@ def __init__( self.mma_warp_id, self.load_warp_id, *self.epilogue_warp_ids, - self.empty_warp_id, + *self.empty_warp_ids, ) ) @@ -366,7 +242,7 @@ def __call__( if cutlass.const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa and False cta_group = tcgen05.CtaGroup.ONE # the intermediate tensor p is from tmem & mK-major @@ -398,19 +274,19 @@ def __call__( self.epi_tile = self.pv_mma_tiler[:2] - sQ_layout_staged = sm100_utils_basic.make_smem_layout_a( + sQ_layout = sm100_utils_basic.make_smem_layout_a( tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, ) - sK_layout_staged = sm100_utils_basic.make_smem_layout_b( + sK_layout = sm100_utils_basic.make_smem_layout_b( tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, ) - tP_layout_staged = sm100_utils_basic.make_smem_layout_a( + tP_layout = sm100_utils_basic.make_smem_layout_a( tiled_mma_pv, self.pv_mma_tiler, self.q_dtype, self.acc_stage, ) - sV_layout_staged = sm100_utils_basic.make_smem_layout_b( + sV_layout = sm100_utils_basic.make_smem_layout_b( tiled_mma_pv, self.pv_mma_tiler, self.v_dtype, self.kv_stage, ) - sO_layout_staged = sm100_utils_basic.make_smem_layout_epi( + sO_layout = sm100_utils_basic.make_smem_layout_epi( self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, ) @@ -418,50 +294,46 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - sQ_layout = cute.select(sQ_layout_staged, mode=[0, 1, 2]) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tma_tile_atom_A( tma_load_op, mQ, - sQ_layout, + cute.select(sQ_layout, mode=[0, 1, 2]), self.mma_tiler_qk, tiled_mma_qk, self.cluster_layout_vmnk.shape, ) # TMA load for K - sK_layout = cute.select(sK_layout_staged, mode=[0, 1, 2]) tma_atom_K, tma_tensor_K = cute.nvgpu.make_tma_tile_atom_B( tma_load_op, mK, - sK_layout, + cute.select(sK_layout, mode=[0, 1, 2]), self.mma_tiler_qk, tiled_mma_qk, self.cluster_layout_vmnk.shape, ) # TMA load for V - sV_layout = cute.select(sV_layout_staged, mode=[0, 1, 2]) tma_atom_V, tma_tensor_V = cute.nvgpu.make_tma_tile_atom_B( tma_load_op, mV, - sV_layout, + cute.select(sV_layout, mode=[0, 1, 2]), self.pv_mma_tiler, tiled_mma_pv, self.cluster_layout_vmnk.shape, ) - o_cta_v_layout = cute.composition( - cute.make_identity_layout(mO.shape), self.epi_tile - ) - sO_layout = cute.select(sO_layout_staged, mode=[0, 1]) + o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) # print(sO_layout.outer) - self.epilogue_warp_ids = (14,) if self.use_tma_O else (14, 15) + if not self.use_tma_O: + self.epilogue_warp_ids = (14, 15) + self.empty_warp_ids = () self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) if cutlass.const_expr(self.use_tma_O): tma_atom_O, mO = cpasync.make_tma_tile_atom( tma_store_op, mO, - sO_layout, + cute.select(sO_layout, mode=[0, 1]), o_cta_v_layout, ) gmem_tiled_copy_O = None @@ -481,8 +353,8 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, sQ_layout) - self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, sK_layout) + self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) self.tile_sched_params, grid = self._compute_grid(mO, self.cta_tiler, self.is_persistent) @@ -511,15 +383,15 @@ class SharedStorage: # Smem tensors sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size * (1 if mLSE is None else 2)] sO: cute.struct.Align[ - cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout_staged)], + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], self.buffer_align_bytes, ] sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout_staged)], + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], self.buffer_align_bytes, ] sK: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout_staged)], + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], self.buffer_align_bytes, ] @@ -560,11 +432,11 @@ class SharedStorage: softcap_val, window_size_left, window_size_right, - sQ_layout_staged, - sK_layout_staged, - tP_layout_staged, - sV_layout_staged, - sO_layout_staged, + sQ_layout, + sK_layout, + tP_layout, + sV_layout, + sO_layout, gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, @@ -599,15 +471,15 @@ def kernel( softcap_val: Optional[cutlass.Float32], window_size_left: Optional[cutlass.Int32], window_size_right: Optional[cutlass.Int32], - sQ_layout_staged: cute.ComposedLayout, - sK_layout_staged: cute.ComposedLayout, - tP_layout_staged: cute.ComposedLayout, - sV_layout_staged: cute.ComposedLayout, - sO_layout_staged: cute.ComposedLayout, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - tile_sched_params: FmhaStaticTileSchedulerParams, + tile_sched_params: TileSchedulerParams, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -667,7 +539,7 @@ def kernel( cute.arch.WARP_SIZE * len( ( - self.empty_warp_id, + *self.empty_warp_ids, self.load_warp_id, self.mma_warp_id, *self.epilogue_warp_ids, @@ -692,15 +564,15 @@ def kernel( # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) - sQ = storage.sQ.get_tensor(sQ_layout_staged.outer, swizzle=sQ_layout_staged.inner) - # sQ_pi = storage.sQ.get_tensor(sQ_layout_staged) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + # sQ_pi = storage.sQ.get_tensor(sQ_layout) # (MMA, MMA_K, MMA_D, PIPE) - sK = storage.sK.get_tensor(sK_layout_staged.outer, swizzle=sK_layout_staged.inner) - # sK_pi = storage.sK.get_tensor(sK_layout_staged) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # sK_pi = storage.sK.get_tensor(sK_layout) # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem - sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout_staged.inner), sV_layout_staged.outer) - sO = storage.sO.get_tensor(sO_layout_staged.outer, swizzle=sO_layout_staged.inner) + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) sScale = storage.sScale.get_tensor(cute.make_layout(256)) @@ -723,7 +595,7 @@ def kernel( tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) - tP = cute.make_tensor(tStS.iterator, tP_layout_staged.outer) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] tOrP0 = cute.make_tensor( @@ -762,9 +634,7 @@ def kernel( # LOAD # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.load_warp_id: - tile_scheduler = create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_scheduler = create_tile_scheduler(tile_sched_params) self.load( tile_scheduler, thr_mma_qk, @@ -787,18 +657,14 @@ def kernel( # MMA # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: # Alloc tmem buffer tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) if warp_idx == self.mma_warp_id: cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) cute.arch.sync_warp() - # tile_scheduler = create_fmha_static_tile_scheduler( - # tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - # ) self.mma( - # tile_scheduler, tiled_mma_qk, tiled_mma_pv, sQ, @@ -806,9 +672,9 @@ def kernel( sV, # sQ_pi.iterator, # sK_pi.iterator, - sQ_layout_staged.inner, - sK_layout_staged.inner, - sV_layout_staged.inner, + sQ_layout.inner, + sK_layout.inner, + sV_layout.inner, tStS0, tStS1, tOtO0, @@ -838,9 +704,7 @@ def kernel( # Epilogue # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: - tile_scheduler = create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_scheduler = create_tile_scheduler(tile_sched_params) self.epilogue_s2g(tile_scheduler, mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls) # /////////////////////////////////////////////////////////////////////////////// @@ -851,9 +715,7 @@ def kernel( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_max_reg_setting_offset, 0) cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) - tile_scheduler = create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_scheduler = create_tile_scheduler(tile_sched_params) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, @@ -1000,6 +862,7 @@ def load_Q(stage: int): kv_producer_state.advance() load_V(n_block, kv_producer_state) # Vi kv_producer_state.advance() + tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop @@ -1025,7 +888,6 @@ def mma( tOrP1: cute.Tensor, pipeline_kv: cutlass.utils.PipelineAsync, mbar_ptr: cute.Pointer, - # tile_scheduler, tile_sched_params, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1071,9 +933,7 @@ def mma( ) P_full_O_rescaled_phase = cutlass.Int32(0) - tile_scheduler = create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_scheduler = create_tile_scheduler(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1480,7 +1340,6 @@ def correction_loop( tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, softmax_scale_log2: cutlass.Float32, - # tile_scheduler, tile_sched_params, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1513,9 +1372,7 @@ def correction_loop( o_corr_consumer_phase = cutlass.Int32(0) corr_epi_producer_phase = cutlass.Int32(1) - tile_scheduler = create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_scheduler = create_tile_scheduler(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1817,10 +1674,9 @@ def epilogue_s2g( cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) else: - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epi_warp_ids)) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) - tOrO = cute.make_fragment_like(tOsO, self.dtype) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) @@ -1832,15 +1688,15 @@ def epilogue_s2g( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem # load acc O from smem to rmem for wider vectorization - # TODO: need stage - cute.autovec_copy(tOsO, tOrO) + tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size - tOcO[0][0]: cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], - tOgO[None, rest_m, None], + tOgO[None, rest_m, None, 2 * m_block + stage], pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) @@ -1906,15 +1762,13 @@ def _compute_grid( mO: cute.Tensor, cta_tiler: Tuple[int, int, int], is_persistent: bool, - ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + ) -> Tuple[TileSchedulerParams, Tuple[int, int, int]]: o_shape = mO.shape - tile_sched_params = create_fmha_static_tile_scheduler_params( + tile_sched_params = TileSchedulerParams( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2]), + cute.size(o_shape[3]), is_persistent, - ( - cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), - cute.size(o_shape[2]), - cute.size(o_shape[3]), - ), ) - grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + grid = get_tile_scheduler_cls(tile_sched_params).get_grid_shape(tile_sched_params) return tile_sched_params, grid diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index bbab8301522..cf49d0ef248 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -166,7 +166,7 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, qhead_per_kvhead=qhead_per_kvhead, - is_persistent=True, + is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 775e1754b3d..6efc1a96747 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -67,9 +67,6 @@ def advance(self): # [Int32], # ) - def __get_mlir_types__(self): - return [self._phase_index.type] - def __extract_mlir_values__(self): phase_index = self._phase_index return [phase_index.ir_value()] diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py new file mode 100644 index 00000000000..f6d7029bb82 --- /dev/null +++ b/flash_attn/cute/tile_scheduler.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +class TileSchedulerParams: + def __init__( + self, + # block_size: cutlass.Constexpr[int], + num_blocks: Int32, + num_head: Int32, + num_batch: Int32, + is_persistent: cutlass.Constexpr[bool] = False, + # qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if using packed GQA + *, + loc=None, + ip=None, + ): + # self.block_size = block_size + self.num_blocks = num_blocks + self.num_head = num_head + self.num_batch = num_batch + self.is_persistent = is_persistent + # self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + self._loc = loc + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.num_blocks, self.num_head, self.num_batch]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.num_blocks, self.num_head, self.num_batch], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return TileSchedulerParams( + # self.block_size, *(tuple(obj_list)), self.qhead_per_kvhead_packgqa, loc=self._loc + *(tuple(obj_list)), + self.is_persistent, + loc=self._loc, + ) + + +class SingleTileScheduler: + def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def create(params: TileSchedulerParams, *, loc=None, ip=None) -> "SingleTileScheduler": + blk_coord = cute.arch.block_idx() + return SingleTileScheduler(blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: TileSchedulerParams, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return params.num_blocks, params.num_head, params.num_batch + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + def __init__( + self, + num_blocks: Int32, + num_head: Int32, + total_blocks: Int32, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.num_blocks = num_blocks + self.num_head = num_head + self.total_blocks = total_blocks + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def create(params: TileSchedulerParams, *, loc=None, ip=None) -> "SingleTileScheduler": + tile_idx = cute.arch.block_idx()[0] + total_blocks = params.num_blocks * params.num_head * params.num_batch + return StaticPersistentTileScheduler( + params.num_blocks, params.num_head, total_blocks, tile_idx, loc=loc, ip=ip + ) + + # called by host + @staticmethod + def get_grid_shape( + params: TileSchedulerParams, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + total_blocks = params.num_blocks * params.num_head * params.num_batch + return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + hn_idx = self._tile_idx // self.num_blocks + block_idx = self._tile_idx - hn_idx * self.num_blocks + batch_idx = hn_idx // self.num_head + head_idx = hn_idx - batch_idx * self.num_head + is_valid = self._tile_idx < self.total_blocks + return cutlass.utils.WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.num_blocks, self.num_head, self.total_blocks, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.num_blocks, self.num_head, self.total_blocks, self._tile_idx], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index f19080fc001..268744f67fd 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -28,7 +28,7 @@ # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -238,10 +238,10 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -306,7 +306,7 @@ def test_flash_attn_varlen_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -423,7 +423,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # qv=qv_unpad, # q_descale=q_descale, # k_descale=k_descale, v_descale=v_descale, - # window_size=window_size, + window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, ) From 8d454a3a9336954dae75013958dc3903ce781b66 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 2 Jul 2025 23:26:14 -0400 Subject: [PATCH 016/258] [Cute] Add FastDivmod --- flash_attn/cute/fast_math.py | 97 ++++++++++++++++++++++++++++++ flash_attn/cute/flash_fwd_sm100.py | 5 +- flash_attn/cute/tile_scheduler.py | 84 ++++++++++++++++++++------ 3 files changed, 165 insertions(+), 21 deletions(-) create mode 100644 flash_attn/cute/fast_math.py diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py new file mode 100644 index 00000000000..b21573aa50d --- /dev/null +++ b/flash_attn/cute/fast_math.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Uint32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_dynamic(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range_dynamic(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res + + +def find_log2(x: Int32) -> Int32: + a: Int32 = Int32(31 - clz(x)) + return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. + + +@dsl_user_op +def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], + "mul.hi.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +class FastDivmod: + def __init__( + self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None + ): + self.divisor = divisor + self.multiplier = multipler + self.shift_right = shift_right + self._loc = loc + + # called by host + @staticmethod + def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod": + """Construct the FastDivmod object, in host code. + This precomputes some values based on the divisor and is computationally expensive. + """ + p = Uint32(31 + find_log2(divisor)) + divisor_u32 = Uint32(divisor) + multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) + shift_right = Uint32(p - 32) + return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip) + + @cute.jit + def div(self, dividend: Int32) -> Int32: + return ( + Int32(umulhi(dividend, self.multiplier) >> self.shift_right) + if self.divisor != 1 + else dividend + ) + + def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: + quotient = self.div(dividend) + remainder = dividend - quotient * self.divisor + return quotient, remainder + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.divisor, self.multiplier, self.shift_right]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.divisor, self.multiplier, self.shift_right], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FastDivmod(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index f2b8235580f..6491c480a8e 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -33,6 +33,7 @@ from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.fast_math import FastDivmod from flash_attn.cute.tile_scheduler import TileSchedulerParams, SingleTileScheduler, StaticPersistentTileScheduler @@ -242,7 +243,7 @@ def __call__( if cutlass.const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa and False + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa cta_group = tcgen05.CtaGroup.ONE # the intermediate tensor p is from tmem & mK-major @@ -1764,7 +1765,7 @@ def _compute_grid( is_persistent: bool, ) -> Tuple[TileSchedulerParams, Tuple[int, int, int]]: o_shape = mO.shape - tile_sched_params = TileSchedulerParams( + tile_sched_params = TileSchedulerParams.create( cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), cute.size(o_shape[2]), cute.size(o_shape[3]), diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index f6d7029bb82..6c3635b5dd5 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -6,31 +6,66 @@ import cutlass.cute as cute from cutlass import Int32 +from flash_attn.cute.fast_math import FastDivmod + class TileSchedulerParams: def __init__( self, # block_size: cutlass.Constexpr[int], - num_blocks: Int32, + num_block: Int32, num_head: Int32, num_batch: Int32, + num_block_divmod: FastDivmod, + num_head_divmod: FastDivmod, is_persistent: cutlass.Constexpr[bool] = False, - # qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if using packed GQA + # qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if using packed GPA *, loc=None, ip=None, ): # self.block_size = block_size - self.num_blocks = num_blocks + self.num_block = num_block self.num_head = num_head self.num_batch = num_batch + self.num_block_divmod = num_block_divmod + self.num_head_divmod = num_head_divmod self.is_persistent = is_persistent # self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa self._loc = loc + @staticmethod + def create( + num_block: Int32, + num_head: Int32, + num_batch: Int32, + is_persistent: cutlass.Constexpr[bool] = False, + *, + loc=None, + ip=None, + ) -> "TileSchedulerParams": + num_block_divmod = FastDivmod.create(num_block, loc=loc, ip=ip) + num_head_divmod = FastDivmod.create(num_head, loc=loc, ip=ip) + return TileSchedulerParams( + num_block, + num_head, + num_batch, + num_block_divmod, + num_head_divmod, + is_persistent, + loc=loc, + ip=ip, + ) + def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.num_blocks, self.num_head, self.num_batch]: + for obj in [ + self.num_block, + self.num_head, + self.num_batch, + self.num_block_divmod, + self.num_head_divmod, + ]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -38,7 +73,16 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.num_blocks, self.num_head, self.num_batch], self._values_pos): + for obj, n_items in zip( + [ + self.num_block, + self.num_head, + self.num_batch, + self.num_block_divmod, + self.num_head_divmod, + ], + self._values_pos, + ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return TileSchedulerParams( @@ -69,7 +113,7 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: - return params.num_blocks, params.num_head, params.num_batch + return params.num_block, params.num_head, params.num_batch def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) @@ -102,16 +146,16 @@ def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler: def __init__( self, - num_blocks: Int32, - num_head: Int32, + num_block_divmod: FastDivmod, + num_head_divmod: FastDivmod, total_blocks: Int32, tile_idx: Int32, *, loc=None, ip=None, ): - self.num_blocks = num_blocks - self.num_head = num_head + self.num_block_divmod = num_block_divmod + self.num_head_divmod = num_head_divmod self.total_blocks = total_blocks self._tile_idx = tile_idx self._loc = loc @@ -120,9 +164,9 @@ def __init__( @staticmethod def create(params: TileSchedulerParams, *, loc=None, ip=None) -> "SingleTileScheduler": tile_idx = cute.arch.block_idx()[0] - total_blocks = params.num_blocks * params.num_head * params.num_batch + total_blocks = params.num_block * params.num_head * params.num_batch return StaticPersistentTileScheduler( - params.num_blocks, params.num_head, total_blocks, tile_idx, loc=loc, ip=ip + params.num_block_divmod, params.num_head_divmod, total_blocks, tile_idx, loc=loc, ip=ip ) # called by host @@ -135,15 +179,16 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - total_blocks = params.num_blocks * params.num_head * params.num_batch + total_blocks = params.num_block * params.num_head * params.num_batch return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) + # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - hn_idx = self._tile_idx // self.num_blocks - block_idx = self._tile_idx - hn_idx * self.num_blocks - batch_idx = hn_idx // self.num_head - head_idx = hn_idx - batch_idx * self.num_head + hn_idx, block_idx = self.num_block_divmod.divmod(self._tile_idx) + batch_idx, head_idx = self.num_head_divmod.divmod(hn_idx) is_valid = self._tile_idx < self.total_blocks + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return cutlass.utils.WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid ) @@ -159,7 +204,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.num_blocks, self.num_head, self.total_blocks, self._tile_idx]: + for obj in [self.num_block_divmod, self.num_head_divmod, self.total_blocks, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -168,7 +213,8 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( - [self.num_blocks, self.num_head, self.total_blocks, self._tile_idx], self._values_pos + [self.num_block_divmod, self.num_head_divmod, self.total_blocks, self._tile_idx], + self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] From e94e0c25f2426e1b0aa25bed3e112f7c6e49c47d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 3 Jul 2025 00:31:11 -0400 Subject: [PATCH 017/258] [Cute] Refactor TileScheduler classes --- flash_attn/cute/flash_fwd_sm100.py | 37 ++++---- flash_attn/cute/tile_scheduler.py | 141 ++++++++++++++--------------- 2 files changed, 83 insertions(+), 95 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6491c480a8e..8797e61ab5a 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -34,7 +34,7 @@ from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod -from flash_attn.cute.tile_scheduler import TileSchedulerParams, SingleTileScheduler, StaticPersistentTileScheduler +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, ParamsBase # class NamedBarrierFwd(enum.IntEnum): @@ -46,20 +46,14 @@ # PEmpty = enum.auto() -def get_tile_scheduler_cls(params: TileSchedulerParams) -> Callable: +def get_tile_scheduler_cls(args: TileSchedulerArguments) -> Callable: """Returns the appropriate tile scheduler class based on the parameters.""" - if cutlass.const_expr(params.is_persistent): + if cutlass.const_expr(args.is_persistent): return StaticPersistentTileScheduler else: return SingleTileScheduler -def create_tile_scheduler( - params: TileSchedulerParams, -) -> SingleTileScheduler | StaticPersistentTileScheduler: - return get_tile_scheduler_cls(params).create(params) - - class FlashAttentionForwardSm100: arch = 100 @@ -357,7 +351,7 @@ def __call__( self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) - self.tile_sched_params, grid = self._compute_grid(mO, self.cta_tiler, self.is_persistent) + self.tile_scheduler_cls, self.tile_sched_params, grid = self._compute_grid(mO, self.cta_tiler, self.is_persistent) self.mbar_load_q_full_offset = 0 self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage @@ -480,7 +474,8 @@ def kernel( gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - tile_sched_params: TileSchedulerParams, + # tile_sched_params: TileSchedulerArguments, + tile_sched_params: ParamsBase, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -635,7 +630,7 @@ def kernel( # LOAD # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.load_warp_id: - tile_scheduler = create_tile_scheduler(tile_sched_params) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) self.load( tile_scheduler, thr_mma_qk, @@ -705,7 +700,7 @@ def kernel( # Epilogue # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: - tile_scheduler = create_tile_scheduler(tile_sched_params) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) self.epilogue_s2g(tile_scheduler, mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls) # /////////////////////////////////////////////////////////////////////////////// @@ -716,7 +711,7 @@ def kernel( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_max_reg_setting_offset, 0) cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) - tile_scheduler = create_tile_scheduler(tile_sched_params) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, @@ -934,7 +929,7 @@ def mma( ) P_full_O_rescaled_phase = cutlass.Int32(0) - tile_scheduler = create_tile_scheduler(tile_sched_params) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1373,7 +1368,7 @@ def correction_loop( o_corr_consumer_phase = cutlass.Int32(0) corr_epi_producer_phase = cutlass.Int32(1) - tile_scheduler = create_tile_scheduler(tile_sched_params) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1763,13 +1758,15 @@ def _compute_grid( mO: cute.Tensor, cta_tiler: Tuple[int, int, int], is_persistent: bool, - ) -> Tuple[TileSchedulerParams, Tuple[int, int, int]]: + ) -> Tuple[TileSchedulerArguments, Tuple[int, int, int]]: o_shape = mO.shape - tile_sched_params = TileSchedulerParams.create( + tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), cute.size(o_shape[2]), cute.size(o_shape[3]), is_persistent, ) - grid = get_tile_scheduler_cls(tile_sched_params).get_grid_shape(tile_sched_params) - return tile_sched_params, grid + tile_scheduler_cls = get_tile_scheduler_cls(tile_sched_args) + tile_sched_params = tile_scheduler_cls.to_underlying_arguments(tile_sched_args) + grid = tile_scheduler_cls.get_grid_shape(tile_sched_params) + return tile_scheduler_cls, tile_sched_params, grid diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 6c3635b5dd5..38d943b13e7 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Tri Dao. from typing import Optional, Tuple +from dataclasses import dataclass, fields import cutlass import cutlass.cute as cute @@ -9,91 +10,55 @@ from flash_attn.cute.fast_math import FastDivmod -class TileSchedulerParams: - def __init__( - self, - # block_size: cutlass.Constexpr[int], - num_block: Int32, - num_head: Int32, - num_batch: Int32, - num_block_divmod: FastDivmod, - num_head_divmod: FastDivmod, - is_persistent: cutlass.Constexpr[bool] = False, - # qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, # Only pass in if using packed GPA - *, - loc=None, - ip=None, - ): - # self.block_size = block_size - self.num_block = num_block - self.num_head = num_head - self.num_batch = num_batch - self.num_block_divmod = num_block_divmod - self.num_head_divmod = num_head_divmod - self.is_persistent = is_persistent - # self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa - self._loc = loc - - @staticmethod - def create( - num_block: Int32, - num_head: Int32, - num_batch: Int32, - is_persistent: cutlass.Constexpr[bool] = False, - *, - loc=None, - ip=None, - ) -> "TileSchedulerParams": - num_block_divmod = FastDivmod.create(num_block, loc=loc, ip=ip) - num_head_divmod = FastDivmod.create(num_head, loc=loc, ip=ip) - return TileSchedulerParams( - num_block, - num_head, - num_batch, - num_block_divmod, - num_head_divmod, - is_persistent, - loc=loc, - ip=ip, - ) +@dataclass +class ParamsBase: + """We require cutlass.Constexpr fields to come after the non-Constexpr fields""" def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] values, self._values_pos = [], [] - for obj in [ - self.num_block, - self.num_head, - self.num_batch, - self.num_block_divmod, - self.num_head_divmod, - ]: + for obj in non_constexpr_fields: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): + all_fields = [getattr(self, field.name) for field in fields(self)] + constexpr_fields = [f for f in all_fields if isinstance(f, cutlass.Constexpr)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] obj_list = [] for obj, n_items in zip( - [ - self.num_block, - self.num_head, - self.num_batch, - self.num_block_divmod, - self.num_head_divmod, - ], + non_constexpr_fields, self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return TileSchedulerParams( - # self.block_size, *(tuple(obj_list)), self.qhead_per_kvhead_packgqa, loc=self._loc - *(tuple(obj_list)), - self.is_persistent, - loc=self._loc, - ) + return self.__class__(*(tuple(obj_list)), *(tuple(constexpr_fields))) + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + is_persistent: cutlass.Constexpr[bool] = False class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params(args.num_block, args.num_head, args.num_batch) + def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): self._blk_coord = blk_coord self._is_first_block = True @@ -101,14 +66,18 @@ def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): self._ip = ip @staticmethod - def create(params: TileSchedulerParams, *, loc=None, ip=None) -> "SingleTileScheduler": + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": blk_coord = cute.arch.block_idx() return SingleTileScheduler(blk_coord, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( - params: TileSchedulerParams, + params: Params, *, loc=None, ip=None, @@ -144,6 +113,21 @@ def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmod.create(args.num_block), FastDivmod.create(args.num_head), total_blocks + ) + def __init__( self, num_block_divmod: FastDivmod, @@ -162,25 +146,32 @@ def __init__( self._ip = ip @staticmethod - def create(params: TileSchedulerParams, *, loc=None, ip=None) -> "SingleTileScheduler": + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": tile_idx = cute.arch.block_idx()[0] - total_blocks = params.num_block * params.num_head * params.num_batch return StaticPersistentTileScheduler( - params.num_block_divmod, params.num_head_divmod, total_blocks, tile_idx, loc=loc, ip=ip + params.num_block_divmod, + params.num_head_divmod, + params.total_blocks, + tile_idx, + loc=loc, + ip=ip, ) # called by host @staticmethod def get_grid_shape( - params: TileSchedulerParams, + params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - total_blocks = params.num_block * params.num_head * params.num_batch - return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) + return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: From 525fb4323bc0d2a02b640a1f8a9d5c48a5c59f1b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 3 Jul 2025 11:30:57 -0400 Subject: [PATCH 018/258] [Cute] Port SingleTileLPTScheduler from C++ to Python --- flash_attn/cute/flash_fwd_sm100.py | 9 +- flash_attn/cute/tile_scheduler.py | 190 +++++++++++++++++++++++++++-- 2 files changed, 184 insertions(+), 15 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 8797e61ab5a..e44f819156a 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -34,7 +34,7 @@ from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, ParamsBase # class NamedBarrierFwd(enum.IntEnum): @@ -51,7 +51,8 @@ def get_tile_scheduler_cls(args: TileSchedulerArguments) -> Callable: if cutlass.const_expr(args.is_persistent): return StaticPersistentTileScheduler else: - return SingleTileScheduler + # return SingleTileScheduler + return SingleTileLPTScheduler class FlashAttentionForwardSm100: @@ -1764,6 +1765,10 @@ def _compute_grid( cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), cute.size(o_shape[2]), cute.size(o_shape[3]), + cute.size(o_shape[0]), # TODO + o_shape[1], + o_shape[1], + 2, # TODO is_persistent, ) tile_scheduler_cls = get_tile_scheduler_cls(tile_sched_args) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 38d943b13e7..6421b64c4bd 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -7,13 +7,11 @@ import cutlass.cute as cute from cutlass import Int32 -from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.fast_math import FastDivmod, clz @dataclass class ParamsBase: - """We require cutlass.Constexpr fields to come after the non-Constexpr fields""" - def __extract_mlir_values__(self): all_fields = [getattr(self, field.name) for field in fields(self)] non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] @@ -25,17 +23,15 @@ def __extract_mlir_values__(self): return values def __new_from_mlir_values__(self, values): - all_fields = [getattr(self, field.name) for field in fields(self)] - constexpr_fields = [f for f in all_fields if isinstance(f, cutlass.Constexpr)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] - obj_list = [] - for obj, n_items in zip( - non_constexpr_fields, - self._values_pos, - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) values = values[n_items:] - return self.__class__(*(tuple(obj_list)), *(tuple(constexpr_fields))) + return self.__class__(**non_constexpr_fields, **constexpr_fields) @dataclass @@ -43,6 +39,10 @@ class TileSchedulerArguments(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + element_size: cutlass.Constexpr[int] = 2 is_persistent: cutlass.Constexpr[bool] = False @@ -210,3 +210,167 @@ def __new_from_mlir_values__(self, values): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + l2_minor_divmod: FastDivmod + l2_major_divmod: FastDivmod + l2_minor_residual_divmod: FastDivmod + num_hb_quotient: Int32 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler.Params": + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + log2_floor = lambda n: 31 - clz(n) + # swizzle is how many heads can fit in L2 + # Seems faster if swizzle if a power of 2 + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block_divmod=FastDivmod.create(args.num_block), + num_head_divmod=FastDivmod.create(args.num_head), + l2_minor_divmod=FastDivmod.create(swizzle), + l2_major_divmod=FastDivmod.create(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmod.create( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + ) + + def __init__( + self, + total_blocks: Int32, + num_block_divmod: FastDivmod, + num_head_divmod: FastDivmod, + l2_minor_divmod: FastDivmod, + l2_major_divmod: FastDivmod, + l2_minor_residual_divmod: FastDivmod, + num_hb_quotient: Int32, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.total_blocks = total_blocks + self.num_block_divmod = num_block_divmod + self.num_head_divmod = num_head_divmod + self.l2_minor_divmod = l2_minor_divmod + self.l2_major_divmod = l2_major_divmod + self.l2_minor_residual_divmod = l2_minor_residual_divmod + self.num_hb_quotient = num_hb_quotient + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTScheduler( + params.total_blocks, + params.num_block_divmod, + params.num_head_divmod, + params.l2_minor_divmod, + params.l2_major_divmod, + params.l2_minor_residual_divmod, + params.num_hb_quotient, + tile_idx, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = self.l2_major_divmod.divmod(self._tile_idx) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < self.num_hb_quotient: + block, bidhb_residual = self.l2_minor_divmod.divmod(l2_mod) + else: + block, bidhb_residual = self.l2_minor_residual_divmod.divmod(l2_mod) + # TODO: should this be l2_minor or l2_minor_residual? + bidhb_actual = bidhb * self.l2_minor_divmod.divisor + bidhb_residual + batch_idx, head_idx = self.num_head_divmod.divmod(bidhb_actual) + # Longest-processing-time-first + block = self.num_block_divmod.divisor - 1 - block + is_valid = self._tile_idx < self.total_blocks + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.total_blocks, + self.num_block_divmod, + self.num_head_divmod, + self.l2_minor_divmod, + self.l2_major_divmod, + self.l2_minor_residual_divmod, + self.num_hb_quotient, + self._tile_idx, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.total_blocks, + self.num_block_divmod, + self.num_head_divmod, + self.l2_minor_divmod, + self.l2_major_divmod, + self.l2_minor_residual_divmod, + self.num_hb_quotient, + self._tile_idx, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) From 60e1e89d33d6f57038b810937ebc9dca088d168c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 4 Jul 2025 12:28:59 -0400 Subject: [PATCH 019/258] [Cute] Update comment about cute version --- flash_attn/cute/interface.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index cf49d0ef248..cd01726f19a 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,7 +1,5 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-06-01] Initial version in Cute-DSL. -# Only support basic forward and backward pass for FlashAttention, optimized for Ampere. -# Lightly tested with headdim 128. +# [2025-06-01] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl. # Features not supported yet: # - varlen # - split (i.e. FlashDecoding) From 6a44198ea27e58d7590ce33a4e681c21dd342827 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 4 Jul 2025 16:46:34 -0400 Subject: [PATCH 020/258] [Cute] Update to cute-dsl 4.1.0.dev0 --- flash_attn/cute/ampere_helpers.py | 26 +- flash_attn/cute/blackwell_helpers.py | 42 ++-- flash_attn/cute/block_info.py | 4 +- flash_attn/cute/fast_math.py | 4 +- flash_attn/cute/flash_bwd.py | 100 ++++---- flash_attn/cute/flash_bwd_postprocess.py | 6 +- flash_attn/cute/flash_bwd_preprocess.py | 6 +- flash_attn/cute/flash_fwd.py | 292 +++++++++++------------ flash_attn/cute/flash_fwd_sm100.py | 195 +++++++-------- flash_attn/cute/hopper_helpers.py | 3 +- flash_attn/cute/interface.py | 2 +- flash_attn/cute/mask.py | 28 +-- flash_attn/cute/pack_gqa.py | 14 +- flash_attn/cute/pipeline.py | 27 +-- flash_attn/cute/softmax.py | 25 +- flash_attn/cute/tile_scheduler.py | 1 - flash_attn/cute/utils.py | 92 ++----- 17 files changed, 412 insertions(+), 455 deletions(-) diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py index 804d052a78b..839f407f75c 100644 --- a/flash_attn/cute/ampere_helpers.py +++ b/flash_attn/cute/ampere_helpers.py @@ -6,9 +6,9 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: - dtype_byte = dtype.width // 8 - bytes_per_row = k_dim * dtype_byte - smem_k_block_size = ( + dtype_byte = cutlass.const_expr(dtype.width // 8) + bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) + smem_k_block_size = cutlass.const_expr( 128 if bytes_per_row % 128 == 0 else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) @@ -22,10 +22,11 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.Compo return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), 0, - cute.make_ordered_layout((8 if k_dim % 32 == 0 else 16, smem_k_block_size), order=(1, 0)), + cute.make_ordered_layout((8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)), ) +@cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -40,7 +41,7 @@ def gemm( B_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: - if swap_AB: + if cutlass.const_expr(swap_AB): gemm( tiled_mma, acc, @@ -58,17 +59,17 @@ def gemm( else: tCrA_copy_view = smem_thr_copy_A.retile(tCrA) tCrB_copy_view = smem_thr_copy_B.retile(tCrB) - if not A_in_regs: + if cutlass.const_expr(not A_in_regs): cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) - if not B_in_regs: + if cutlass.const_expr(not B_in_regs): cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) - for k in range(cute.size(tCsA.shape[2])): + for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): if k < cute.size(tCsA.shape[2]) - 1: - if not A_in_regs: + if cutlass.const_expr(not A_in_regs): cute.copy( smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] ) - if not B_in_regs: + if cutlass.const_expr(not B_in_regs): cute.copy( smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] ) @@ -77,6 +78,7 @@ def gemm( hook_fn() +@cute.jit def gemm_rs( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -88,8 +90,8 @@ def gemm_rs( ) -> None: tCrB_copy_view = smem_thr_copy_B.retile(tCrB) cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) - for k in range(cute.size(tCrA.shape[2])): - if k < cute.size(tCrA.shape[2]) - 1: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 9a83f4a9998..ca9c4b77a88 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional +from typing import Optional, Tuple import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import tcgen05 @@ -22,7 +22,7 @@ def gemm( cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) -def i64_to_i32x2(i: int) -> tuple[int, int]: +def i64_to_i32x2(i: int) -> Tuple[int, int]: """Convert a 64-bit integer to a tuple of two 32-bit integers.""" return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF @@ -40,7 +40,7 @@ def gemm_ptx( zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if not is_ts: + if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else None @@ -50,7 +50,7 @@ def gemm_ptx( smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, - sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) @@ -61,7 +61,7 @@ def gemm_ptx( smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, - sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) @@ -139,7 +139,7 @@ def gemm_ptx_loop( zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if not is_ts: + if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout @@ -149,7 +149,7 @@ def gemm_ptx_loop( smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, - sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) @@ -160,7 +160,7 @@ def gemm_ptx_loop( smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, - sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) @@ -168,14 +168,14 @@ def gemm_ptx_loop( if cutlass.const_expr(not is_ts): offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 - for k in range(cute.size(tCrA.shape[2]))] + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] else: offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 - for k in range(cute.size(tCrA.shape[2]))] - offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))] offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 - for k in range(cute.size(tCrB.shape[2]))] - offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))] if cutlass.const_expr(not is_ts): smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) @@ -217,7 +217,7 @@ def gemm_ptx_loop( f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) - for k in range(1, cute.size(tCrA.shape[2])) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r", @@ -258,7 +258,7 @@ def gemm_ptx_loop( # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(1, cute.size(tCrA.shape[2])) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r", @@ -281,7 +281,7 @@ def gemm_ptx_partial( zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if not is_ts: + if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout @@ -291,7 +291,7 @@ def gemm_ptx_partial( smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, - sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) @@ -302,7 +302,7 @@ def gemm_ptx_partial( smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, - sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) @@ -432,7 +432,7 @@ def gemm_ptx_partial1( zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if not is_ts: + if cutlass.const_expr(not is_ts): assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) @@ -440,7 +440,7 @@ def gemm_ptx_partial1( smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, - sm100_desc.Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) @@ -451,7 +451,7 @@ def gemm_ptx_partial1( smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, - sm100_desc.Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else sm100_desc.Major.MN + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN )) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index a3505e5dbb5..2739a31c4ef 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -30,7 +30,7 @@ def get_n_block_min_max( if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx if self.is_causal else n_idx + self.window_size_right + n_idx_right = n_idx if cutlass.const_expr(self.is_causal) else n_idx + self.window_size_right n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.n_block_size)) n_block_min = 0 if cutlass.const_expr(self.is_local and self.window_size_left is not None): @@ -56,7 +56,7 @@ def get_n_block_min_causal_local_mask( n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_right = ( n_idx - if not self.is_local or self.window_size_right is None + if cutlass.const_expr(not self.is_local or self.window_size_right is None) else n_idx + self.window_size_right ) return cutlass.max(n_block_min, n_idx_right // self.n_block_size) diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py index b21573aa50d..943388fd291 100644 --- a/flash_attn/cute/fast_math.py +++ b/flash_attn/cute/fast_math.py @@ -11,14 +11,14 @@ @cute.jit def clz(x: Int32) -> Int32: - # for i in cutlass.range_dynamic(32): + # for i in cutlass.range_constexpr(32): # if (1 << (31 - i)) & x: # return Int32(i) # return Int32(32) # Early exit is not supported yet res = Int32(32) done = False - for i in cutlass.range_dynamic(32): + for i in cutlass.range(32): if ((1 << (31 - i)) & x) and not done: res = Int32(i) done = True diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 03d41b31e6b..3ae61ba08dc 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -262,25 +262,25 @@ def _setup_attributes(self): cute.make_layout(self.num_threads), cute.make_layout(1) ) - if self.qhead_per_kvhead > 1: + if cutlass.const_expr(self.qhead_per_kvhead > 1): self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum def _get_tiled_mma(self): num_mma_warps = self.num_threads // 32 - AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if not self.SdP_swapAB else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) + AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) tiled_mma_sdp = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutSdP, permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16), ) - AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if not self.dKV_swapAB else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) + AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) tiled_mma_dkv = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdKV, permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16), ) - AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if not self.dQ_swapAB else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) tiled_mma_dq = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdQ, @@ -293,7 +293,7 @@ def _get_shared_storage_cls(self): cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] sLSE_struct, sdPsum_struct = [ cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128] @@ -431,7 +431,7 @@ def kernel( m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) m_block_min = 0 - if self.is_causal: + if cutlass.const_expr(self.is_causal): m_block_min = max( (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, m_block_min, @@ -526,7 +526,7 @@ def kernel( tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) - LSEslice = (None, 0, None) if not self.SdP_swapAB else (0, None, None) + LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] @@ -672,7 +672,7 @@ def kernel( m_block = m_block_min assert self.num_stages_Q >= self.num_stages_dO - for stage in range(self.num_stages_Q): + for stage in cutlass.range_constexpr(self.num_stages_Q): if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): if stage == 0 or m_block + stage < m_block_max: load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) @@ -695,7 +695,7 @@ def kernel( smem_pipe_read_do = cutlass.Int32(0) smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) smem_pipe_write_do = cutlass.Int32(0) - for m_tile in cutlass.range_dynamic(m_block_min, m_block_max, unroll=1): + for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): compute_one_m_block( m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, mask_fn=mask_fn, @@ -738,7 +738,7 @@ def compute_one_m_block( mask_fn: Optional[Callable] = None, ): def load_Q_next(): - m_block_next = m_block + (self.num_stages_Q - 1 if self.num_stages_Q > 1 else 1) + m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1) if m_block_next < m_block_max: load_Q_LSE(m_block_next, smem_pipe_write_q) cute.arch.cp_async_commit_group() @@ -750,22 +750,22 @@ def load_dO_next(): # MMA S acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C( - (self.m_block_size, self.n_block_size) if not self.SdP_swapAB else (self.n_block_size, self.m_block_size) + (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size) ) acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_S.fill(0.0) - cute.arch.cp_async_wait_group(1 if self.num_stages_Q > 1 else 0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK, - smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], + smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.tSsK, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( - smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], tLSErLSE + smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) @@ -774,31 +774,31 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) - for r in range(cute.size(acc_S_mn, mode=[0])): + for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # MMA dP acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_dP.fill(0.0) - cute.arch.cp_async_wait_group(1 if self.num_stages_dO > 1 else 0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV, - smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], + smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.tdPsV, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, - hook_fn=load_Q_next if self.num_stages_Q > 1 else None, + hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None, swap_AB=self.SdP_swapAB, ) tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0]) cute.autovec_copy( - smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], tLSErdPsum + smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum ) acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) - for r in range(cute.size(acc_dP_mn, mode=[0])): + for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) @@ -823,7 +823,7 @@ def load_dO_next(): sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO, smem_copy_params.tdVsPt, - smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if self.num_stages_dO > 1 else 0], + smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, @@ -834,7 +834,7 @@ def load_dO_next(): # MMA dQ def dQ_mma(hook_fn): acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C( - (self.m_block_size, self.head_dim_padded) if not self.dQ_swapAB else (self.head_dim_padded, self.m_block_size) + (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size) ) acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32) acc_dQ.fill(0.0) @@ -850,7 +850,7 @@ def dQ_mma(hook_fn): tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(acc_dQ) - for i in range(cute.size(acc_dQ_atomic)): + for i in cutlass.range_constexpr(cute.size(acc_dQ_atomic)): utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) @@ -867,7 +867,7 @@ def dQ_mma(hook_fn): sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ, smem_copy_params.tdKsdSt, - smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if self.num_stages_Q > 1 else 0], + smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, @@ -959,7 +959,7 @@ def epilogue( gmem_tiled_copy_dK, tdKrdK[None, rest_m, None], tdKgdK[None, rest_m, None], - pred=tdKpdK[None, rest_m, None] if self.check_hdim_oob else None, + pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): if t0dVcdV[0, rest_m, 0][0] < mdV.shape[1] - n_block * self.n_block_size - tdVcdV[0][0]: @@ -967,7 +967,7 @@ def epilogue( gmem_tiled_copy_dV, tdVrdV[None, rest_m, None], tdVgdV[None, rest_m, None], - pred=tdVpdV[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, ) else: # qhead_per_kvhead > 1, do atomic add @@ -982,9 +982,9 @@ def epilogue( acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) - for i in range(cute.size(acc_dV_atomic)): + for i in cutlass.range_constexpr(cute.size(acc_dV_atomic)): utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) - for i in range(cute.size(acc_dK_atomic)): + for i in cutlass.range_constexpr(cute.size(acc_dK_atomic)): utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) @cute.jit @@ -1005,16 +1005,16 @@ def load_K( tKcK = gmem_thr_copy.partition_S(cK) t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK) tKpK = utils.predicate_k(tKcK, limit=headdim) - for n in range(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size: # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0] predicate = cute.make_fragment_like(tKpK[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tKpK[i, n, k] if self.check_hdim_oob else True) and predicate_n + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate, ) @@ -1034,16 +1034,16 @@ def load_V( tVcV = gmem_thr_copy.partition_S(cV) t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV) tVpV = utils.predicate_k(tVcV, limit=headdim) - for n in range(cute.size(tVsV.shape[1])): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: # Instead of using tVcV, we using t0VcV and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time. predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0] predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if self.check_hdim_oob else True) and predicate_n + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate, ) @@ -1065,31 +1065,31 @@ def load_Q_LSE( smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size: # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0] predicate = cute.make_fragment_like(tQpQ[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tQpQ[i, m, k] if self.check_hdim_oob else True) and predicate_m + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_Q, tQgQ[None, m, None, block], - tQsQ[None, m, None, smem_pipe_write_q if self.num_stages_Q > 1 else 0], + tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. - for m in range(cute.size(tLSEsLSE.shape[1])): + for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])): if tLSEcLSE[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_LSE, tLSEgLSE[None, m, block], - tLSEsLSE[None, m, smem_pipe_write_q if self.num_stages_Q > 1 else 0], + tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], ) @cute.jit @@ -1109,29 +1109,29 @@ def load_dO_dPsum( smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): - for m in range(cute.size(tdOsdO.shape[1])): + for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size: # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time. predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0] predicate = cute.make_fragment_like(tdOpdO[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tdOpdO[i, m, k] if self.check_hdim_oob else True) and predicate_m + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_dO, tdOgdO[None, m, None, block], - tdOsdO[None, m, None, smem_pipe_write_q if self.num_stages_dO > 1 else 0], + tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. - for m in range(cute.size(tdPsumgdPsum.shape[1])): + for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])): if tdPsumcdPsum[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_dPsum, tdPsumgdPsum[None, m, block], - tdPsumsdPsum[None, m, smem_pipe_write_q if self.num_stages_dO > 1 else 0], + tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], ) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 616ea30e1e5..9136dcd8460 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -154,7 +154,7 @@ def __call__( num_mma_warps = self.num_threads // 32 AtomLayoutdQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) - if not self.dQ_swapAB + if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) ) tiled_mma = cute.make_tiled_mma( @@ -253,7 +253,7 @@ def kernel( # print(tiled_mma) acc_shape = tiled_mma.partition_shape_C( (self.m_block_size, self.head_dim_padded) - if not dQ_swapAB + if cutlass.const_expr(not dQ_swapAB) else (self.head_dim_padded, self.m_block_size) ) acc = cute.make_fragment(acc_shape, cutlass.Float32) @@ -265,7 +265,7 @@ def kernel( # print(acc) # print(tdQsdQaccum) # ((1, 1), 64) # print(tdQrdQaccum) # ((1, 4), 4, 4) - for i in range(cute.size(tdQsdQaccum)): + for i in cutlass.range_constexpr(cute.size(tdQsdQaccum)): tdQrdQaccum[i] = tdQsdQaccum[i] # Convert tdQrdQaccum from fp32 to fp16/bf16 rdQ = cute.make_fragment_like(acc, self.dtype) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index c6955574083..7a2734ec205 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -233,7 +233,7 @@ def kernel( assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) - for m in range(cute.size(tOrO.shape[1])): + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): # Instead of using tOcO, we using t0OcO and subtract the offset from the limit # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: @@ -241,13 +241,13 @@ def kernel( gmem_thr_copy_O, tOgO[None, m, None], tOrO[None, m, None], - pred=tOpO[None, m, None] if self.check_hdim_oob else None, + pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) cute.copy( gmem_thr_copy_dO, tOgdO[None, m, None], tOrdO[None, m, None], - pred=tOpdO[None, m, None] if self.check_hdim_oob else None, + pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) # Sum across the "k" dimension dpsum = (tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32)).reduce( diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index f2fa3e3c2f3..11b34607a1d 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -14,6 +14,7 @@ import cutlass import cutlass.cute as cute +from cutlass import const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup import cutlass.utils.ampere_helpers as sm80_utils_basic import cutlass.utils.hopper_helpers as sm90_utils_basic @@ -146,19 +147,19 @@ def _check_type( mSeqUsedK_type: Type[cutlass.Numeric] | None, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mLSE_type not in [None, cutlass.Float32]): + if const_expr(mLSE_type not in [None, cutlass.Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): raise TypeError("cu_seqlens_q tensor must be Int32") - if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): raise TypeError("cu_seqlens_k tensor must be Int32") - if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): raise TypeError("seqused_q tensor must be Int32") - if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): raise TypeError("seqused_k tensor must be Int32") assert mQ_type == self.dtype @@ -179,7 +180,7 @@ def _setup_attributes(self): self.sO_layout = cute.tile_to_shape( sO_layout_atom, (self.m_block_size, self.head_dim_v_padded), (0, 1), ) - if cutlass.const_expr(sP_layout_atom is not None): + if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( sP_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), ) @@ -297,12 +298,12 @@ def epilogue( pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) # Write LSE from rmem -> gmem - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) gLSE_expanded_layout = cute.append( gLSE.layout, cute.make_layout((self.head_dim_v_padded,), stride=(0,)) @@ -321,7 +322,7 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) @@ -329,10 +330,10 @@ def epilogue( # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) # sync to make sure all smem stores are done - if cutlass.const_expr(self.use_tma_O): + if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - utils.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) tOsO, tOgO = cpasync.tma_partition( tma_atom_O, @@ -354,7 +355,7 @@ def epilogue( tOrO = cute.make_fragment_like(tOsO, self.dtype) # load acc O from smem to rmem for wider vectorization cute.autovec_copy(tOsO, tOrO) - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) @@ -367,7 +368,7 @@ def epilogue( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, ) else: pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @@ -391,7 +392,7 @@ def load_Q( tQcQ = gmem_thr_copy.partition_S(cQ) t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) tQpQ = utils.predicate_k(tQcQ, limit=headdim) - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. if t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]: @@ -399,7 +400,7 @@ def load_Q( gmem_thr_copy, tQgQ[None, m, None], tQsQ[None, m, None], - pred=tQpQ[None, m, None] if self.check_hdim_oob else None, + pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @@ -419,32 +420,32 @@ def load_K( ): # Do we need to check if we overshoot kBlockN when we load K? is_even_n_smem_k = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 - if cutlass.const_expr(need_predicates or not is_even_n_smem_k): + if const_expr(need_predicates or not is_even_n_smem_k): # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. - if cutlass.const_expr(is_even_n_smem_k): + if const_expr(is_even_n_smem_k): seqlen_limit = seqlen - block * self.n_block_size else: - if cutlass.const_expr(not need_predicates): + if const_expr(not need_predicates): seqlen_limit = self.n_block_size else: seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) seqlen_limit -= tKcK[0][0] - for n in range(cute.size(tKsK.shape[1])): + for n in cutlass.range_constepxr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: cute.copy( gmem_tiled_copy, tKgK[None, n, None, block], - tKsK[None, n, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tKpK[None, n, None] if self.check_hdim_oob else None, + tKsK[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. else: cute.copy( gmem_tiled_copy, tKgK[None, None, None, block], - tKsK[None, None, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tKpK if self.check_hdim_oob else None, + tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK if const_expr(self.check_hdim_oob) else None, ) @cute.jit @@ -463,30 +464,30 @@ def load_V( ): # Do we need to check if we overshoot kBlockN when we load V? is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 - if cutlass.const_expr(need_predicates or not is_even_n_smem_v): - for n in range(cute.size(tVsV.shape[1])): + if const_expr(need_predicates or not is_even_n_smem_v): + for n in cutlass.range_constepxr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: - predicate = tVpV[None, n, None] if self.check_hdim_v_oob else None - if cutlass.const_expr(need_predicates): + predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None + if const_expr(need_predicates): seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in range(cute.size(predicate.shape[1])): - for i in range(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if self.check_hdim_v_oob else True) and predicate_n + for k in cutlass.range_constepxr(cute.size(predicate.shape[1])): + for i in cutlass.range_constepxr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n cute.copy( gmem_tiled_copy, tVgV[None, n, None, block], - tVsV[None, n, None, smem_pipe_write if self.num_stages > 1 else 0], + tVsV[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], pred=predicate, ) else: cute.copy( gmem_tiled_copy, tVgV[None, None, None, block], - tVsV[None, None, None, smem_pipe_write if self.num_stages > 1 else 0], - pred=tVpV if self.check_hdim_v_oob else None, + tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tVpV if const_expr(self.check_hdim_v_oob) else None, ) @@ -518,7 +519,7 @@ def _get_shared_storage_cls(self): cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] @cute.struct @@ -532,7 +533,7 @@ class SharedStorageSharedQV: sQ: sQV_struct sK: sK_struct - return SharedStorageQKV if cutlass.const_expr(not self.Q_in_regs) else SharedStorageSharedQV + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( @@ -577,7 +578,7 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(softcap is not None): + if const_expr(softcap is not None): softmax_scale_log2 = softmax_scale * LOG2_E softcap_val = None else: @@ -644,7 +645,7 @@ def kernel( block_info = BlockInfo( self.m_block_size, self.n_block_size, self.is_causal, self.is_local, window_size_left, window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) seqlen = SeqlenInfo(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) @@ -672,7 +673,7 @@ def kernel( storage = smem.allocate(SharedStorage) sQ = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout) - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout) else: sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) @@ -723,7 +724,7 @@ def kernel( cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) tKcK = gmem_thr_copy_K.partition_S(cK) t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) - if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + if const_expr(self.head_dim_padded == self.head_dim_v_padded): tVcV = tKcK t0VcV = t0KcK else: @@ -734,7 +735,7 @@ def kernel( # use "if" on the mn dimension. # This is to reduce register pressure and gets 2-3% performance gain. tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) - if cutlass.const_expr(self.same_hdim_kv): + if const_expr(self.same_hdim_kv): tVpV = tKpK else: tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) @@ -761,7 +762,7 @@ def kernel( # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn # -inf to e.g. -50.0, which can affect the attention softmax. def scoremod_premask_fn(acc_S): - if cutlass.const_expr(softcap_val is not None): + if const_expr(softcap_val is not None): acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) compute_one_n_block = partial( @@ -779,7 +780,7 @@ def scoremod_premask_fn(acc_S): def preprocess_Q(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) - if cutlass.const_expr(self.Q_in_regs): + if const_expr(self.Q_in_regs): cute.arch.barrier() tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) @@ -787,22 +788,22 @@ def preprocess_Q(): # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and # read from smem_q to registers, then load V. # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. - if cutlass.const_expr(self.Q_in_regs): + if const_expr(self.Q_in_regs): load_K(n_block, smem_pipe_write=0, need_predicates=True) cute.arch.cp_async_commit_group() preprocess_Q() cute.arch.barrier() # Make sure all threads have read smem_q before loading V - for stage in range(self.num_stages): - if cutlass.const_expr(not self.Q_in_regs or stage > 0): + for stage in cutlass.range_constepxr(self.num_stages): + if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) cute.arch.cp_async_commit_group() - if stage < self.num_stages - 1: + if const_expr(stage < self.num_stages - 1): if stage == 0 or n_block - stage >= 0: load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) cute.arch.cp_async_commit_group() - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): preprocess_Q() # /////////////////////////////////////////////////////////////////////////////// @@ -816,7 +817,7 @@ def preprocess_Q(): mask = AttentionMask( self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, window_size_left, window_size_right, - self.qhead_per_kvhead if self.pack_gqa else 1, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, @@ -831,20 +832,18 @@ def preprocess_Q(): smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking - if self.is_causal or self.is_local: + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) - # Currently we can't do loop with negative step - # https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 2 - n_tile compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False)) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking - for n_tile in cutlass.range_dynamic(n_block, unroll=1): + for n_tile in cutlass.range(n_block, unroll=1): compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -904,7 +903,7 @@ def load_V_next(): sm80_utils.gemm( mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, smem_copy_params.tSsQ, - smem_copy_params.tSsK[None, None, None, smem_pipe_read if self.num_stages > 1 else 0], + smem_copy_params.tSsK[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, @@ -916,26 +915,26 @@ def load_K_next(): load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) cute.arch.cp_async_commit_group() # wait for smem tile V for O - if cutlass.const_expr(self.num_stages == 1): + if const_expr(self.num_stages == 1): sync() load_K_next() - if cutlass.const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) softmax.rescale_O(mma_params.acc_O, row_scale) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - if cutlass.const_expr(self.num_stages > 1): + if const_expr(self.num_stages > 1): sync() load_K_next() sm80_utils.gemm_rs( mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, - smem_copy_params.tOsVt[None, None, None, smem_pipe_read if self.num_stages > 1 else 0], + smem_copy_params.tOsVt[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], smem_copy_params.smem_thr_copy_V, # hook_fn=load_K_next, ) - # if cutlass.const_expr(self.num_stages > 1): + # if const_expr(self.num_stages > 1): # load_K_next() @@ -993,7 +992,7 @@ def _get_tiled_mma(self): def _get_shared_storage_cls(self): # If PackGQA, we use cp.async to load Q, so we want sQ to align to 1024 bytes - sQ_alignment = 128 if not self.pack_gqa else 1024 + sQ_alignment = 128 if const_expr(not self.pack_gqa) else 1024 sK_alignment = 128 sV_alignment = 128 sQ_struct, sK_struct, sV_struct = [ @@ -1003,9 +1002,9 @@ def _get_shared_storage_cls(self): (sQ_alignment, sK_alignment, sV_alignment) ) ] - cosize_sQV = utils.max_constexpr(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] - cosize_sP = cute.cosize(self.sP_layout) if self.sP_layout is not None else 0 + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] @@ -1031,7 +1030,7 @@ class SharedStorageSharedQV: sK: sK_struct sP: sP_struct - return SharedStorageQKV if cutlass.const_expr(not self.Q_in_regs) else SharedStorageSharedQV + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( @@ -1061,18 +1060,18 @@ def __call__( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) - QO_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [0, 2, 1] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) for t in (mQ, mO) ] - KV_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensK is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] - LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 @@ -1084,7 +1083,7 @@ def __call__( self.num_producer_regs = 24 # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if self.intra_wg_overlap else (self.num_mma_warp_groups == 2) + self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm self._setup_attributes() @@ -1096,45 +1095,45 @@ def __call__( self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) - tma_atom_Q, tma_tensor_Q = cpasync.make_tma_tile_atom( + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast ) - tma_atom_K, tma_tensor_K = cpasync.make_tma_tile_atom( + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), (self.n_block_size, self.head_dim_padded), 1 # No mcast for now ) - tma_atom_V, tma_tensor_V = cpasync.make_tma_tile_atom( + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), (self.n_block_size, self.head_dim_v_padded), 1 # No mcast for now ) - if cutlass.const_expr(self.use_tma_O): - tma_atom_O, mO = cpasync.make_tma_tile_atom( + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast ) else: tma_atom_O = None - if cutlass.const_expr(self.pack_gqa): + if const_expr(self.pack_gqa): shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) - if cutlass.const_expr(mLSE is not None): + if const_expr(mLSE is not None): shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( - cute.ceil_div(cute.size(mQ.shape[0]) if mCuSeqlensQ is None else max_seqlen_q, self.m_block_size), + cute.ceil_div(cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is None) else max_seqlen_q, self.m_block_size), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3] if mCuSeqlensQ is None else mCuSeqlensQ.shape[0] - 1), + cute.size(mQ.shape[3] if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ.shape[0] - 1), ) # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. # Right after this, we multiply by log2(e) before applying exp2. @@ -1142,18 +1141,18 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(softcap is None): + if const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E softcap_val = cutlass.Float32(softmax_scale / softcap) - if cutlass.const_expr(window_size_left is not None): + if const_expr(window_size_left is not None): window_size_left = cutlass.Int32(window_size_left) - if cutlass.const_expr(window_size_right is not None): + if const_expr(window_size_right is not None): window_size_right = cutlass.Int32(window_size_right) self.kernel( - tma_tensor_Q if not self.pack_gqa else mQ, + tma_tensor_Q if const_expr(not self.pack_gqa) else mQ, tma_tensor_K, tma_tensor_V, mO, @@ -1233,11 +1232,11 @@ def kernel( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - if cutlass.const_expr(self.use_tma_O): + if const_expr(self.use_tma_O): cpasync.prefetch_descriptor(tma_atom_O) smem = cutlass.utils.SmemAllocator() @@ -1248,13 +1247,13 @@ def kernel( if warp_idx == 0: # if tidx < 2: # # barrierO num threads should be self.num_mma_threads - # cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q, 1 if not self.pack_gqa else self.num_Q_load_threads) - # cute.arch.mbarrier_init_arrive_cnt(mbar_ptr_Q + 1, self.num_mma_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) + cute.arch.mbarrier_init(mbar_ptr_Q, 1 if const_expr(not self.pack_gqa) else self.num_Q_load_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread) - pipeline_kv_consumer_group = cutlass.utils.CooperativeGroup( - cutlass.utils.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) pipeline_k = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_K.data_ptr(), @@ -1278,11 +1277,11 @@ def kernel( # TODO: how to get sQ_pi for cp.async if pack_gqa? sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - if cutlass.const_expr(not self.Q_in_regs): + if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) - if cutlass.const_expr(sP_layout is not None): + if const_expr(sP_layout is not None): sP_pi = storage.sP.get_tensor(sP_layout) sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) else: @@ -1296,10 +1295,10 @@ def kernel( block_info = BlockInfo( self.m_block_size, self.n_block_size, self.is_causal, self.is_local, window_size_left, window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], + SeqlenInfo, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, @@ -1307,13 +1306,13 @@ def kernel( AttentionMaskCls = partial( AttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) seqlen = SeqlenInfoCls(batch_idx) # Can't early exit so we have to write it this way (under an if statement) if mCuSeqlensQ is None or m_block * self.n_block_size < seqlen.seqlen_q: - if cutlass.const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 + if const_expr(self.is_causal): # Longest tile first + m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 # TODO: return early if n_block_max == 0 # if self.is_causal: # if n_block_max <= 0: @@ -1397,8 +1396,8 @@ def load( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1406,20 +1405,20 @@ def load( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block, head_idx, batch_idx = cute.arch.block_idx() seqlen = SeqlenInfoCls(batch_idx) - if cutlass.const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(self.is_causal): # Longest tile first + m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 + if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, head_idx, batch_idx] else: mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) - head_idx_kv = head_idx // self.qhead_per_kvhead if not self.pack_gqa else head_idx - if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] else: mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, @@ -1443,20 +1442,20 @@ def load( cute.group_modes(gV, 0, 2), ) kv_producer_state = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Producer, self.num_stages + cutlass.pipeline.PipelineUserType.Producer, self.num_stages ) load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) if warp_idx_in_wg == 0: # load_Q - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes(mbar_ptr_Q, self.tma_copy_q_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - for i in cutlass.range_dynamic(n_block_max - n_block_min, unroll=2): + for i in cutlass.range(n_block_max - n_block_min, unroll=2): n_block = n_block_max - i - 1 load_K(n_block, producer_state=kv_producer_state) load_V(n_block, producer_state=kv_producer_state) @@ -1474,8 +1473,8 @@ def mma( sK: cute.Tensor, sVt: cute.Tensor, sP: cute.Tensor | None, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, gmem_tiled_copy_Q: cute.TiledCopy, tidx: cutlass.Int32, @@ -1499,7 +1498,7 @@ def mma( wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) - tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if cutlass.const_expr(sP is not None) else None + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if const_expr(sP is not None) else None tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) # /////////////////////////////////////////////////////////////////////////////// @@ -1507,8 +1506,8 @@ def mma( # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) smem_thr_copy_P = utils.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) - # tPsP = smem_thr_copy_P.partition_D(sP_pi) if cutlass.const_expr(sP_pi is not None) else None - tPsP = smem_thr_copy_P.partition_D(sP) if cutlass.const_expr(sP is not None) else None + # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None # if cute.arch.thread_idx()[0] == 0: # cute.printf(sP_pi.layout, sP_pi.iterator) # cute.printf(sP.layout, sP.iterator) @@ -1524,11 +1523,11 @@ def mma( # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn # -inf to e.g. -50.0, which can affect the attention softmax. def scoremod_premask_fn(acc_S): - if cutlass.const_expr(softcap_val is not None): + if const_expr(softcap_val is not None): acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) mma_one_n_block = partial( - self.mma_one_n_block_intrawg_overlap if cutlass.const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, pipeline_k=pipeline_k, pipeline_v=pipeline_v, mma_params=mma_params, smem_copy_params=smem_copy_params, softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, @@ -1536,8 +1535,8 @@ def scoremod_premask_fn(acc_S): m_block, head_idx, batch_idx = cute.arch.block_idx() seqlen = SeqlenInfoCls(batch_idx) - if cutlass.const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if self.pack_gqa else 1), self.m_block_size) - m_block - 1 + if const_expr(self.is_causal): # Longest tile first + m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( @@ -1545,9 +1544,9 @@ def scoremod_premask_fn(acc_S): mask_causal=self.is_causal, mask_local=self.is_local, ) # Load Q if PackGQA - if cutlass.const_expr(self.pack_gqa): + if const_expr(self.pack_gqa): pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, head_idx, batch_idx] else: mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) @@ -1560,7 +1559,7 @@ def scoremod_premask_fn(acc_S): n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) consumer_state = pipeline.make_pipeline_state( - cutlass.utils.PipelineUserType.Consumer, self.num_stages + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages ) cute.arch.mbarrier_wait(mbar_ptr_Q, phase=0) softmax.reset() @@ -1569,7 +1568,7 @@ def scoremod_premask_fn(acc_S): # We need masking on S for the very last block when K and V has length not multiple of n_block_size. # We also need masking on S if it's causal, for the last several blocks. # First iteration with seqlen masking - if cutlass.const_expr(self.intra_wg_overlap): + if const_expr(self.intra_wg_overlap): acc_S = cute.make_fragment( tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 ) @@ -1603,13 +1602,12 @@ def scoremod_premask_fn(acc_S): # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) n_block_max -= 1 # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal or self.is_local): + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min_causal_local_mask, unroll=1): + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( n_block, consumer_state, tiled_mma_qk_copy, tiled_mma_pv_copy, @@ -1621,22 +1619,22 @@ def scoremod_premask_fn(acc_S): seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min_before_local_mask, unroll=1): + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( n_block, consumer_state, tiled_mma_qk_copy1, tiled_mma_pv_copy1, check_inf=True, ) # Separate iterations with local masking on the left - if cutlass.const_expr(self.is_local and block_info.window_size_left is not None): + if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min, unroll=1): + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( n_block, consumer_state, tiled_mma_qk_copy2, tiled_mma_pv_copy2, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) ) # Last "half" iteration - if cutlass.const_expr(self.intra_wg_overlap): + if const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(consumer_state, pipeline_v.consumer_try_wait(consumer_state)) sm90_utils.gemm( tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, @@ -1657,11 +1655,11 @@ def scoremod_premask_fn(acc_S): def mma_one_n_block( self, n_block: cutlass.Int32, - smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, @@ -1683,7 +1681,7 @@ def mma_one_n_block( warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) scoremod_premask_fn(acc_S) - if cutlass.const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) @@ -1711,11 +1709,11 @@ def mma_one_n_block( def mma_one_n_block_intrawg_overlap( self, n_block: cutlass.Int32, - smem_pipe_read: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - pipeline_k: cutlass.utils.PipelineAsync, - pipeline_v: cutlass.utils.PipelineAsync, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, @@ -1746,7 +1744,7 @@ def mma_one_n_block_intrawg_overlap( pipeline_k.consumer_release(smem_pipe_read) scoremod_premask_fn(acc_S) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - if cutlass.const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) @@ -1766,26 +1764,26 @@ def mma_one_n_block_intrawg_overlap( @cute.jit def mma_init(self): warp_group_idx = utils.canonical_warp_group_idx(sync=False) - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): if warp_group_idx == 1: - utils.barrier_arrive( + cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * self.num_threads_per_warp_group, ) def warp_scheduler_barrier_sync(self): - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): cute.arch.barrier( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), number_of_threads=2 * self.num_threads_per_warp_group ) def warp_scheduler_barrier_arrive(self): - if cutlass.const_expr(self.use_scheduler_barrier): + if const_expr(self.use_scheduler_barrier): assert self.num_mma_warp_groups in [2, 3] cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 - next_wg = 1 - cur_wg if self.num_mma_warp_groups == 2 else (cur_wg + 1 if cur_wg < self.num_mma_warp_groups - 1 else 0) - utils.barrier_arrive( + next_wg = 1 - cur_wg if const_expr(self.num_mma_warp_groups == 2) else (cur_wg + 1 if cur_wg < self.num_mma_warp_groups - 1 else 0) + cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, ) @@ -1796,9 +1794,9 @@ def load_K( tma_atom: cute.CopyAtom, tKgK: cute.Tensor, tKsK: cute.Tensor, - pipeline: cutlass.utils.PipelineAsync, + pipeline: cutlass.pipeline.PipelineAsync, block: cutlass.Int32, - producer_state: cutlass.utils.PipelineState | pipeline.PipelineStateSimple, + producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, ): # TODO: mcast # TODO check warp_idx if we have 128 producer threads diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index e44f819156a..80a5751dc39 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -21,6 +21,7 @@ import cutlass import cutlass.cute as cute +from cutlass import const_expr from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic @@ -48,7 +49,7 @@ def get_tile_scheduler_cls(args: TileSchedulerArguments) -> Callable: """Returns the appropriate tile scheduler class based on the parameters.""" - if cutlass.const_expr(args.is_persistent): + if const_expr(args.is_persistent): return StaticPersistentTileScheduler else: # return SingleTileScheduler @@ -205,18 +206,18 @@ def __call__( self.k_dtype = mK.element_type self.v_dtype = mV.element_type self.o_dtype = mO.element_type - QO_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [0, 2, 1] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) for t in (mQ, mO) ] - KV_layout_transpose = [1, 3, 2, 0] if cutlass.const_expr(mCuSeqlensK is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] - LSE_layout_transpose = [2, 1, 0] if cutlass.const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None # (s, d, h, b) -> (d, s, h, b) mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=[1, 0, 2, 3])) @@ -225,17 +226,17 @@ def __call__( self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) - if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mQ is not supported") - if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mK is not supported") - if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): raise RuntimeError("The layout of mV is not supported") # check type consistency - if cutlass.const_expr(self.q_dtype != self.k_dtype): + if const_expr(self.q_dtype != self.k_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") - if cutlass.const_expr(self.q_dtype != self.v_dtype): + if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa @@ -290,7 +291,7 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tma_tile_atom_A( + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mQ, cute.select(sQ_layout, mode=[0, 1, 2]), @@ -300,7 +301,7 @@ def __call__( ) # TMA load for K - tma_atom_K, tma_tensor_K = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), @@ -309,7 +310,7 @@ def __call__( self.cluster_layout_vmnk.shape, ) # TMA load for V - tma_atom_V, tma_tensor_V = cute.nvgpu.make_tma_tile_atom_B( + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), @@ -321,12 +322,12 @@ def __call__( o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) # print(sO_layout.outer) - if not self.use_tma_O: + if const_expr(not self.use_tma_O): self.epilogue_warp_ids = (14, 15) self.empty_warp_ids = () self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) - if cutlass.const_expr(self.use_tma_O): - tma_atom_O, mO = cpasync.make_tma_tile_atom( + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( tma_store_op, mO, cute.select(sO_layout, mode=[0, 1]), @@ -377,7 +378,7 @@ class SharedStorage: # Tmem holding buffer tmem_holding_buf: cutlass.Int32 # Smem tensors - sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size * (1 if mLSE is None else 2)] + sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2)] sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], self.buffer_align_bytes, @@ -399,15 +400,15 @@ class SharedStorage: # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if cutlass.const_expr(softcap is None): + if const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E softcap_val = cutlass.Float32(softmax_scale / softcap) - if cutlass.const_expr(window_size_left is not None): + if const_expr(window_size_left is not None): window_size_left = cutlass.Int32(window_size_left) - if cutlass.const_expr(window_size_right is not None): + if const_expr(window_size_right is not None): window_size_right = cutlass.Int32(window_size_right) # Launch the kernel synchronously self.kernel( @@ -495,11 +496,11 @@ def kernel( # coord inside cta tidx, _, _ = cute.arch.thread_idx() - if cutlass.const_expr(not self.pack_gqa): + if const_expr(not self.pack_gqa): cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - if cutlass.const_expr(self.use_tma_O): + if const_expr(self.use_tma_O): cpasync.prefetch_descriptor(tma_atom_O) # Alloc @@ -510,28 +511,28 @@ def kernel( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers - for i in range(self.q_stage): - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) if warp_idx == 2: - for i in range(2): - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) if warp_idx == 3: - if cutlass.const_expr(self.s0_s1_barrier): - for i in range(8): - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + if const_expr(self.s0_s1_barrier): + for i in cutlass.range_constexpr(8): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) if warp_idx == 4: - for i in range(2): - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) if warp_idx == 5: - for i in range(2): - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) - cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) if warp_idx == 6: - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( mbar_ptr + self.mbar_max_reg_setting_offset, cute.arch.WARP_SIZE * len( @@ -545,7 +546,7 @@ def kernel( ), ) if warp_idx == 7: - cute.arch.mbarrier_init_arrive_cnt( + cute.arch.mbarrier_init( mbar_ptr + self.mbar_tmem_dealloc_offset, cute.arch.WARP_SIZE * len( @@ -610,10 +611,10 @@ def kernel( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, window_size_left, window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfo, seqlen_q_static=mQ.shape[0] if not self.pack_gqa else mQ.shape[0][1], + SeqlenInfo, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, @@ -621,7 +622,7 @@ def kernel( AttentionMaskCls = partial( AttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if self.pack_gqa else 1, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) if warp_idx >= 12: @@ -726,7 +727,7 @@ def kernel( AttentionMaskCls=AttentionMaskCls, ) - if cutlass.const_expr(not self.s0_s1_barrier): + if const_expr(not self.s0_s1_barrier): stage = cutlass.Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) softmax_loop( stage=stage, @@ -785,7 +786,7 @@ def load( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - pipeline_kv: cutlass.utils.PipelineAsync, + pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -822,7 +823,7 @@ def load( ) q_producer_phase = cutlass.Int32(1) - kv_producer_state = cutlass.utils.make_pipeline_state(cutlass.utils.PipelineUserType.Producer, self.kv_stage) + kv_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.kv_stage) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -833,7 +834,7 @@ def load( def load_Q(stage: int): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_empty_offset + stage, q_producer_phase) with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes(mbar_ptr + self.mbar_load_q_full_offset + stage, self.tma_copy_q_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr + self.mbar_load_q_full_offset + stage, self.tma_copy_q_bytes) cute.copy( tma_atom_Q, tQgQ[None, 2 * m_block + stage], @@ -853,7 +854,7 @@ def load_Q(stage: int): q_producer_phase ^= 1 load_V(n_block_max - 1, kv_producer_state) # V0 kv_producer_state.advance() - for i in cutlass.range_dynamic(n_block_max - 1 - n_block_min, unroll=1): + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i load_K(n_block, kv_producer_state) # Ki kv_producer_state.advance() @@ -883,7 +884,7 @@ def mma( tOtO1: cute.Tensor, tOrP0: cute.Tensor, tOrP1: cute.Tensor, - pipeline_kv: cutlass.utils.PipelineAsync, + pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, tile_sched_params, block_info: BlockInfo, @@ -909,7 +910,7 @@ def mma( gemm_Si = [ partial( sm100_utils.gemm_ptx_partial, - qk_mma_op, self.tmem_s0_offset if stage == 0 else self.tmem_s1_offset, tSrQs[stage], + qk_mma_op, self.tmem_s0_offset if const_expr(stage == 0) else self.tmem_s1_offset, tSrQs[stage], sA=sQ[None, None, None, stage], sA_swizzle=sQ_swizzle, sB_swizzle=sK_swizzle, zero_init=True ) @@ -918,15 +919,15 @@ def mma( gemm_Pi = [ partial( sm100_utils.gemm_ptx_partial, - pv_mma_op, self.tmem_o0_offset if stage == 0 else self.tmem_o1_offset, tOrPs[stage], + pv_mma_op, self.tmem_o0_offset if const_expr(stage == 0) else self.tmem_o1_offset, tOrPs[stage], sA=None, sA_swizzle=None, sB_swizzle=sV_swizzle ) for stage in range(2) ] mma_q_consumer_phase = cutlass.Int32(0) - mma_kv_consumer_state = cutlass.utils.make_pipeline_state( - cutlass.utils.PipelineUserType.Consumer, self.kv_stage + mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage ) P_full_O_rescaled_phase = cutlass.Int32(0) @@ -937,12 +938,12 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - for stage in range(2): + for stage in cutlass.range_constexpr(2): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) # 2. wait for K0 - if stage == 0: + if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] # We don't need to acquire empty S0 / S1. @@ -972,14 +973,14 @@ def mma( # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate O_should_accumulate = False - for i in cutlass.range_dynamic(n_block_max - 1 - n_block_min, unroll=1): + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) mma_kv_release_state = mma_kv_consumer_state.clone() Vi_index = mma_kv_consumer_state.index tOrVi = tOrV[None, None, None, Vi_index] - for stage in range(2): + for stage in cutlass.range_constexpr(2): # 2. acquire corrected O0/O1_partial and P0 / P1 # For the first iteration in this work tile, waiting for O0/O1_partial # means that the correction warps has finished reading tO during @@ -996,14 +997,14 @@ def mma( # with cute.arch.elect_one(): # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) # 5. release V(i-1) - if stage == 1: + if const_expr(stage == 1): pipeline_kv.consumer_release(mma_kv_release_state) mma_kv_release_state.advance() # End of GEMM_PV00 (P0 * V0 -> O0_partial) # GEMM_QK0i (Q0 * Ki -> S0) # 1. wait for Ki - if stage == 0: + if const_expr(stage == 0): mma_kv_consumer_state.advance() pipeline_kv.consumer_wait(mma_kv_consumer_state) Ki_index = mma_kv_consumer_state.index @@ -1034,7 +1035,7 @@ def mma( pipeline_kv.consumer_wait(mma_kv_consumer_state) Vi_index = mma_kv_consumer_state.index tOrVi = tOrV[None, None, None, Vi_index] - for stage in range(2): + for stage in cutlass.range_constexpr(2): # 2. acquire corrected Oi_partial and Pi cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) # 3. gemm @@ -1144,7 +1145,7 @@ def softmax_loop( mask_fn = partial( mask.apply_mask_sm100, m_block=m_block * 2 + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local ) - softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if self.q_dtype.width == 16 else 0.0) + softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0) softmax.reset() softmax_step = partial( @@ -1167,17 +1168,17 @@ def softmax_loop( si_corr_producer_phase ^= 1 # 1 masking iter - if cutlass.const_expr(not self.is_even_N): + if const_expr(not self.is_even_N): # mask_trip_count = 1 if seqlen.seqlen_k % self.mma_tiler_qk[1] == 0 else 0 mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=False, mask_fn=partial(mask_fn, mask_seqlen=True)) n_block_max -= 1 # Next couple of iterations with causal masking - if cutlass.const_expr(self.is_causal or self.is_local): + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 - for n_tile in cutlass.range_dynamic(n_block_max - n_block_min_causal_local_mask, unroll=1): + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) @@ -1185,13 +1186,13 @@ def softmax_loop( n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( seqlen, m_block, n_block_min ) - for n_tile in cutlass.range_dynamic(n_block_max - n_block_min_before_local_mask, unroll=1): + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) # Separate iterations with local masking on the left - if cutlass.const_expr(self.is_local and block_info.window_size_left is not None): + if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range_dynamic(0, n_block_max - n_block_min, unroll=1): + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) @@ -1200,7 +1201,7 @@ def softmax_loop( # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - if cutlass.const_expr(mLSE is not None): + if const_expr(mLSE is not None): sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) @@ -1208,7 +1209,7 @@ def softmax_loop( # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) # # Write LSE to gmem - # if cutlass.const_expr(mLSE is not None): + # if const_expr(mLSE is not None): # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] # scale = ( # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) @@ -1218,7 +1219,7 @@ def softmax_loop( # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 # if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf # ) - # if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + # if const_expr(not seqlen.has_cu_seqlens_q): # mLSE_cur = mLSE[None, head_idx, batch_idx] # else: # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) @@ -1282,7 +1283,7 @@ def softmax_step( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) - if cutlass.const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(tSrS_t2r, n_block=n_block) row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) @@ -1290,7 +1291,7 @@ def softmax_step( # tSrScale_r2t[0] = acc_scale # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() - if cutlass.const_expr(not is_first): + if const_expr(not is_first): thread_idx = thr_tmem_load.thr_idx sScale[thread_idx + stage * self.m_block_size] = acc_scale # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) @@ -1301,7 +1302,7 @@ def softmax_step( # print(tSrS_t2r) softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait - if cutlass.const_expr(self.s0_s1_barrier): + if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, cutlass.Float32) tSrP_r2t = cute.make_tensor( @@ -1310,7 +1311,7 @@ def softmax_step( # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t) # Sequence barrier arrive - if cutlass.const_expr(self.s0_s1_barrier): + if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) # print(tSrP_r2t_f32, tStP_r2t) cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) @@ -1383,8 +1384,8 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, cutlass.Float32) - for i in cutlass.range_dynamic(n_block_max - n_block_min - 1, unroll=1): - for stage in range(2): + for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for stage in cutlass.range_constexpr(2): # wait for S0 / S1 cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScale_1_t2r, tSrScale_t2r) @@ -1408,13 +1409,13 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) stats = [None, None] - for stage in range(2): + for stage in cutlass.range_constexpr(2): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] row_sum = sScale[tidx + stage * self.m_block_size] - if cutlass.const_expr(mLSE is not None): + if const_expr(mLSE is not None): row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] else: row_max = None @@ -1432,13 +1433,13 @@ def correction_loop( # mma warp can write to them cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2,)) - for stage in range(2): + for stage in cutlass.range_constexpr(2): row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -1530,13 +1531,13 @@ def correction_rescale( frg_count = self.head_dim_v_padded // corr_tile_size tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) - for i in range(frg_count): + for i in cutlass.range_constexpr(frg_count): tOrO_frg_i = tOrO_frg[None, i] tTMrO_i_layout = cute.composition(tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0])) tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) cute.copy(tiled_tmem_load, tOtO_t2r_i, tTMrO_i) - for j in range(0, cute.size(tTMrO_i), 2): + for j in cutlass.range_constexpr(0, cute.size(tTMrO_i), 2): tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( (tTMrO_i[j], tTMrO_i[j + 1]), (scale, scale), ) @@ -1611,12 +1612,12 @@ def correction_epilogue( tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) - for i in range(self.head_dim_v_padded // corr_tile_size): + for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) - for j in range(0, cute.size(tOrO_frg), 2): + for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), ) @@ -1646,12 +1647,12 @@ def epilogue_s2g( while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) - if cutlass.const_expr(self.use_tma_O): + if const_expr(self.use_tma_O): tOsO, tOgO = cpasync.tma_partition( tma_atom_O, 0, @@ -1659,14 +1660,14 @@ def epilogue_s2g( cute.group_modes(sO, 0, 2), cute.group_modes(gO, 0, 2), ) - for stage in range(2): + for stage in cutlass.range_constexpr(2): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) cute.arch.cp_async_bulk_commit_group() - for stage in range(2): + for stage in cutlass.range_constexpr(2): # Ensure O0 / O1 buffer is ready to be released cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) @@ -1679,7 +1680,7 @@ def epilogue_s2g( tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - for stage in range(2): + for stage in cutlass.range_constexpr(2): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) @@ -1709,9 +1710,9 @@ def load_K( tma_atom: cute.CopyAtom, tKgK: cute.Tensor, tKsK: cute.Tensor, - pipeline: cutlass.utils.PipelineAsync, + pipeline: cutlass.pipeline.PipelineAsync, block: cutlass.Int32, - producer_state: cutlass.utils.PipelineState, + producer_state: cutlass.pipeline.PipelineState, ): pipeline.producer_acquire(producer_state) cute.copy( @@ -1722,10 +1723,10 @@ def load_K( ) def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread, len([self.load_warp_id]) + load_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) ) - load_kv_consumer_group = cutlass.utils.CooperativeGroup(cutlass.utils.Agent.Thread, len([self.mma_warp_id])) - return cutlass.utils.PipelineTmaUmma.create( + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) + return cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=load_kv_mbar_ptr, num_stages=self.kv_stage, producer_group=load_kv_producer_group, @@ -1737,7 +1738,7 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): # def warp_scheduler_barrier_init(self): # warp_group_idx = utils.canonical_warp_group_idx(sync=False) # if warp_group_idx == 0: - # utils.barrier_arrive( + # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, # ) @@ -1750,7 +1751,7 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): # def warp_scheduler_barrier_arrive(self): # cur_wg = utils.canonical_warp_group_idx(sync=False) # next_wg = 1 - cur_wg - # utils.barrier_arrive( + # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index d42c33e76e7..6408e11f786 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -4,6 +4,7 @@ from cutlass.cute.nvgpu import warpgroup +@cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, @@ -14,7 +15,7 @@ def gemm( # A_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: - if swap_AB: + if cutlass.const_expr(swap_AB): gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) else: warpgroup.fence() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index cd01726f19a..c68165a3b60 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-06-01] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.dev0. # Features not supported yet: # - varlen # - split (i.e. FlashDecoding) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index be04357c695..660a5efbc00 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -42,7 +42,7 @@ def apply_mask( if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): if t0ScS_mn[0, c][1] >= seqlenk_col_limit: acc_S_mn[None, c].fill(-cutlass.Float32.inf) else: # Causal or local @@ -61,7 +61,7 @@ def apply_mask( 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset ) if cutlass.const_expr(mask_causal): - for r in range(cute.size(tScS_mn.shape[0])): + for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size @@ -73,22 +73,22 @@ def apply_mask( if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): # only consider the column index, so the row index sets to 0. if t0ScS_mn[0, c][1] >= col_limit_right: acc_S_mn[r, c] = -cutlass.Float32.inf else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right - if self.window_size_right is not None + if cutlass.const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( causal_row_offset - 1 - self.window_size_left - if self.window_size_left is not None + if cutlass.const_expr(self.window_size_left is not None) else None ) - for r in range(cute.size(tScS_mn.shape[0])): + for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size else: @@ -102,11 +102,11 @@ def apply_mask( else: col_limit_right = self.n_block_size col_limit_left = ( - row_idx + local_row_offset_left if self.window_size_left is not None else 0 + row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) # traverse column index. - for c in range(cute.size(tScS_mn.shape[1])): + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): col_idx = t0ScS_mn[0, c][1] # only consider the column index, so the row index sets to 0. if col_idx >= col_limit_right or col_idx < col_limit_left: @@ -131,7 +131,7 @@ def apply_mask_sm100( seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): - for i in range(cute.size(tScS_t2r.shape)): + for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf # For some reason the 2 lines above generate really bad SASS @@ -149,7 +149,7 @@ def apply_mask_sm100( col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) - for i in range(cute.size(tScS_t2r.shape)): + for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] ) @@ -157,12 +157,12 @@ def apply_mask_sm100( else: local_row_offset_right = ( causal_row_offset + self.window_size_right - if self.window_size_right is not None + if cutlass.const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( causal_row_offset - 1 - self.window_size_left - if self.window_size_left is not None + if cutlass.const_expr(self.window_size_left is not None) else None ) if cutlass.const_expr(self.window_size_right is not None): @@ -172,10 +172,10 @@ def apply_mask_sm100( else: col_limit_right = self.n_block_size col_limit_left = ( - row_idx + local_row_offset_left if self.window_size_left is not None else 0 + row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 ) # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) - for i in range(cute.size(tScS_t2r.shape)): + for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): col_idx = tScS_t2r[i][1] acc_S[i] = ( -cutlass.Float32.inf diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index 9d2d43e0a6f..46d8dd38798 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -63,7 +63,7 @@ def load_Q( assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tQsQ.shape[1])): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): q_ptr_i64 = utils.shuffle_sync( tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) @@ -77,13 +77,13 @@ def load_Q( mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tQsQ.shape[0][0]) mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) - for k in range(cute.size(tQsQ.shape[2])): + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): ki = tQcQ[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, mQ_cur_copy[None, ki], tQsQ[None, m, k], - pred=tQpQ[None, m, k] if self.check_hdim_oob else None, + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @@ -107,7 +107,7 @@ def store_LSE( assert cute.size(tLSErLSE) <= threads_per_row num_threads = tiled_mma.size tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tLSErLSE)): + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): lse_ptr_i64 = utils.shuffle_sync( tPrLSEPtr[m // threads_per_row], m % threads_per_row, @@ -142,7 +142,7 @@ def store_O( assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) - for m in range(cute.size(tOrO.shape[1])): + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): o_ptr_i64 = utils.shuffle_sync( tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) @@ -156,11 +156,11 @@ def store_O( mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tOrO.shape[0][0]) mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) - for k in range(cute.size(tOrO.shape[2])): + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): ki = tOcO[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, tOrO[None, m, k], mO_cur_copy[None, ki], - pred=tOpO[None, m, k] if self.check_hdim_oob else None, + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 6efc1a96747..7ea4743c2ed 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -7,9 +7,8 @@ import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, Int32, if_generate -from cutlass.utils import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait -from cutlass.utils.pipeline import PipelineUserType -from cutlass.utils.pipeline import _PipelineOp +from cutlass.pipeline import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait +from cutlass.pipeline import PipelineUserType, PipelineOp class PipelineStateSimple: @@ -108,7 +107,7 @@ def create( producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - init_wait: bool = True, + init_wait: cutlass.Constexpr[bool] = True, ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. @@ -123,23 +122,23 @@ def create( :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int """ - producer_type = _PipelineOp.TmaLoad - consumer_type = _PipelineOp.AsyncThread + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) - sync_object_array_full = PipelineAsync._make_sync_object_array( + sync_object_full = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8), num_stages, producer, tx_count ) - sync_object_array_empty = PipelineAsync._make_sync_object_array( + sync_object_empty = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer ) dst_rank = None producer_mask = None - if init_wait: + if cutlass.const_expr(init_wait): pipeline_init_wait() return PipelineTmaAsyncNoCluster( - sync_object_array_full, - sync_object_array_empty, + sync_object_full, + sync_object_empty, num_stages, producer_mask, dst_rank, @@ -151,9 +150,9 @@ def producer_acquire(self, state: PipelineState, try_acquire_token: Optional[Boo """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_array_empty.wait(state.index, state.phase), + lambda: self.sync_object_empty.wait(state.index, state.phase), ) - self.sync_object_array_full.arrive(state.index, self.producer_mask) + self.sync_object_full.arrive(state.index, self.producer_mask) def producer_commit(self, state: PipelineState): """ @@ -168,5 +167,5 @@ def consumer_release(self, state: PipelineState): # Only 1 thread per warp group signals the empty buffer. if_generate( cute.arch.thread_idx()[0] % 128 == 0, - lambda: self.sync_object_array_empty.arrive(state.index, self.consumer_mask), + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), ) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index f94f8579e87..506a5d8b3c8 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -55,7 +55,7 @@ def online_softmax( acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) # Each iteration processes one row of acc_S - for r in range(cute.size(self.row_max)): + for r in cutlass.range_constexpr(cute.size(self.row_max)): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) row_max_cur = self._compute_row_max( acc_S_row, @@ -63,8 +63,7 @@ def online_softmax( ) row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) if cutlass.const_expr(check_inf): - if row_max_cur == -Float32.inf: - row_max_cur = 0.0 + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur if cutlass.const_expr(is_first): row_max_cur_scaled = row_max_cur * self.scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) @@ -90,7 +89,7 @@ def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(self.row_max, Float32) - for r in range(cute.size(self.row_sum)): + for r in cutlass.range_constexpr(cute.size(self.row_sum)): # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = ( self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] @@ -117,7 +116,7 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: """ acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) - for r in range(cute.size(row_scale)): + for r in cutlass.range_constexpr(cute.size(row_scale)): acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) @@ -156,6 +155,7 @@ def update_row_sum( # tmp = self._compute_row_sum(acc_S_row_exp) # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + @cute.jit def scale_subtract_rowmax( self, acc_S_row: cute.Tensor, @@ -163,13 +163,14 @@ def scale_subtract_rowmax( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" minus_row_max_scaled = -row_max * self.scale_log2 - for i in range(0, cute.size(acc_S_row.shape), 2): + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (minus_row_max_scaled, minus_row_max_scaled), ) + @cute.jit def apply_exp2_convert( self, acc_S_row: cute.Tensor, @@ -184,8 +185,8 @@ def apply_exp2_convert( acc_S_row_converted_frg = cute.logical_divide( acc_S_row_converted, cute.make_layout(frg_tile) ) - for j in range(frg_cnt): - for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) @@ -202,14 +203,14 @@ def scale_apply_exp2_convert( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" minus_row_max_scaled = -row_max * self.scale_log2 - for i in range(0, cute.size(acc_S_row.shape), 2): + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (minus_row_max_scaled, minus_row_max_scaled), ) - # for i in range(0, cute.size(acc_S_row.shape), 2): + # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( # (acc_S_row[i], acc_S_row[i + 1]), # (self.scale_log2, self.scale_log2), @@ -226,8 +227,8 @@ def scale_apply_exp2_convert( acc_S_row_converted_frg = cute.logical_divide( acc_S_row_converted, cute.make_layout(frg_tile) ) - for j in range(frg_cnt): - for k in range(0, cute.size(acc_S_row_frg, mode=[0]), 2): + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( # cute.arch.fma_packed_f32x2( # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 6421b64c4bd..d5cb1c10313 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -319,7 +319,6 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: block, bidhb_residual = self.l2_minor_divmod.divmod(l2_mod) else: block, bidhb_residual = self.l2_minor_residual_divmod.divmod(l2_mod) - # TODO: should this be l2_minor or l2_minor_residual? bidhb_actual = bidhb * self.l2_minor_divmod.divisor + bidhb_residual batch_idx, head_idx = self.num_head_divmod.divmod(bidhb_actual) # Longest-processing-time-first diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index af6a8c7332a..80543965093 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -25,7 +25,7 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if swapAB: + if cutlass.const_expr(swapAB): return make_tiled_copy_B(copy_atom, tiled_mma) else: return cute.make_tiled_copy( @@ -38,7 +38,7 @@ def make_tiled_copy_A( def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if swapAB: + if cutlass.const_expr(swapAB): return make_tiled_copy_A(copy_atom, tiled_mma) else: return cute.make_tiled_copy( @@ -59,7 +59,7 @@ def make_tiled_copy_C(copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma) -> cut def mma_make_fragment_A( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if swapAB: + if cutlass.const_expr(swapAB): return mma_make_fragment_B(smem, thr_mma) else: return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) @@ -68,7 +68,7 @@ def mma_make_fragment_A( def mma_make_fragment_B( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if swapAB: + if cutlass.const_expr(swapAB): return mma_make_fragment_A(smem, thr_mma) else: return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) @@ -77,7 +77,7 @@ def mma_make_fragment_B( def get_smem_store_atom( arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] ) -> cute.CopyAtom: - if arch < 90: + if cutlass.const_expr(arch < 90): return cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), element_type, @@ -90,25 +90,20 @@ def get_smem_store_atom( ) -def max_constexpr( - a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric] -) -> cutlass.Constexpr[cute.Numeric]: - return a if a > b else b - - +@cute.jit def warp_reduce( val: cute.TensorSSA | cute.Numeric, op: Callable, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: - if isinstance(val, cute.TensorSSA): + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) - for i in range(cute.size(val.shape)): + for i in cutlass.range_constexpr(cute.size(val.shape)): res[i] = warp_reduce(res[i], op, width) return res.load() else: - for i in range(int(math.log2(width))): + for i in cutlass.range_constexpr(int(math.log2(width))): val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) return val @@ -188,22 +183,22 @@ def exp2f_asm(a: float | Float32, *, loc=None, ip=None) -> Float32: ) +@cute.jit def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: """exp2f calculation for both vector and scalar. - :param x: input value :type x: cute.TensorSSA or Float32 :return: exp2 value :rtype: cute.TensorSSA or Float32 """ - if isinstance(x, cute.TensorSSA): + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): res = cute.make_fragment(x.shape, Float32) res.store(x) - for i in range(cute.size(x.shape)): - res[i] = exp2f_asm(res[i]) + for i in cutlass.range_constexpr(cute.size(x.shape)): + res[i] = cute.arch.exp2(res[i]) return res.load() else: - return exp2f_asm(x) + return cute.arch.exp2(x) @dsl_user_op @@ -237,6 +232,7 @@ def fmax( ) +@cute.jit def fmax_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: @@ -257,7 +253,7 @@ def fmax_reduce( fmax(res[4], res[5]), fmax(res[6], res[7]), ] - for i in range(8, cute.size(x.shape), 8): + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_max[0] = fmax(local_max[0], res[i], res[i + 1]) local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) @@ -266,6 +262,7 @@ def fmax_reduce( return fmax(local_max[0], local_max[2], local_max[3]) +@cute.jit def fadd_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: @@ -282,7 +279,7 @@ def fadd_reduce( else (res[0], res[1]) ) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] - for i in range(8, cute.size(x.shape), 8): + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) @@ -320,60 +317,22 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) +@cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" tApA = cute.make_fragment( cute.make_layout( - (tAcA.shape[0][1], cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) - for rest_v in range(tApA.shape[0]): - for rest_k in range(tApA.shape[2]): + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) return tApA -@dsl_user_op -def barrier_sync( - barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None -) -> None: - llvm.inline_asm( - None, - [ - cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip), - cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip), - ], - "bar.sync $0, $1;", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def barrier_arrive( - barrier_id: int | cutlass.Int32, number_of_threads: int | cutlass.Int32, *, loc=None, ip=None -) -> None: - """ - Arrive at a named barrier. - """ - barrier_id = cutlass.Int32(barrier_id).ir_value(loc=loc, ip=ip) - number_of_threads = cutlass.Int32(number_of_threads).ir_value(loc=loc, ip=ip) - nvvm.barrier_arrive(barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip) - # llvm.inline_asm( - # None, - # [barrier_id, number_of_threads], - # "bar.arrive $0, $1;", - # "r,r", - # has_side_effects=True, - # is_align_stack=False, - # asm_dialect=llvm.AsmDialect.AD_ATT, - # ) - - @dsl_user_op def cp_async_mbarrier_arrive_shared( mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None @@ -413,14 +372,11 @@ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: # ) -@dsl_user_op +@cute.jit def shuffle_sync( value: cute.Numeric, offset: cute.typing.Int, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, - *, - loc=None, - ip=None, ) -> cute.Numeric: assert value.width % 32 == 0, "value type must be a multiple of 32 bits" # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 @@ -430,7 +386,7 @@ def shuffle_sync( val = cute.make_fragment(1, type(value)) val[0] = value val_i32 = cute.recast_tensor(val, cutlass.Int32) - for i in range(cute.size(val_i32)): + for i in cutlass.range_constexpr(cute.size(val_i32)): val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) return val[0] From 25bd20c135950429b89cdb92dcbf2e771957b04a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 4 Jul 2025 20:48:57 -0400 Subject: [PATCH 021/258] [Cute] Use RS WGMMA for fwd_sm90 --- flash_attn/cute/flash_fwd.py | 134 ++++++++++++++++-------------- flash_attn/cute/hopper_helpers.py | 9 +- flash_attn/cute/interface.py | 8 +- flash_attn/cute/mask.py | 12 ++- flash_attn/cute/utils.py | 50 +++++++---- 5 files changed, 127 insertions(+), 86 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 11b34607a1d..66710700041 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -945,6 +945,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase): def __init__(self, *args, intra_wg_overlap: bool = True, **kwargs): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = True def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( @@ -961,12 +962,15 @@ def _get_smem_layout_atom(self): self.dtype ) sO_layout_atom = sV_layout_atom - sP_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size - ), - self.dtype - ) + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size + ), + self.dtype + ) + else: + sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom def _get_tiled_mma(self): @@ -987,8 +991,19 @@ def _get_tiled_mma(self): cutlass.Float32, atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.head_dim_v_padded), + a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, ) - return tiled_mma_qk, tiled_mma_pv + tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + cutlass.Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.head_dim_v_padded), + a_source=warpgroup.OperandSource.RMEM + ) + return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs def _get_shared_storage_cls(self): # If PackGQA, we use cp.async to load Q, so we want sQ to align to 1024 bytes @@ -1072,7 +1087,7 @@ def __call__( ] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None - tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group @@ -1178,10 +1193,9 @@ def __call__( self.gmem_tiled_copy_K, self.gmem_tiled_copy_V, self.gmem_tiled_copy_O, - # the compiler is unhappy about us using tiled_mma_qk/pv and setting the ACCUMULATE - # field inside a for loop, so we work around by creating multiple copies of the - # tiled_mma_qk/pv. - *((tiled_mma_qk, tiled_mma_pv) * 4), + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, SharedStorage, ).launch( grid=grid_dim, @@ -1221,12 +1235,7 @@ def kernel( gmem_tiled_copy_O: cute.TiledCopy, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - tiled_mma_qk_copy: cute.TiledMma, - tiled_mma_pv_copy: cute.TiledMma, - tiled_mma_qk_copy1: cute.TiledMma, - tiled_mma_pv_copy1: cute.TiledMma, - tiled_mma_qk_copy2: cute.TiledMma, - tiled_mma_pv_copy2: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, SharedStorage: cutlass.Constexpr, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1285,7 +1294,7 @@ def kernel( sP_pi = storage.sP.get_tensor(sP_layout) sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) else: - sP, sP_pi = None + sP, sP_pi = None, None # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma sVt = utils.transpose_view(sV) @@ -1349,6 +1358,7 @@ def kernel( self.mma( tiled_mma_qk, tiled_mma_pv, + tiled_mma_pv_rs, softmax, acc_O, mQ, @@ -1365,12 +1375,6 @@ def kernel( block_info, SeqlenInfoCls, AttentionMaskCls, - tiled_mma_qk_copy, - tiled_mma_pv_copy, - tiled_mma_qk_copy1, - tiled_mma_pv_copy1, - tiled_mma_qk_copy2, - tiled_mma_pv_copy2, ) # /////////////////////////////////////////////////////////////////////////////// # Epilogue @@ -1466,6 +1470,7 @@ def mma( self, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, softmax: Softmax, acc_O: cute.Tensor, mQ: cute.Tensor, @@ -1482,12 +1487,6 @@ def mma( block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - tiled_mma_qk_copy: cute.TiledMma, - tiled_mma_pv_copy: cute.TiledMma, - tiled_mma_qk_copy1: cute.TiledMma, - tiled_mma_pv_copy1: cute.TiledMma, - tiled_mma_qk_copy2: cute.TiledMma, - tiled_mma_pv_copy2: cute.TiledMma, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1498,7 +1497,12 @@ def mma( wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) - tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) if const_expr(sP is not None) else None + if const_expr(self.mma_pv_is_rs): + acc_S_shape = tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) + acc_S_layout = cute.make_layout(acc_S_shape) + tOrP = cute.make_fragment(utils.convert_layout_acc_frgA(acc_S_layout), self.dtype) + else: + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) # /////////////////////////////////////////////////////////////////////////////// @@ -1528,6 +1532,7 @@ def scoremod_premask_fn(acc_S): mma_one_n_block = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + tiled_mma_qk=tiled_mma_qk, tiled_mma_pv=tiled_mma_pv, tiled_mma_pv_rs=tiled_mma_pv_rs, pipeline_k=pipeline_k, pipeline_v=pipeline_v, mma_params=mma_params, smem_copy_params=smem_copy_params, softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, @@ -1583,20 +1588,21 @@ def scoremod_premask_fn(acc_S): mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) softmax.online_softmax(acc_S, is_first=True, check_inf=True) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_thr_copy_P.retile(rP) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP.store(tOrP_acc.load().to(self.dtype)) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_thr_copy_P.retile(tOrP) + cute.copy(smem_thr_copy_P, tPrP, tPsP) + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter acc_O.fill(0.0) else: self.warp_scheduler_barrier_sync() consumer_state = mma_one_n_block( - n_block_max - 1, consumer_state, tiled_mma_qk, tiled_mma_pv, + n_block_max - 1, consumer_state, is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) ) # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) @@ -1610,7 +1616,7 @@ def scoremod_premask_fn(acc_S): for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( - n_block, consumer_state, tiled_mma_qk_copy, tiled_mma_pv_copy, + n_block, consumer_state, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) ) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) @@ -1621,16 +1627,14 @@ def scoremod_premask_fn(acc_S): # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - consumer_state = mma_one_n_block( - n_block, consumer_state, tiled_mma_qk_copy1, tiled_mma_pv_copy1, check_inf=True, - ) + consumer_state = mma_one_n_block(n_block, consumer_state, check_inf=True) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile consumer_state = mma_one_n_block( - n_block, consumer_state, tiled_mma_qk_copy2, tiled_mma_pv_copy2, + n_block, consumer_state, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) ) # Last "half" iteration @@ -1658,6 +1662,7 @@ def mma_one_n_block( smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mma_params: SimpleNamespace, @@ -1685,15 +1690,17 @@ def mma_one_n_block( mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_copy_params.smem_thr_copy_P.retile(rP) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP.store(tOrP_acc.load().to(self.dtype)) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(mma_params.tOrP) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) softmax.rescale_O(mma_params.acc_O, row_scale) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() sm90_utils.gemm( @@ -1712,6 +1719,7 @@ def mma_one_n_block_intrawg_overlap( smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mma_params: SimpleNamespace, @@ -1750,15 +1758,17 @@ def mma_one_n_block_intrawg_overlap( row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) - rP = cute.make_fragment_like(acc_S, self.dtype) - rP.store(acc_S.load().to(self.dtype)) - # tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) - tPrP = smem_copy_params.smem_thr_copy_P.retile(rP) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP.store(tOrP_acc.load().to(self.dtype)) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) softmax.rescale_O(mma_params.acc_O, row_scale) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV return smem_pipe_read @cute.jit diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 6408e11f786..3a57e43da08 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -19,10 +19,13 @@ def gemm( gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) else: warpgroup.fence() - tiled_mma.set(warpgroup.Field.ACCUMULATE, not zero_init) + # We make a new mma_atom since we'll be modifying its attribute (accumulate). + # Otherwise the compiler complains "operand #0 does not dominate this use" + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): - cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) - tiled_mma.set(warpgroup.Field.ACCUMULATE, True) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + mma_atom.set(warpgroup.Field.ACCUMULATE, True) warpgroup.commit_group() if cutlass.const_expr(wg_wait >= 0): warpgroup.wait_group(wg_wait) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c68165a3b60..e2f03832912 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -129,10 +129,14 @@ def _flash_attn_fwd( causal, local = True, False else: causal, local = False, True - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # if compute_capability == 9: # TODO: tune block size according to hdim + # if not causal and not local: + # n_block_size = 128 + compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 660a5efbc00..89ce612c6ec 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -43,8 +43,11 @@ def apply_mask( if cutlass.const_expr(mask_seqlen): # traverse column index. for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): - if t0ScS_mn[0, c][1] >= seqlenk_col_limit: - acc_S_mn[None, c].fill(-cutlass.Float32.inf) + # if t0ScS_mn[0, c][1] >= seqlenk_col_limit: + # acc_S_mn[None, c].fill(-cutlass.Float32.inf) + oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit + for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): + acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] else: # Causal or local # If PackGQA, we split the work of compute divmod among threads in the same row threads_per_row = thr_mma.tv_layout_C.shape[0][0] @@ -75,8 +78,9 @@ def apply_mask( # traverse column index. for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): # only consider the column index, so the row index sets to 0. - if t0ScS_mn[0, c][1] >= col_limit_right: - acc_S_mn[r, c] = -cutlass.Float32.inf + # if t0ScS_mn[0, c][1] >= col_limit_right: + # acc_S_mn[r, c] = -cutlass.Float32.inf + acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 80543965093..eb82940cdee 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -141,23 +141,43 @@ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) +@cute.jit def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. - # Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) - # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) - acc_layout_divided = cute.logical_divide(acc_layout, (None, None, 2)) - rA_mma_view = cute.make_layout( - ( - (acc_layout_divided.shape[0], acc_layout_divided.shape[2][0]), - acc_layout_divided.shape[1], - acc_layout_divided.shape[2][1], - ), - stride=( - (acc_layout_divided.stride[0], acc_layout_divided.stride[2][0]), - acc_layout_divided.stride[1], - acc_layout_divided.stride[2][1], - ), - ) + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if cutlass.const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) return rA_mma_view From 0d0ab1ba229f00069a6b013bfc6da0db9e0f8039 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 4 Jul 2025 22:53:27 -0400 Subject: [PATCH 022/258] [Cute] Use tile_scheduler in fwd_sm90 --- flash_attn/cute/flash_fwd.py | 547 +++++++++++++++++++---------------- 1 file changed, 295 insertions(+), 252 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 66710700041..f6504df7038 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -29,6 +29,7 @@ from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, ParamsBase class FlashAttentionForwardBase: @@ -1144,12 +1145,26 @@ def __call__( shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is None) else max_seqlen_q, self.m_block_size), + + TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.m_block_size), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3] if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ.shape[0] - 1), + cute.size(mQ.shape[3]), + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + self.dtype.width // 8, + is_persistent=False, ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + # TODO: deal with PackGQA and varlen + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + # grid_dim = ( + # cute.ceil_div(cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is None) else max_seqlen_q, self.m_block_size), + # cute.size(mQ.shape[2]), + # cute.size(mQ.shape[3] if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ.shape[0] - 1), + # ) # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. # Right after this, we multiply by log2(e) before applying exp2. # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -1196,6 +1211,8 @@ def __call__( tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs, + tile_sched_params, + TileScheduler, SharedStorage, ).launch( grid=grid_dim, @@ -1236,7 +1253,9 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tiled_mma_pv_rs: cute.TiledMma, - SharedStorage: cutlass.Constexpr, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor @@ -1253,7 +1272,7 @@ def kernel( # Mbarrier init mbar_ptr_Q = storage.mbar_ptr.data_ptr() - if warp_idx == 0: + if warp_idx == 1: # if tidx < 2: # # barrierO num threads should be self.num_mma_threads # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) @@ -1290,17 +1309,20 @@ def kernel( sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) + # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + sVt = utils.transpose_view(sV) if const_expr(sP_layout is not None): sP_pi = storage.sP.get_tensor(sP_layout) sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) else: sP, sP_pi = None, None - # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma - sVt = utils.transpose_view(sV) + # reuse sQ's data iterator + sO_pi = storage.sQ.get_tensor(sO_layout) + # TODO: idk why not using sO_pi is faster + sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, head_idx, batch_idx = cute.arch.block_idx() block_info = BlockInfo( self.m_block_size, self.n_block_size, self.is_causal, self.is_local, window_size_left, window_size_right, @@ -1317,76 +1339,60 @@ def kernel( window_size_left=window_size_left, window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoCls(batch_idx) - # Can't early exit so we have to write it this way (under an if statement) - if mCuSeqlensQ is None or m_block * self.n_block_size < seqlen.seqlen_q: - if const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 - # TODO: return early if n_block_max == 0 - # if self.is_causal: - # if n_block_max <= 0: - # return - - if warp_idx < 4: # Producer - cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) - self.load( - mQ, - mK, - mV, - sQ, - sK, - sV, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - block_info, - SeqlenInfoCls - ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: # Producer + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) - else: # Consumer - cute.arch.warpgroup_reg_alloc(self.num_mma_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - tidx = tidx - 128 - acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) - self.mma( - tiled_mma_qk, - tiled_mma_pv, - tiled_mma_pv_rs, - softmax, - acc_O, - mQ, - sQ, - sK, - sVt, - sP, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - gmem_tiled_copy_Q, - tidx, - softcap_val, - block_info, - SeqlenInfoCls, - AttentionMaskCls, - ) - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - # reuse sQ's data iterator - sO_pi = cute.make_tensor(sQ.iterator, sO_layout) - # TODO: idk why not using sO_pi is faster - sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) - self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, - gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, - ) + else: # Consumer + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + mQ, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softcap_val, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + ) @cute.jit def load( @@ -1405,65 +1411,75 @@ def load( mbar_ptr_Q: cutlass.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - m_block, head_idx, batch_idx = cute.arch.block_idx() - seqlen = SeqlenInfoCls(batch_idx) - if const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 - if const_expr(not seqlen.has_cu_seqlens_q): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) - head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx - if const_expr(not seqlen.has_cu_seqlens_k): - mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] - else: - mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] - gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) - gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) - if const_expr(not self.pack_gqa): - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, - cute.make_layout(1), - cute.group_modes(sQ, 0, 2), - cute.group_modes(gQ, 0, 2), - ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, - cute.make_layout(1), - cute.group_modes(sK, 0, 2), - cute.group_modes(gK, 0, 2), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, - cute.make_layout(1), - cute.group_modes(sV, 0, 2), - cute.group_modes(gV, 0, 2), - ) - kv_producer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.num_stages - ) - load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) - load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) if warp_idx_in_wg == 0: - # load_Q - if const_expr(not self.pack_gqa): - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) - cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - for i in cutlass.range(n_block_max - n_block_min, unroll=2): - n_block = n_block_max - i - 1 - load_K(n_block, producer_state=kv_producer_state) - load_V(n_block, producer_state=kv_producer_state) - kv_producer_state.advance() + q_producer_phase = cutlass.Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] + gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) + gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(not self.pack_gqa): + gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 2), + cute.group_modes(gQ, 0, 2), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 2), + cute.group_modes(gK, 0, 2), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 2), + cute.group_modes(gV, 0, 2), + ) + load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) + load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) + # load_Q + if const_expr(not self.pack_gqa): + # TODO: wait for Q to be empty + q_producer_phase ^= 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) + cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + for i in cutlass.range(n_block_max - n_block_min, unroll=2): + n_block = n_block_max - i - 1 + load_K(n_block, producer_state=kv_producer_state) + load_V(n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + @cute.jit def mma( @@ -1471,22 +1487,29 @@ def mma( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tiled_mma_pv_rs: cute.TiledMma, - softmax: Softmax, - acc_O: cute.Tensor, + # softmax: Softmax, + # acc_O: cute.Tensor, mQ: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], sQ: cute.Tensor, sK: cute.Tensor, sVt: cute.Tensor, - sP: cute.Tensor | None, + sP: Optional[cute.Tensor], + sO: cute.Tensor, pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], tidx: cutlass.Int32, + softmax_scale_log2: cutlass.Float32, softcap_val: cutlass.Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, + TileSchedulerCls: Callable, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1519,141 +1542,161 @@ def mma( self.mma_init() - # shape: (atom_v_m * rest_m) + acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) + acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) # group parameters for mma_one_n_block mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if const_expr(softcap_val is not None): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) - - mma_one_n_block = partial( + mma_one_n_block_all = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, tiled_mma_qk=tiled_mma_qk, tiled_mma_pv=tiled_mma_pv, tiled_mma_pv_rs=tiled_mma_pv_rs, pipeline_k=pipeline_k, pipeline_v=pipeline_v, mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, scoremod_premask_fn=scoremod_premask_fn, + check_inf=True, ) - m_block, head_idx, batch_idx = cute.arch.block_idx() - seqlen = SeqlenInfoCls(batch_idx) - if const_expr(self.is_causal): # Longest tile first - m_block = cute.ceil_div(seqlen.seqlen_q * (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1), self.m_block_size) - m_block - 1 - - mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, - mask_causal=self.is_causal, mask_local=self.is_local, - ) - # Load Q if PackGQA - if const_expr(self.pack_gqa): - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) - if const_expr(not seqlen.has_cu_seqlens_q): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) - # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) - # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, - # headdim=mQ.shape[1]) - pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) - - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - consumer_state = pipeline.make_pipeline_state( + q_consumer_phase = cutlass.Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.num_stages ) - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=0) - softmax.reset() - # For performance reason, we separate out two kinds of iterations: - # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. - # We also need masking on S if it's causal, for the last several blocks. - # First iteration with seqlen masking - if const_expr(self.intra_wg_overlap): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - pipeline_k.consumer_wait(consumer_state) - sm90_utils.gemm( - tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, consumer_state.index], - zero_init=True, wg_wait=0 + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + # -inf to e.g. -50.0, which can affect the attention softmax. + def scoremod_premask_fn(acc_S): + if const_expr(softcap_val is not None): + acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) + + # shape: (atom_v_m * rest_m) + softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + mma_one_n_block = partial( + mma_one_n_block_all, softmax=softmax, scoremod_premask_fn=scoremod_premask_fn ) - pipeline_k.consumer_release(consumer_state) - scoremod_premask_fn(acc_S) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - softmax.online_softmax(acc_S, is_first=True, check_inf=True) - tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP.store(tOrP_acc.load().to(self.dtype)) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_thr_copy_P.retile(tOrP) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - consumer_state = mma_one_n_block( - n_block_max - 1, consumer_state, - is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True) + + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, + mask_causal=self.is_causal, mask_local=self.is_local, ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + softmax.reset() + # Load Q if PackGQA + if const_expr(self.pack_gqa): + pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + q_consumer_phase ^= 1 + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last several blocks. + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + acc_S = cute.make_fragment( + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 + ) + pipeline_k.consumer_wait(kv_consumer_state) + sm90_utils.gemm( + tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, kv_consumer_state.index], + zero_init=True, wg_wait=0 + ) + pipeline_k.consumer_release(kv_consumer_state) + scoremod_premask_fn(acc_S) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + softmax.online_softmax(acc_S, is_first=True) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP.store(tOrP_acc.load().to(self.dtype)) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_thr_copy_P.retile(tOrP) + cute.copy(smem_thr_copy_P, tPrP, tPsP) + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + n_block_max - 1, kv_consumer_state, + is_first_n_block=True, mask_fn=partial(mask_fn, mask_seqlen=True) + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + kv_consumer_state = mma_one_n_block( + n_block, kv_consumer_state, mask_fn=partial(mask_fn, mask_seqlen=False) + ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( seqlen, m_block, n_block_min ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - consumer_state = mma_one_n_block( - n_block, consumer_state, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) - ) - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - consumer_state = mma_one_n_block(n_block, consumer_state, check_inf=True) - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - consumer_state = mma_one_n_block( - n_block, consumer_state, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + kv_consumer_state = mma_one_n_block(n_block, kv_consumer_state, check_inf=True) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + kv_consumer_state = mma_one_n_block( + n_block, kv_consumer_state, + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + ) + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + sm90_utils.gemm( + tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, + mma_params.tOrVt[None, None, None, kv_consumer_state.index], + zero_init=False, wg_wait=-1 ) - # Last "half" iteration - if const_expr(self.intra_wg_overlap): - pipeline_v.consumer_wait(consumer_state, pipeline_v.consumer_try_wait(consumer_state)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, consumer_state.index], - zero_init=False, wg_wait=-1 + warpgroup.wait_group(0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + else: + self.warp_scheduler_barrier_arrive() + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize() + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, + gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, ) - warpgroup.wait_group(0) - pipeline_v.consumer_release(consumer_state) - consumer_state.advance() - else: - self.warp_scheduler_barrier_arrive() - # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize() - softmax.rescale_O(acc_O, row_scale) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() @cute.jit def mma_one_n_block( From 312bb9b35ecbac27ae11bcac38bfaec68dd3aba3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 5 Jul 2025 16:34:49 -0400 Subject: [PATCH 023/258] [Cute] Add SingleTileVarlenScheduler to fwd_sm90 --- flash_attn/cute/flash_fwd.py | 36 +++-- flash_attn/cute/flash_fwd_sm100.py | 63 ++++----- flash_attn/cute/interface.py | 2 +- flash_attn/cute/tile_scheduler.py | 219 ++++++++++++++++++++++++++++- flash_attn/cute/utils.py | 15 ++ tests/cute/test_flash_attn.py | 8 +- 6 files changed, 291 insertions(+), 52 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index f6504df7038..dbac72f4918 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -29,7 +29,7 @@ from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase class FlashAttentionForwardBase: @@ -303,7 +303,8 @@ def epilogue( if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) gLSE_expanded_layout = cute.append( @@ -326,7 +327,8 @@ def epilogue( if const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) # thr_mma = tiled_mma.get_slice(tidx) # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) @@ -1146,19 +1148,26 @@ def __call__( stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) - TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.m_block_size), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]), + cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], - self.dtype.width // 8, + total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + block_size=self.m_block_size, + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, is_persistent=False, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) - # TODO: deal with PackGQA and varlen grid_dim = TileScheduler.get_grid_shape(tile_sched_params) # grid_dim = ( # cute.ceil_div(cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is None) else max_seqlen_q, self.m_block_size), @@ -1422,12 +1431,14 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, head_idx, batch_idx] else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] @@ -1522,8 +1533,9 @@ def mma( tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) if const_expr(self.mma_pv_is_rs): acc_S_shape = tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) - acc_S_layout = cute.make_layout(acc_S_shape) - tOrP = cute.make_fragment(utils.convert_layout_acc_frgA(acc_S_layout), self.dtype) + tOrP = cute.make_fragment( + utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype + ) else: tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) @@ -1564,6 +1576,7 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn # -inf to e.g. -50.0, which can affect the attention softmax. def scoremod_premask_fn(acc_S): @@ -1590,7 +1603,8 @@ def scoremod_premask_fn(acc_S): if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, head_idx, batch_idx] else: - mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 80a5751dc39..9de5f2c4fe6 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -35,7 +35,7 @@ from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase # class NamedBarrierFwd(enum.IntEnum): @@ -47,15 +47,6 @@ # PEmpty = enum.auto() -def get_tile_scheduler_cls(args: TileSchedulerArguments) -> Callable: - """Returns the appropriate tile scheduler class based on the parameters.""" - if const_expr(args.is_persistent): - return StaticPersistentTileScheduler - else: - # return SingleTileScheduler - return SingleTileLPTScheduler - - class FlashAttentionForwardSm100: arch = 100 @@ -353,7 +344,31 @@ def __call__( self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) - self.tile_scheduler_cls, self.tile_sched_params, grid = self._compute_grid(mO, self.cta_tiler, self.is_persistent) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + if const_expr(self.is_causal or self.is_local): + TileScheduler = SingleTileLPTScheduler + else: + TileScheduler = SingleTileScheduler if const_expr(not self.is_persistent) else StaticPersistentTileScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 + total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + block_size=self.cta_tiler[0], + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) self.mbar_load_q_full_offset = 0 self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage @@ -437,9 +452,9 @@ class SharedStorage: gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, - self.tile_sched_params, + tile_sched_params, ).launch( - grid=grid, + grid=grid_dim, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, smem=self.shared_storage.size_in_bytes(), @@ -1754,25 +1769,3 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) - - @staticmethod - def _compute_grid( - mO: cute.Tensor, - cta_tiler: Tuple[int, int, int], - is_persistent: bool, - ) -> Tuple[TileSchedulerArguments, Tuple[int, int, int]]: - o_shape = mO.shape - tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), - cute.size(o_shape[2]), - cute.size(o_shape[3]), - cute.size(o_shape[0]), # TODO - o_shape[1], - o_shape[1], - 2, # TODO - is_persistent, - ) - tile_scheduler_cls = get_tile_scheduler_cls(tile_sched_args) - tile_sched_params = tile_scheduler_cls.to_underlying_arguments(tile_sched_args) - grid = tile_scheduler_cls.get_grid_shape(tile_sched_params) - return tile_scheduler_cls, tile_sched_params, grid diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index e2f03832912..f07af019964 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -135,7 +135,7 @@ def _flash_attn_fwd( # if compute_capability == 9: # TODO: tune block size according to hdim # if not causal and not local: - # n_block_size = 128 + # n_block_size = 176 compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index d5cb1c10313..e0bf202f022 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -7,6 +7,7 @@ import cutlass.cute as cute from cutlass import Int32 +import flash_attn.cute.utils as utils from flash_attn.cute.fast_math import FastDivmod, clz @@ -42,6 +43,11 @@ class TileSchedulerArguments(ParamsBase): seqlen_k: Int32 headdim: Int32 headdim_v: Int32 + total_q: Int32 + block_size: cutlass.Constexpr[int] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 element_size: cutlass.Constexpr[int] = 2 is_persistent: cutlass.Constexpr[bool] = False @@ -228,15 +234,18 @@ class Params(ParamsBase): def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTScheduler.Params": + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.block_size, args.qhead_per_kvhead_packgqa, args.element_size) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit - log2_floor = lambda n: 31 - clz(n) # swizzle is how many heads can fit in L2 + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # Seems faster if swizzle if a power of 2 + log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -283,6 +292,7 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod + @cute.jit def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": tile_idx = cute.arch.block_idx()[0] return SingleTileLPTScheduler( @@ -373,3 +383,210 @@ def __new_from_mlir_values__(self, values): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + block_size: cutlass.Constexpr[int] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler.Params": + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + block_size=args.block_size, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + ) + + def __init__( + self, + num_head: Int32, + num_batch: Int32, + tile_idx: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + block_size: cutlass.Constexpr[int] = 128, + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, + *, + loc=None, + ip=None, + ): + self.num_head = num_head + self.num_batch = num_batch + self.mCuSeqlensQ = mCuSeqlensQ + self.mSeqUsedQ = mSeqUsedQ + assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + self.block_size = block_size + self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + self._tile_idx = tile_idx + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileVarlenScheduler( + params.num_head, + params.num_batch, + tile_idx, + mCuSeqlensQ=params.mCuSeqlensQ, + mSeqUsedQ=params.mSeqUsedQ, + block_size=params.block_size, + qhead_per_kvhead_packgqa=params.qhead_per_kvhead_packgqa, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.block_size - 1) + ) // params.block_size + return (total_blocks_max * params.num_head, Int32(1), Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + batch_idx = lane + bidb_start + if cutlass.const_expr(self.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < self.num_batch: + seqlen = self.mSeqUsedQ[batch_idx] + else: + assert self.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx < self.num_batch: + cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] + # Very important that we set mask_and_clamp to 0 + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1, mask_and_clamp=0) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, self.block_size) + if batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * self.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= self.num_batch: + batch_idx = Int32(self.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * self.num_head + is_valid = False + if batch_idx >= self.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(self.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * self.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * self.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * self.num_head + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < self.num_batch + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.num_head, + self.num_batch, + self._tile_idx, + self.mCuSeqlensQ, + self.mSeqUsedQ, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.num_head, + self.num_batch, + self._tile_idx, + self.mCuSeqlensQ, + self.mSeqUsedQ, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileVarlenScheduler( + *(tuple(obj_list)), + block_size=self.block_size, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead_packgqa, + loc=self._loc, + ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index eb82940cdee..e12dcac2584 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -425,3 +425,18 @@ def noop_asm(val: cutlass.Int32, *, loc=None, ip=None) -> cute.Numeric: asm_dialect=llvm.AsmDialect.AD_ATT, ) ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if cutlass.const_expr(lane is None): + lane = cute.arch.lane_idx() + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) + return val diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 268744f67fd..16a1c3fa65c 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -230,8 +230,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -241,7 +241,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -282,7 +282,7 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) - batch_size = 9 if seqlen_q <= 2048 else 2 + batch_size = 49 if seqlen_q <= 2048 else 2 nheads = 6 # batch_size = 1 # nheads = 1 From 10e8c39fdaaf5c422dbd3f13c662f5d93830029e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 5 Jul 2025 23:30:04 -0400 Subject: [PATCH 024/258] [Cute] Do manual f32->f16x2 conversion for fwd_sm90 --- flash_attn/cute/blackwell_helpers.py | 15 +++---- flash_attn/cute/flash_fwd.py | 12 ++++-- flash_attn/cute/interface.py | 7 ++-- flash_attn/cute/utils.py | 58 ++++++++++++++++++++++++++-- 4 files changed, 73 insertions(+), 19 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index ca9c4b77a88..176b083c4f5 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -308,15 +308,10 @@ def gemm_ptx_partial( smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) - if cutlass.const_expr(not is_ts): - offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 - for k in range(cute.size(tCrA.shape[2]))] - else: - offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 - for k in range(cute.size(tCrA.shape[2]))] + tCrA_layout = tCrA.layout if cutlass.const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] - offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 - for k in range(cute.size(tCrB.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] if cutlass.const_expr(not is_ts): @@ -330,8 +325,8 @@ def gemm_ptx_partial( None, [ # acc.iterator.toint().ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(smem_desc_start_a_lo).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), cutlass.Int32(not zero_init).ir_value(), ], "{\n\t" diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index dbac72f4918..11755a06bcc 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1637,7 +1637,11 @@ def scoremod_premask_fn(acc_S): softmax.online_softmax(acc_S, is_first=True) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP.store(tOrP_acc.load().to(self.dtype)) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP) if const_expr(not self.mma_pv_is_rs): tPrP = smem_thr_copy_P.retile(tOrP) cute.copy(smem_thr_copy_P, tPrP, tPsP) @@ -1749,7 +1753,8 @@ def mma_one_n_block( # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP.store(tOrP_acc.load().to(self.dtype)) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + utils.cvt_f16(tOrP_acc, tOrP) if const_expr(not self.mma_pv_is_rs): tPrP = smem_copy_params.smem_thr_copy_P.retile(mma_params.tOrP) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) @@ -1817,7 +1822,8 @@ def mma_one_n_block_intrawg_overlap( pipeline_v.consumer_release(smem_pipe_read_v) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP.store(tOrP_acc.load().to(self.dtype)) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + utils.cvt_f16(tOrP_acc, tOrP) if const_expr(not self.mma_pv_is_rs): tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f07af019964..6d370bc0078 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -133,9 +133,9 @@ def _flash_attn_fwd( assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # if compute_capability == 9: # TODO: tune block size according to hdim - # if not causal and not local: - # n_block_size = 176 + if compute_capability == 9: # TODO: tune block size according to hdim + if not causal and not local: + n_block_size = 192 compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, @@ -154,6 +154,7 @@ def _flash_attn_fwd( qhead_per_kvhead, is_causal=causal, is_local=local, + pack_gqa=False, m_block_size=m_block_size, n_block_size=n_block_size, # num_stages=1, diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index e12dcac2584..b6c9711aedf 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -257,9 +257,21 @@ def fmax_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): - if cutlass.const_expr(init_val is None): - init_val = -cutlass.Float32.inf - return x.reduce(cute.ReductionOp.MAX, init_val, 0) + # if cutlass.const_expr(init_val is None): + # init_val = -cutlass.Float32.if + # return x.reduce(cute.ReductionOp.MAX, init_val, 0) + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max # We instead force the 3-input max. @@ -290,6 +302,18 @@ def fadd_reduce( if cutlass.const_expr(init_val is None): init_val = Float32.zero return x.reduce(cute.ReductionOp.ADD, init_val, 0) + # res = cute.make_fragment(x.shape, Float32) + # res.store(x) + # local_sum = [res[0], res[1], res[2], res[3]] + # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + # local_sum[0] += res[i + 0] + # local_sum[1] += res[i + 1] + # local_sum[2] += res[i + 2] + # local_sum[3] += res[i + 3] + # local_sum[0] += local_sum[1] + # local_sum[2] += local_sum[3] + # local_sum[0] += local_sum[2] + # return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val else: res = cute.make_fragment(x.shape, Float32) res.store(x) @@ -440,3 +464,31 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> val += partial_sum # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) return val + + +@dsl_user_op +def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst: cute.Tensor): + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) From 3fc8c3ce281db3dc64a4f690295efaf14a68a510 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Jul 2025 17:43:47 -0400 Subject: [PATCH 025/258] [Cute] Split tP arrival for fwd_sm100 --- flash_attn/cute/blackwell_helpers.py | 49 +++++++++++++++++++++++----- flash_attn/cute/flash_fwd_sm100.py | 31 +++++++++++------- flash_attn/cute/softmax.py | 17 +++++----- flash_attn/cute/utils.py | 7 ++++ 4 files changed, 76 insertions(+), 28 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 176b083c4f5..6b963e6069d 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -278,6 +278,8 @@ def gemm_ptx_partial( sB: cute.Tensor, sA_swizzle: Optional[cute.Swizzle], sB_swizzle: cute.Swizzle, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[cutlass.Int32] = None, zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM @@ -321,6 +323,7 @@ def gemm_ptx_partial( smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" if cutlass.const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" llvm.inline_asm( None, [ @@ -365,14 +368,34 @@ def gemm_ptx_partial( asm_dialect=llvm.AsmDialect.AD_ATT, ) else: + input_args = [ + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ] + if cutlass.const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(cutlass.Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$3], $4, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" llvm.inline_asm( None, - [ - # acc.iterator.toint().ir_value(), - cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), - cutlass.Int32(smem_desc_start_b_lo).ir_value(), - cutlass.Int32(not zero_init).ir_value(), - ], + # [ + # # acc.iterator.toint().ir_value(), + # cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # cutlass.Int32(smem_desc_start_b_lo).ir_value(), + # cutlass.Int32(not zero_init).ir_value(), + # ], + input_args, "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" @@ -399,10 +422,20 @@ def gemm_ptx_partial( # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(1, cute.size(tCrA.shape[2])) + for k in range(1, cute.size(tCrA.shape[2]) if cutlass.const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 2) ) + + mbar_wait_str + + ("".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(cute.size(tCrA.shape[2]) // 2, cute.size(tCrA.shape[2])) + ) if cutlass.const_expr(mbar_ptr is not None) else "") + "}\n", - "r,r,r", + # "r,r,r", + "r,r,r" if cutlass.const_expr(mbar_ptr is None) else "r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 9de5f2c4fe6..9997a80a2ca 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -83,7 +83,6 @@ def __init__( self.pv_acc_dtype = cutlass.Float32 self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent - self.is_even_N = False self.is_causal = is_causal self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead @@ -384,7 +383,9 @@ def __call__( self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 self.mbar_max_reg_setting_offset = self.mbar_s0_s1_sequence_offset + 8 self.mbar_tmem_dealloc_offset = self.mbar_max_reg_setting_offset + 1 - self.mbar_total = self.mbar_tmem_dealloc_offset + 1 + # self.mbar_total = self.mbar_tmem_dealloc_offset + 1 + self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 + self.mbar_total = self.mbar_P_full_2_offset + 2 @cute.struct class SharedStorage: @@ -546,6 +547,9 @@ def kernel( cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + if warp_idx == 8: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_2_offset + i, cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) if warp_idx == 6: cute.arch.mbarrier_init( mbar_ptr + self.mbar_max_reg_setting_offset, @@ -1003,7 +1007,8 @@ def mma( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) - gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase= P_full_O_rescaled_phase) # 4. release accumulated O0_partial / O1_partial # Don't need to signal O_full to the correction warps anymore since the # correction warps wait for the softmax warps anyway. By the time the softmax @@ -1055,7 +1060,8 @@ def mma( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) - gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp # has signaled to the correction warp, the softmax warp has just finished compute @@ -1136,7 +1142,7 @@ def softmax_loop( tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), cutlass.Float32, ) tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) thr_tmem_store = tiled_tmem_store.get_slice(tidx) @@ -1183,16 +1189,13 @@ def softmax_loop( si_corr_producer_phase ^= 1 # 1 masking iter - if const_expr(not self.is_even_N): - # mask_trip_count = 1 if seqlen.seqlen_k % self.mma_tiler_qk[1] == 0 else 0 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=False, mask_fn=partial(mask_fn, mask_seqlen=True)) - n_block_max -= 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + n_block_max -= 1 # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) - # Currently we can't do loop with negative step https://github.com/NVIDIA/cutlass/issues/2326 for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) @@ -1329,10 +1332,16 @@ def softmax_step( if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) # print(tSrP_r2t_f32, tStP_r2t) - cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 2): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() # Notify mma warp that P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 2, cute.size(tStP_r2t.shape[2])): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + # Notify mma warp that the 2nd half of P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.arch.exp2(acc_scale_) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 506a5d8b3c8..dfbfa708fc8 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -28,12 +28,12 @@ def reset(self) -> None: self.row_sum.fill(0.0) def _compute_row_max( - self, acc_S_row: cute.TensorSSA, init_val: float | Float32 = -Float32.inf + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None ) -> Float32: return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) def _compute_row_sum( - self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 = Float32.zero + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None ) -> Float32: return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) @@ -59,7 +59,7 @@ def online_softmax( acc_S_row = acc_S_mn[r, None].load() # (n_block_size) row_max_cur = self._compute_row_max( acc_S_row, - init_val=-Float32.inf if cutlass.const_expr(is_first) else self.row_max[r], + init_val=self.row_max[r] if cutlass.const_expr(not is_first) else None, ) row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) if cutlass.const_expr(check_inf): @@ -76,7 +76,7 @@ def online_softmax( # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) acc_S_row_sum = ( - self._compute_row_sum(acc_S_row_exp) + self.row_sum[r] * row_scale[r] + self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[r] * row_scale[r]) ) self.row_max[r] = row_max_cur self.row_sum[r] = acc_S_row_sum @@ -128,7 +128,6 @@ def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[flo @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): - # row_max_new = self._compute_row_max(acc_S_row, init_val=-Float32.inf) row_max_new = self._compute_row_max(acc_S_row) row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale = 0.0 @@ -137,12 +136,12 @@ def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Floa row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) if cutlass.const_expr(self.rescale_threshold > 0.0): if acc_scale_ >= -self.rescale_threshold: row_max_new = row_max_old row_max_safe = row_max_old - acc_scale_ = 0.0 - acc_scale = utils.exp2f(acc_scale_) + acc_scale = 1.0 self.row_max[0] = row_max_new return row_max_safe, acc_scale @@ -162,12 +161,12 @@ def scale_subtract_rowmax( row_max: Float32, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" - minus_row_max_scaled = -row_max * self.scale_log2 + row_max_scaled = row_max * self.scale_log2 for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), - (minus_row_max_scaled, minus_row_max_scaled), + (-row_max_scaled, -row_max_scaled), ) @cute.jit diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index b6c9711aedf..4b2fe92bac5 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -262,6 +262,12 @@ def fmax_reduce( # return x.reduce(cute.ReductionOp.MAX, init_val, 0) res = cute.make_fragment(x.shape, Float32) res.store(x) + # local_max = [res[0], res[1]] + # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): + # local_max[0] = fmax(local_max[0], res[i + 0]) + # local_max[1] = fmax(local_max[1], res[i + 1]) + # local_max[0] = fmax(local_max[0], local_max[1]) + # return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) local_max = [res[0], res[1], res[2], res[3]] for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): local_max[0] = fmax(local_max[0], res[i + 0]) @@ -319,6 +325,7 @@ def fadd_reduce( res.store(x) local_sum_0 = ( cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) if cutlass.const_expr(init_val is not None) else (res[0], res[1]) ) From 723c36b350edb45b3d2942353093f2c8c0aba562 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Jul 2025 17:49:59 -0400 Subject: [PATCH 026/258] [Cute] Set tP arrival split to be 3/4 --- flash_attn/cute/blackwell_helpers.py | 4 ++-- flash_attn/cute/flash_fwd_sm100.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 6b963e6069d..ea464168faa 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -422,7 +422,7 @@ def gemm_ptx_partial( # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(1, cute.size(tCrA.shape[2]) if cutlass.const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 2) + for k in range(1, cute.size(tCrA.shape[2]) if cutlass.const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 4 * 3) ) + mbar_wait_str + ("".join( @@ -431,7 +431,7 @@ def gemm_ptx_partial( f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(cute.size(tCrA.shape[2]) // 2, cute.size(tCrA.shape[2])) + for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) ) if cutlass.const_expr(mbar_ptr is not None) else "") + "}\n", # "r,r,r", diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 9997a80a2ca..e9a535a7258 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1333,12 +1333,12 @@ def softmax_step( cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) # print(tSrP_r2t_f32, tStP_r2t) # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) - for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 2): + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() # Notify mma warp that P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 2, cute.size(tStP_r2t.shape[2])): + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) # Notify mma warp that the 2nd half of P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) From e540fc1beabc6d36e77c8eb0151fab35f31d0b34 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Jul 2025 17:54:40 -0400 Subject: [PATCH 027/258] [Cute] Fix missing tmem_store fence --- flash_attn/cute/flash_fwd_sm100.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index e9a535a7258..d9dd1b71ab7 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1340,6 +1340,7 @@ def softmax_step( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() # Notify mma warp that the 2nd half of P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) From aace11d5f1a60fc020a625402ba78a730096a3f1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Jul 2025 19:06:59 -0400 Subject: [PATCH 028/258] [Cute] Tune num registers for fwd_sm100 --- flash_attn/cute/flash_fwd_sm100.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index d9dd1b71ab7..963445c0c16 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -132,13 +132,13 @@ def __init__( self.num_regs_softmax = 176 # self.num_regs_correction = 104 # self.num_regs_correction = 96 - self.num_regs_correction = 80 - # self.num_regs_correction = 64 + # self.num_regs_correction = 80 + self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 # self.num_regs_other = 24 # self.num_regs_other = 32 # self.num_regs_other = 64 - self.num_regs_other = 80 - # self.num_regs_other = 96 + # self.num_regs_other = 80 + self.num_regs_other = 96 if self.is_causal or self.is_local else 80 # self.num_regs_other = 48 self.buffer_align_bytes = 1024 From f14dcb1d439a6c43163e288da51dd314632fabde Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 8 Jul 2025 20:26:49 -0400 Subject: [PATCH 029/258] [Cute] Check that compute_capability is 9.x or 10.x --- flash_attn/cute/interface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 6d370bc0078..5816714a520 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -162,7 +162,7 @@ def _flash_attn_fwd( num_threads=num_threads, Q_in_regs=False, ) - else: + elif compute_capability == 10: fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -171,6 +171,8 @@ def _flash_attn_fwd( qhead_per_kvhead=qhead_per_kvhead, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, ) + else: + raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, From 8ba246f6cc8813d41f9289e2781b7d8fa22a97cb Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Wed, 9 Jul 2025 11:10:28 -0700 Subject: [PATCH 030/258] [BE] Better compress flash attention binaries (#1744) --- hopper/setup.py | 3 +++ setup.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..10894252db0 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -524,6 +524,9 @@ def nvcc_threads_args(): "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted + "-Xfatbin", # compress all binary sections + "-compress-all", + "-compress-mode=size", # compress with CUDA fatbin more aggressively ] if get_platform() == "win_amd64": nvcc_flags.extend( diff --git a/setup.py b/setup.py index a7f15a99724..9f994023e8d 100644 --- a/setup.py +++ b/setup.py @@ -286,6 +286,9 @@ def validate_and_update_archs(archs): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", + "-Xfatbin", + "-compress-all", + "-compress-mode=size", # "--ptxas-options=-v", # "--ptxas-options=-O2", # "-lineinfo", From 944811ec93fac746321b2ccf5f23934c35d4b326 Mon Sep 17 00:00:00 2001 From: LosCrossOS <165311345+loscrossos@users.noreply.github.com> Date: Wed, 9 Jul 2025 20:23:14 +0200 Subject: [PATCH 031/258] adding changes for Windows compile fix for MSVC. (#1716) Signed-off-by: loscrossos <165311345+loscrossos@users.noreply.github.com> --- setup.py | 60 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 9f994023e8d..d54e93f6649 100644 --- a/setup.py +++ b/setup.py @@ -195,6 +195,37 @@ def validate_and_update_archs(archs): # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-Xfatbin", + "-compress-all", + "-compress-mode=size", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", + ] + + compiler_c17_flag=["-O3", "-std=c++17"] + # Add Windows-specific flags + if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': + nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) + compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] + ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", @@ -274,33 +305,8 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-Xfatbin", - "-compress-all", - "-compress-mode=size", - # "--ptxas-options=-v", - # "--ptxas-options=-O2", - # "-lineinfo", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - # "-DFLASHATTENTION_DISABLE_DROPOUT", - # "-DFLASHATTENTION_DISABLE_ALIBI", - # "-DFLASHATTENTION_DISABLE_SOFTCAP", - # "-DFLASHATTENTION_DISABLE_UNEVEN_K", - # "-DFLASHATTENTION_DISABLE_LOCAL", - ] - + cc_flag - ), + "cxx": compiler_c17_flag, + "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", From 1e556445878e3724ccfe9384df061a1fce3ff1a4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 9 Jul 2025 14:28:18 -0400 Subject: [PATCH 032/258] [CI] Compile with nvcc 12.9.1 --- .github/workflows/publish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6205ebf4b69..0a6a57510d7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -45,7 +45,7 @@ jobs: os: [ubuntu-22.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1'] - cuda-version: ['12.9.0'] + cuda-version: ['12.9.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -90,7 +90,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.25 + uses: Jimver/cuda-toolkit@v0.2.26 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} From 7b0bfcc3d1f69786f0c4277c582ad58acdfb297d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 9 Jul 2025 14:33:49 -0400 Subject: [PATCH 033/258] Bump to v2.8.1 --- flash_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 9ef52f504bb..fa45a44cbe1 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.8.0.post2" +__version__ = "2.8.1" from flash_attn.flash_attn_interface import ( flash_attn_func, From adf27d1db38223288981c4dc3509efafbddd3422 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 9 Jul 2025 14:58:38 -0400 Subject: [PATCH 034/258] [WIP] Add benchmarking script --- benchmarks/benchmark_attn.py | 397 +++++++++++++++++++++++++++++++++++ 1 file changed, 397 insertions(+) create mode 100644 benchmarks/benchmark_attn.py diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py new file mode 100644 index 00000000000..8d4a5c0c0f7 --- /dev/null +++ b/benchmarks/benchmark_attn.py @@ -0,0 +1,397 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None +# cudnn = None + +Timing = NamedTuple('timing', [('mean', float)]) + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python +from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python +try: + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 +except ImportError: + flash_attn_func_v3 = None + +if torch.cuda.get_device_capability()[0] != 9: + flash_attn_func_v3 = None +# flash_attn_func_v3 = None + +flash_attn_func = None + +from triton.testing import do_bench + +def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): + # # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert v.shape == (b, nheads_k, seqlen_k, headdim) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty_like(q_gpu) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + + o, stats = graph.sdpa( + name="sdpa", + q=q, + k=k, + v=v, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + stats: stats_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert v.shape == (b, nheads_k, seqlen_k, headdim) + assert g.shape == (b, nheads, seqlen_q, headdim) + assert o.shape == (b, nheads, seqlen_q, headdim) + assert lse.shape == (b, nheads, seqlen_q, 1) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g + dq_gpu = torch.empty_like(q_gpu) + dk_gpu = torch.empty_like(k_gpu) + dv_gpu = torch.empty_like(v_gpu) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + o = graph.tensor_like(o_gpu.detach()) + g = graph.tensor_like(g_gpu.detach()) + stats = graph.tensor_like(lse.detach()) + + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=q, + k=k, + v=v, + o=o, + dO=g, + stats=stats, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + g: g_gpu, + stats: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return run + + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = False +has_backward = False +page_size = None +softcap = 0.0 +V_colmajor = False +deterministic = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 +# for headdim in [64, 128, 256]: +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +# bs_seqlen_vals = [(32, 512), (16, 1024)] +# bs_seqlen_vals = [(2, 64 * 132)] +bs_seqlen_vals = [(4, 8192)] +# bs_seqlen_vals = [(1, 16 * 1024)] +time_f = {} +time_b = {} + +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192]: +# for headdim in [64, 96, 128, 192, 256]: +# for headdim in [64, 96, 128]: +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192, 256]: +for headdim in [128]: + nheads = dim // headdim + # nheads = 128 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + nheads_kv = nheads + # nheads_kv = nheads // 4 + # nheads_kv = 1 + headdim_v = headdim + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False + + for batch_size, seqlen in bs_seqlen_vals: + num_splits = 0 + # window_size = (-1, -1) + window_size = (None, None) + window_size_fa = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + pack_gqa = None + # seqlen_q = 64 + seqlen_q = seqlen + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) + q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_(has_backward) + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) + if varlen: + q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen + # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:256] + # seqlen_q = 256 + # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:384] + # seqlen_q = 384 + if page_size is not None: + assert seqlen % page_size == 0 + k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + + for causal in [False, True]: + # for causal in [False]: + print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: + cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + if not varlen: + m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + else: + m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean + if has_backward: + time.sleep(1) + if not varlen: + _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + else: + _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) + + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') + time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean + time.sleep(1) + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean + # pytorch_profiler(cudnn_spda, backward=False) + # pytorch_profiler(cudnn_spda_bwd, backward=False) + time.sleep(1) + if flash_attn_func_v3 is not None: + if not varlen: + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) + else: + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) + time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean + if flash_attn_func_python is not None: + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: + time.sleep(1) + if not varlen: + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3') + else: + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav3') + time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean + # time.sleep(1) + # if not varlen: + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) + # else: + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) + # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python') + + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') + if flash_attn_func_v3 is not None: + print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') + + if flash_attn_func_python is not None: + print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv2 Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') From ed209409acedbb2379f870bbd03abce31a7a51b7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 11 Jul 2025 15:39:36 -0400 Subject: [PATCH 035/258] [FA3] Don't return lse --- hopper/flash_attn_interface.py | 4 ++-- hopper/test_flash_attn.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index cfb8881b4b2..0e93f234aa3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -304,7 +304,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return out @staticmethod def backward(ctx, dout, *args): @@ -403,7 +403,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return out @staticmethod def backward(ctx, dout, *args): diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 109b5fcac00..f1247e689da 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -193,7 +193,7 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( + out = flash_attn_func( q, k, v, @@ -460,7 +460,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, @@ -1050,7 +1050,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) - out0, lse0 = flash_attn_func(q, k, v, causal=causal) + out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq @@ -1058,9 +1058,9 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): for i in range(1000): torch.random.manual_seed(42) - out, lse = flash_attn_func(q, k, v, causal=causal) + out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) - assert torch.equal(lse, lse0) + # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) From 87855ac853a4c76e7f0194ab78ea408cdbac3ec0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Jul 2025 15:39:42 -0400 Subject: [PATCH 036/258] [Cute] Clean up flash_fwd_sm90 and flash_fwd_sm100 a bit --- flash_attn/cute/flash_fwd.py | 35 +-- flash_attn/cute/flash_fwd_sm100.py | 344 +++++++++++++---------------- 2 files changed, 179 insertions(+), 200 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 11755a06bcc..bc4b29b97c1 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1169,11 +1169,6 @@ def __call__( ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - # grid_dim = ( - # cute.ceil_div(cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is None) else max_seqlen_q, self.m_block_size), - # cute.size(mQ.shape[2]), - # cute.size(mQ.shape[3] if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ.shape[0] - 1), - # ) # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. # Right after this, we multiply by log2(e) before applying exp2. # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -1228,6 +1223,7 @@ def __call__( block=[self.num_threads, 1, 1], smem=SharedStorage.size_in_bytes(), stream=stream, + min_blocks_per_mp=1, ) @cute.kernel @@ -1330,8 +1326,6 @@ def kernel( # TODO: idk why not using sO_pi is faster sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) - # Thread index, block index - tidx, _, _ = cute.arch.thread_idx() block_info = BlockInfo( self.m_block_size, self.n_block_size, self.is_causal, self.is_local, window_size_left, window_size_right, @@ -1375,6 +1369,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Tile MMA compute thread partitions and allocate accumulators # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 self.mma( tiled_mma_qk, @@ -1619,6 +1614,7 @@ def scoremod_premask_fn(acc_S): # those that need masking on S, and those that don't. # We need masking on S for the very last block when K and V has length not multiple of n_block_size. # We also need masking on S if it's causal, for the last several blocks. + O_should_accumulate = False # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): acc_S = cute.make_fragment( @@ -1649,13 +1645,15 @@ def scoremod_premask_fn(acc_S): cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - acc_O.fill(0.0) + # acc_O.fill(0.0) else: self.warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( n_block_max - 1, kv_consumer_state, - is_first_n_block=True, mask_fn=partial(mask_fn, mask_seqlen=True) + is_first_n_block=True, mask_fn=partial(mask_fn, mask_seqlen=True), + O_should_accumulate=False ) + O_should_accumulate = True # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) n_block_max -= 1 # Next couple of iterations with causal masking @@ -1667,8 +1665,10 @@ def scoremod_premask_fn(acc_S): for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile kv_consumer_state = mma_one_n_block( - n_block, kv_consumer_state, mask_fn=partial(mask_fn, mask_seqlen=False) + n_block, kv_consumer_state, mask_fn=partial(mask_fn, mask_seqlen=False), + O_should_accumulate=O_should_accumulate ) + O_should_accumulate = True n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( @@ -1677,7 +1677,8 @@ def scoremod_premask_fn(acc_S): # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - kv_consumer_state = mma_one_n_block(n_block, kv_consumer_state, check_inf=True) + kv_consumer_state = mma_one_n_block(n_block, kv_consumer_state, check_inf=True, O_should_accumulate=O_should_accumulate) + O_should_accumulate = True # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) @@ -1685,15 +1686,17 @@ def scoremod_premask_fn(acc_S): n_block = n_block_max - 1 - n_tile kv_consumer_state = mma_one_n_block( n_block, kv_consumer_state, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False) + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False), + O_should_accumulate=O_should_accumulate ) + O_should_accumulate = True # Last "half" iteration if const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) sm90_utils.gemm( tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, mma_params.tOrVt[None, None, None, kv_consumer_state.index], - zero_init=False, wg_wait=-1 + zero_init=not O_should_accumulate, wg_wait=-1 ) warpgroup.wait_group(0) pipeline_v.consumer_release(kv_consumer_state) @@ -1733,6 +1736,7 @@ def mma_one_n_block( mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, + O_should_accumulate: cutlass.Boolean = True, ): acc_S = cute.make_fragment( tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 @@ -1768,7 +1772,7 @@ def mma_one_n_block( sm90_utils.gemm( tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=is_first_n_block, wg_wait=0 + zero_init=not O_should_accumulate, wg_wait=0 ) pipeline_v.consumer_release(smem_pipe_read) smem_pipe_read.advance() @@ -1790,6 +1794,7 @@ def mma_one_n_block_intrawg_overlap( scoremod_premask_fn: Callable, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, + O_should_accumulate: cutlass.Boolean = True, ): smem_pipe_read_v = smem_pipe_read.clone() smem_pipe_read.advance() @@ -1807,7 +1812,7 @@ def mma_one_n_block_intrawg_overlap( sm90_utils.gemm( tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, mma_params.tOrVt[None, None, None, smem_pipe_read_v.index], - zero_init=False, wg_wait=-1 + zero_init=not O_should_accumulate, wg_wait=-1 ) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 963445c0c16..a3380fedd2d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -21,7 +21,7 @@ import cutlass import cutlass.cute as cute -from cutlass import const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic @@ -79,8 +79,8 @@ def __init__( self.cta_tiler = (2 * m_block_size, n_block_size, self.head_dim_padded) self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) self.pv_mma_tiler = (m_block_size, self.head_dim_v_padded, n_block_size) - self.qk_acc_dtype = cutlass.Float32 - self.pv_acc_dtype = cutlass.Float32 + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent self.is_causal = is_causal @@ -140,6 +140,7 @@ def __init__( # self.num_regs_other = 80 self.num_regs_other = 96 if self.is_causal or self.is_local else 80 # self.num_regs_other = 48 + self.num_regs_empty = 24 self.buffer_align_bytes = 1024 @@ -166,16 +167,16 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, + softmax_scale: Float32, stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - max_seqlen_q: Optional[cutlass.Int32] = None, - softcap: cutlass.Float32 | float | None = None, - window_size_left: cutlass.Int32 | int | None = None, - window_size_right: cutlass.Int32 | int | None = None, + max_seqlen_q: Optional[Int32] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -381,9 +382,7 @@ def __call__( self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.epi_stage self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.epi_stage self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 - self.mbar_max_reg_setting_offset = self.mbar_s0_s1_sequence_offset + 8 - self.mbar_tmem_dealloc_offset = self.mbar_max_reg_setting_offset + 1 - # self.mbar_total = self.mbar_tmem_dealloc_offset + 1 + self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 self.mbar_total = self.mbar_P_full_2_offset + 2 @@ -392,9 +391,9 @@ class SharedStorage: # m_barriers for pipelines mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] # Tmem holding buffer - tmem_holding_buf: cutlass.Int32 + tmem_holding_buf: Int32 # Smem tensors - sScale: cute.struct.MemRange[cutlass.Float32, 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2)] + sScale: cute.struct.MemRange[Float32, 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2)] sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], self.buffer_align_bytes, @@ -421,11 +420,11 @@ class SharedStorage: softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = cutlass.Float32(softmax_scale / softcap) + softcap_val = Float32(softmax_scale / softcap) if const_expr(window_size_left is not None): - window_size_left = cutlass.Int32(window_size_left) + window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): - window_size_right = cutlass.Int32(window_size_right) + window_size_right = Int32(window_size_right) # Launch the kernel synchronously self.kernel( tma_tensor_Q, @@ -480,10 +479,10 @@ def kernel( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_O: cute.CopyAtom, - softmax_scale_log2: cutlass.Float32, - softcap_val: Optional[cutlass.Float32], - window_size_left: Optional[cutlass.Int32], - window_size_right: Optional[cutlass.Int32], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, @@ -492,7 +491,6 @@ def kernel( gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, - # tile_sched_params: TileSchedulerArguments, tile_sched_params: ParamsBase, ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -509,22 +507,22 @@ def kernel( computation phases, and optional attention masking. """ - # coord inside cta - tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if const_expr(not self.pack_gqa): - cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) - if const_expr(self.use_tma_O): - cpasync.prefetch_descriptor(tma_atom_O) + # Prefetch tma descriptor + if warp_idx == 0: + if const_expr(not self.pack_gqa): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) mbar_ptr = storage.mbar_ptr.data_ptr() - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): @@ -547,23 +545,9 @@ def kernel( cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) - if warp_idx == 8: + if warp_idx == 6: for i in cutlass.range_constexpr(2): cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_2_offset + i, cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) - if warp_idx == 6: - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_max_reg_setting_offset, - cute.arch.WARP_SIZE - * len( - ( - *self.empty_warp_ids, - self.load_warp_id, - self.mma_warp_id, - *self.epilogue_warp_ids, - *self.correction_warp_ids, - ) - ), - ) if warp_idx == 7: cute.arch.mbarrier_init( mbar_ptr + self.mbar_tmem_dealloc_offset, @@ -599,7 +583,7 @@ def kernel( qk_acc_shape = thr_mma_qk.partition_shape_C((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) # TODO: this is a fake tensor, need to retrieve tmem_ptr - tmem_ptr = cute.make_ptr(cutlass.Float32, 0, mem_space=cute.AddressSpace.tmem, + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) @@ -643,96 +627,99 @@ def kernel( window_size_left=window_size_left, window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(len(self.empty_warp_ids) > 0): + if warp_idx == self.empty_warp_ids[0]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) - if warp_idx >= 12: + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_max_reg_setting_offset) - # /////////////////////////////////////////////////////////////////////////////// - # LOAD - # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.load_warp_id: - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) - self.load( - tile_scheduler, - thr_mma_qk, - thr_mma_pv, - mQ, - mK, - mV, - sQ, - sK, - sV, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - pipeline_kv, - mbar_ptr, - block_info, - SeqlenInfoCls, - ) - # /////////////////////////////////////////////////////////////////////////////// - # MMA - # /////////////////////////////////////////////////////////////////////////////// + self.load( + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) if warp_idx == self.mma_warp_id: - # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: - # Alloc tmem buffer - tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) - if warp_idx == self.mma_warp_id: - cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) - cute.arch.sync_warp() - - self.mma( - tiled_mma_qk, - tiled_mma_pv, - sQ, - sK, - sV, - # sQ_pi.iterator, - # sK_pi.iterator, - sQ_layout.inner, - sK_layout.inner, - sV_layout.inner, - tStS0, - tStS1, - tOtO0, - tOtO1, - tOrP0, - tOrP1, - pipeline_kv, - mbar_ptr, - tile_sched_params, - block_info, - SeqlenInfoCls, - ) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + sQ_layout.inner, + sK_layout.inner, + sV_layout.inner, + tStS0, + tStS1, + tOtO0, + tOtO1, + tOrP0, + tOrP1, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) - # if warp_idx == self.mma_warp_id: - # dealloc tmem buffer - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) - tmem_alloc_cols = cutlass.Int32(self.tmem_alloc_cols) - # Retrieving tmem ptr and make acc - tmem_ptr = cute.arch.retrieve_tmem_ptr( - cutlass.Float32, - alignment=16, - ptr_to_buffer_holding_addr=storage.tmem_holding_buf, - ) - cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) - self.epilogue_s2g(tile_scheduler, mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls) + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g(mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls) # /////////////////////////////////////////////////////////////////////////////// # Softmax # /////////////////////////////////////////////////////////////////////////////// if warp_idx < self.correction_warp_ids[0]: # increase register after decreasing - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_max_reg_setting_offset, 0) cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) - - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, @@ -740,14 +727,14 @@ def kernel( sScale=sScale, mLSE=mLSE, mbar_ptr=mbar_ptr, - tile_scheduler=tile_scheduler, block_info=block_info, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, + TileSchedulerCls=TileSchedulerCls, ) if const_expr(not self.s0_s1_barrier): - stage = cutlass.Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) + stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) softmax_loop( stage=stage, tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s0_offset if stage == 0 else self.tmem_s1_offset), tStS.layout)) @@ -768,7 +755,6 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_max_reg_setting_offset) self.correction_loop( thr_mma_qk, thr_mma_pv, @@ -782,9 +768,9 @@ def kernel( tma_atom_O, mbar_ptr, softmax_scale_log2, - tile_sched_params, block_info, SeqlenInfoCls, + TileSchedulerCls, ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -793,7 +779,6 @@ def kernel( @cute.jit def load( self, - tile_scheduler, thr_mma_qk: cute.core.ThrMma, thr_mma_pv: cute.core.ThrMma, mQ: cute.Tensor, @@ -809,6 +794,7 @@ def load( mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, ): # (bM, bK, loopM, loopL) gQ_qdhb = cute.local_tile(mQ, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0, None, None)) @@ -841,8 +827,9 @@ def load( cute.group_modes(tOgV_dkhb, 0, 3), ) - q_producer_phase = cutlass.Int32(1) + q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.kv_stage) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -892,8 +879,6 @@ def mma( sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, - # sQ_base_addr: cute.Pointer, - # sK_base_addr: cute.Pointer, sQ_swizzle: cute.Swizzle, sK_swizzle: cute.Swizzle, sV_swizzle: cute.Swizzle, @@ -905,9 +890,9 @@ def mma( tOrP1: cute.Tensor, pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, - tile_sched_params, block_info: BlockInfo, SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, ): thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM @@ -919,12 +904,6 @@ def mma( tOrPs = (tOrP0, tOrP1) qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op - # sQ_base_addr_for_desc = cute.arch.make_warp_uniform(sm100_desc.make_smem_desc_start_addr(sQ_base_addr)) - # sK_base_addr_for_desc = cute.arch.make_warp_uniform(sm100_desc.make_smem_desc_start_addr(sK_base_addr)) - # sQ_addr_offset_for_desc = (cute.crd2idx((0, 0, 0, 1), sQ.layout) * sQ.element_type.width // 8) >> 4 - # sK_addr_offset_for_desc = (cute.crd2idx((0, 0, 0, 1), sK.layout) * sK.element_type.width // 8) >> 4 - # sQ_layout = cute.select(sQ.layout, mode=[0, 1, 2]) - # sK_layout = cute.select(sK.layout, mode=[0, 1, 2]) gemm_Si = [ partial( @@ -944,13 +923,13 @@ def mma( for stage in range(2) ] - mma_q_consumer_phase = cutlass.Int32(0) + mma_q_consumer_phase = Int32(0) mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage ) - P_full_O_rescaled_phase = cutlass.Int32(0) + P_full_O_rescaled_phase = Int32(0) - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -972,13 +951,6 @@ def mma( # 3. gemm # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) gemm_Si[stage](tCrB=tSrKi, sB=sK[None, None, None, mma_kv_consumer_state.index]) - # sm100_utils.gemm_ptx_partial1( - # qk_mma_op, 0 + stage * self.tmem_s1_offset, tSrQs[stage], tSrKi, - # sQ_base_addr_for_desc, sQ_addr_offset_for_desc, stage, - # sK_base_addr_for_desc, sK_addr_offset_for_desc, 0, - # sQ_layout, sK_layout, sQ_swizzle, sK_swizzle, - # zero_init=True - # ) # 4. release S0 / S1 with cute.arch.elect_one(): tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) @@ -1086,17 +1058,17 @@ def mma( def softmax_loop( self, stage: int, - # stage: cutlass.Int32, - softmax_scale_log2: cutlass.Float32, + # stage: Int32, + softmax_scale_log2: Float32, thr_mma_qk: cute.core.ThrMma, tStSi: cute.Tensor, sScale: cute.Tensor, mLSE: Optional[cute.Tensor], mbar_ptr: cute.Pointer, - tile_scheduler, block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, + TileSchedulerCls: Callable, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1129,34 +1101,35 @@ def softmax_loop( tStP = cute.make_tensor(tStSi.iterator + self.tmem_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32, ) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) tStS_t2r = thr_tmem_load.partition_S(tStSi) tmem_store_scale_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), cutlass.Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32, ) thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), cutlass.Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32, ) tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) thr_tmem_store = tiled_tmem_store.get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) - mma_si_consumer_phase = cutlass.Int32(0) - si_corr_producer_phase = cutlass.Int32(1) - s0_s1_sequence_phase = cutlass.Int32(1 if stage == 0 else 0) + mma_si_consumer_phase = Int32(0) + si_corr_producer_phase = Int32(1) + s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) # self.warp_scheduler_barrier_init() warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1214,7 +1187,7 @@ def softmax_loop( n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) - # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, cutlass.Float32) + # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) # tSrScale_r2t[0] = softmax.row_sum[0] # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() @@ -1235,7 +1208,7 @@ def softmax_loop( # LN2 = math.log(2.0) # lse = ( # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 - # if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf # ) # if const_expr(not seqlen.has_cu_seqlens_q): # mLSE_cur = mLSE[None, head_idx, batch_idx] @@ -1253,14 +1226,14 @@ def softmax_loop( @cute.jit def softmax_step( self, - # stage: cutlass.Int32, - mma_si_consumer_phase: cutlass.Int32, - si_corr_producer_phase: cutlass.Int32, - s0_s1_sequence_phase: cutlass.Int32, - n_block: cutlass.Int32, + # stage: Int32, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + n_block: Int32, softmax: SoftmaxSm100, mbar_ptr: cute.Pointer, - mbar_s0_s1_sequence_offset: cutlass.Int32, + mbar_s0_s1_sequence_offset: Int32, thr_mma_qk: cute.core.ThrMma, thr_tmem_load: cute.CopyAtom, thr_tmem_store: cute.CopyAtom, @@ -1288,7 +1261,7 @@ def softmax_step( 5. Computing row sums for normalization 6. Coordinating pipeline synchronization between different processing stages """ - tilePlikeFP32 = self.mma_tiler_qk[1] // cutlass.Float32.width * self.v_dtype.width + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) @@ -1305,7 +1278,7 @@ def softmax_step( mask_fn(tSrS_t2r, n_block=n_block) row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) - # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, cutlass.Float32) + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) # tSrScale_r2t[0] = acc_scale # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() @@ -1322,7 +1295,7 @@ def softmax_step( # Sequence barrier wait if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) - tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, cutlass.Float32) + tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) tSrP_r2t = cute.make_tensor( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) @@ -1362,10 +1335,10 @@ def correction_loop( sO: cute.Tensor, tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, - softmax_scale_log2: cutlass.Float32, - tile_sched_params, + softmax_scale_log2: Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, ): tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) @@ -1391,11 +1364,11 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) - softmax_corr_consumer_phase = cutlass.Int32(0) - o_corr_consumer_phase = cutlass.Int32(0) - corr_epi_producer_phase = cutlass.Int32(1) + softmax_corr_consumer_phase = Int32(0) + o_corr_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1408,7 +1381,7 @@ def correction_loop( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase) softmax_corr_consumer_phase ^= 1 - tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, cutlass.Float32) + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): for stage in cutlass.range_constexpr(2): # wait for S0 / S1 @@ -1471,7 +1444,7 @@ def correction_loop( LN2 = math.log(2.0) lse = ( (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -cutlass.Float32.inf + if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: gLSE[tidx + stage * self.m_block_size] = lse @@ -1512,8 +1485,8 @@ def correction_rescale( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - thread_idx: cutlass.Int32, - scale: cutlass.Float32, + thread_idx: Int32, + scale: Float32, ): """Rescale intermediate attention results based on softmax normalization factor. @@ -1575,8 +1548,8 @@ def correction_epilogue( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - thread_idx: cutlass.Int32, - scale: cutlass.Float32, + thread_idx: Int32, + scale: Float32, sO: cute.Tensor, ): """Apply final scaling and transformation to attention output before writing to global memory. @@ -1597,7 +1570,7 @@ def correction_epilogue( :param tOtO: Tensor containing accumulated attention output :type tOtO: cute.Tensor :param scale: Final scaling factor to apply to the output - :type scale: cutlass.Float32 + :type scale: Float32 :param sO: Shared memory tensor for the final output :type sO: cute.Tensor """ @@ -1659,15 +1632,16 @@ def correction_epilogue( @cute.jit def epilogue_s2g( self, - tile_scheduler, mO: cute.Tensor, sO: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], mbar_ptr: cute.Pointer, SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, ): - epi_consumer_phase = cutlass.Int32(0) + epi_consumer_phase = Int32(0) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1736,7 +1710,7 @@ def load_K( tKgK: cute.Tensor, tKsK: cute.Tensor, pipeline: cutlass.pipeline.PipelineAsync, - block: cutlass.Int32, + block: Int32, producer_state: cutlass.pipeline.PipelineState, ): pipeline.producer_acquire(producer_state) From 3d0e14a79b3890b5f874f397aa64cb03fe061322 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Jul 2025 17:08:18 -0400 Subject: [PATCH 037/258] [Cute] Support varlen in flash_fwd_sm100 --- flash_attn/cute/flash_fwd_sm100.py | 85 +++++++++++++++++------------- flash_attn/cute/interface.py | 12 ++++- tests/cute/test_flash_attn.py | 15 ++++-- 3 files changed, 68 insertions(+), 44 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index a3380fedd2d..46c1a1c93d3 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -3,9 +3,9 @@ # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128. +# - varlen # - sliding window # Unsupported features that will be added later: -# - varlen # - split-kv (optimizing for inference) # - more hdim (192, 256) # Based on the cutlass example and cute-dsl example: @@ -210,7 +210,8 @@ def __call__( LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None # (s, d, h, b) -> (d, s, h, b) - mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=[1, 0, 2, 3])) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() @@ -796,36 +797,6 @@ def load( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - # (bM, bK, loopM, loopL) - gQ_qdhb = cute.local_tile(mQ, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0, None, None)) - tSgQ_qdhb = thr_mma_qk.partition_A(gQ_qdhb) - # (bN, bK, loopN, loopL) - gK_kdhb = cute.local_tile(mK, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None, None)) - tSgK_kdhb = thr_mma_qk.partition_B(gK_kdhb) - # (bK, bN, loopN, loopL) - gV_dkhb = cute.local_tile(mV, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None, None, None)) - tOgV_dkhb = thr_mma_pv.partition_B(gV_dkhb) - tQsQ, tQgQ_qdhb = cpasync.tma_partition( - tma_atom_Q, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sQ, 0, 3), - cute.group_modes(tSgQ_qdhb, 0, 3), - ) - tKsK, tKgK_kdhb = cpasync.tma_partition( - tma_atom_K, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK_kdhb, 0, 3), - ) - tVsV, tVgV_dkl = cpasync.tma_partition( - tma_atom_V, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sV, 0, 3), - cute.group_modes(tOgV_dkhb, 0, 3), - ) q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.kv_stage) @@ -833,9 +804,46 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx - tQgQ = tQgQ_qdhb[None, None, head_idx, batch_idx] - head_idx_kv = head_idx // self.qhead_per_kvhead - tKgK, tVgV = [t[None, None, head_idx_kv, batch_idx] for t in (tKgK_kdhb, tVgV_dkl)] + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) + mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) + + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + tSgQ = thr_mma_qk.partition_A(gQ) + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) + tSgK = thr_mma_qk.partition_B(gK) + gV = cute.local_tile(mV_cur, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None)) + tOgV = thr_mma_pv.partition_B(gV) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) def load_Q(stage: int): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_empty_offset + stage, q_producer_phase) @@ -851,7 +859,6 @@ def load_Q(stage: int): load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_kv) load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_kv) - seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) load_Q(0) # Q0 load_K(n_block_max - 1, kv_producer_state) # K0 @@ -1435,7 +1442,8 @@ def correction_loop( if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2,)) for stage in cutlass.range_constexpr(2): row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] @@ -1649,7 +1657,8 @@ def epilogue_s2g( if const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[None, None, head_idx, batch_idx] else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, None, head_idx]) + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) if const_expr(self.use_tma_O): tOsO, tOgO = cpasync.tma_partition( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 5816714a520..816df0e1cc7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,12 +1,22 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.dev0. -# Features not supported yet: + +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128. # - varlen +# - sliding window +# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) + +# Features not supported yet: # - split (i.e. FlashDecoding) # - tuned block sizes # - paged KV # - append KV to existing KV cache # - FP8 +# - bwd pass optimized for Hopper/Blackwell import math from typing import Optional, Tuple diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 16a1c3fa65c..fed0f365d47 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -238,10 +238,10 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -279,6 +279,8 @@ def test_flash_attn_output( def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): + if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) @@ -306,7 +308,7 @@ def test_flash_attn_varlen_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)) + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -343,6 +345,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) + if causal or local: + key_padding_mask = query_padding_mask + ( q_unpad, k_unpad, From 730e2309b8a2feaf9542dc5e55be62c739e611c1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Jul 2025 17:16:55 -0400 Subject: [PATCH 038/258] [Cute] Don't need max_seqlen_q for varlen fwd anymore --- flash_attn/cute/flash_fwd.py | 1 - flash_attn/cute/flash_fwd_sm100.py | 1 - flash_attn/cute/interface.py | 14 +++----------- tests/cute/test_flash_attn.py | 1 - 4 files changed, 3 insertions(+), 14 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index bc4b29b97c1..0226dfffaa9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1064,7 +1064,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - max_seqlen_q: Optional[cutlass.Int32] = None, softcap: cutlass.Float32 | float | None = None, window_size_left: cutlass.Int32 | int | None = None, window_size_right: cutlass.Int32 | int | None = None, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 46c1a1c93d3..001048f3c8c 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -173,7 +173,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - max_seqlen_q: Optional[Int32] = None, softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 816df0e1cc7..8ede8958dbe 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -56,7 +56,6 @@ def _flash_attn_fwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, @@ -77,7 +76,7 @@ def _flash_attn_fwd( total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 - seqlen_q = max_seqlen_q + seqlen_q = None total_q = q.shape[0] seqlen_k, num_head_kv, _ = k.shape[-3:] head_dim_v = v.shape[-1] @@ -89,7 +88,6 @@ def _flash_attn_fwd( assert v.shape == (seqlen_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" if cu_seqlens_q is not None: - assert max_seqlen_q is not None, "max_seqlen_q must be provided if cu_seqlens_q is provided" assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" @@ -130,7 +128,6 @@ def _flash_attn_fwd( from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] - max_seqlen_q = cutlass.Int32(max_seqlen_q) if max_seqlen_q is not None else None if causal: window_size_right = 0 local = window_size_left is not None or window_size_right is not None @@ -187,12 +184,12 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softcap, window_size_left, window_size_right, + softcap, window_size_left, window_size_right, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - max_seqlen_q, softcap, window_size_left, window_size_right, + softcap, window_size_left, window_size_right, ) return out, lse @@ -444,7 +441,6 @@ def forward( cu_seqlens_k: Optional[torch.Tensor], seqused_q: Optional[torch.Tensor], seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -458,7 +454,6 @@ def forward( cu_seqlens_k, seqused_q, seqused_k, - max_seqlen_q, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], @@ -466,7 +461,6 @@ def forward( softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) - ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size @@ -509,7 +503,6 @@ def flash_attn_varlen_func( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -523,7 +516,6 @@ def flash_attn_varlen_func( cu_seqlens_k, seqused_q, seqused_k, - max_seqlen_q, softmax_scale, causal, window_size, diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index fed0f365d47..f1e6f85e7ff 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -423,7 +423,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # max_seqlen_k, seqused_q=seqused_q, seqused_k=seqused_k, - max_seqlen_q=max_seqlen_q, causal=causal, # qv=qv_unpad, # q_descale=q_descale, From 10ee063e407035acc1719c5f980e2a62c2531242 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Jul 2025 19:19:50 -0400 Subject: [PATCH 039/258] [Cute] Fix varlen scheduler when SeqUsedQ is not passed in --- benchmarks/benchmark_attn.py | 5 ++++- flash_attn/cute/tile_scheduler.py | 5 ++--- tests/cute/test_flash_attn.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 8d4a5c0c0f7..b68220e5e47 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -359,7 +359,10 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if flash_attn_func_python is not None: - m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + if not varlen: + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + else: + m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: time.sleep(1) if not varlen: diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index e0bf202f022..ee64cbe7657 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -480,10 +480,9 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: else: assert self.mCuSeqlensQ is not None cur_cu_seqlen = Int32(0) - if batch_idx < self.num_batch: + if batch_idx <= self.num_batch: cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] - # Very important that we set mask_and_clamp to 0 - next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1, mask_and_clamp=0) + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) seqlen = next_cu_seqlen - cur_cu_seqlen if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): seqlen *= self.qhead_per_kvhead_packgqa diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index f1e6f85e7ff..848c68eb8a1 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -421,8 +421,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, # max_seqlen_k, - seqused_q=seqused_q, - seqused_k=seqused_k, + # seqused_q=seqused_q, + # seqused_k=seqused_k, causal=causal, # qv=qv_unpad, # q_descale=q_descale, From c5b0c631074e4c8d53fdebea2d71ea621baf9344 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Jul 2025 22:55:54 -0400 Subject: [PATCH 040/258] [Cute] Use LPT for SingleTileVarlenScheduler --- benchmarks/benchmark_attn.py | 4 ++- flash_attn/cute/flash_fwd.py | 1 + flash_attn/cute/flash_fwd_sm100.py | 1 + flash_attn/cute/tile_scheduler.py | 39 ++++++++++++++++++++++++++++-- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index b68220e5e47..b08a9c84dcf 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -27,8 +27,10 @@ from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python try: from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 except ImportError: flash_attn_func_v3 = None + flash_attn_varlen_func_v3 = None if torch.cuda.get_device_capability()[0] != 9: flash_attn_func_v3 = None @@ -355,7 +357,7 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: - m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if flash_attn_func_python is not None: diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 0226dfffaa9..3c0651f7893 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1165,6 +1165,7 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, element_size=self.dtype.width // 8, is_persistent=False, + lpt=self.is_causal or self.is_local, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 001048f3c8c..dfac68787d2 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -365,6 +365,7 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, + lpt=self.is_causal or self.is_local, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ee64cbe7657..c7fad36b22a 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -50,6 +50,7 @@ class TileSchedulerArguments(ParamsBase): qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 element_size: cutlass.Constexpr[int] = 2 is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -391,41 +392,50 @@ class Params(ParamsBase): num_head: Int32 num_batch: Int32 total_q: Int32 + max_kvblock_in_l2: Int32 block_size: cutlass.Constexpr[int] mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.block_size) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, total_q=args.total_q, + max_kvblock_in_l2=max_kvblock_in_l2, block_size=args.block_size, mCuSeqlensQ=args.mCuSeqlensQ, mSeqUsedQ=args.mSeqUsedQ, qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, ) def __init__( self, num_head: Int32, num_batch: Int32, + max_kvblock_in_l2: Int32, tile_idx: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, block_size: cutlass.Constexpr[int] = 128, qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, + lpt: cutlass.Constexpr[bool] = False, *, loc=None, ip=None, ): self.num_head = num_head self.num_batch = num_batch + self.max_kvblock_in_l2 = max_kvblock_in_l2 self.mCuSeqlensQ = mCuSeqlensQ self.mSeqUsedQ = mSeqUsedQ assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( @@ -433,6 +443,7 @@ def __init__( ) self.block_size = block_size self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + self.lpt = lpt self._tile_idx = tile_idx self._is_first_block = True self._loc = loc @@ -448,11 +459,13 @@ def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": return SingleTileVarlenScheduler( params.num_head, params.num_batch, + params.max_kvblock_in_l2, tile_idx, mCuSeqlensQ=params.mCuSeqlensQ, mSeqUsedQ=params.mSeqUsedQ, block_size=params.block_size, qhead_per_kvhead_packgqa=params.qhead_per_kvhead_packgqa, + lpt=params.lpt, loc=loc, ip=ip, ) @@ -537,8 +550,27 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: ) num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * self.num_head - head_idx = mh_block // num_m_blocks - block = mh_block - head_idx * num_m_blocks + if cutlass.const_expr(self.lpt): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_m_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = 16 if num_m_blocks * 16 <= self.max_kvblock_in_l2 else (8 if num_m_blocks * 8 <= self.max_kvblock_in_l2 else (4 if num_m_blocks * 4 <= self.max_kvblock_in_l2 else (2 if num_m_blocks * 2 <= self.max_kvblock_in_l2 else 1))) + nheads_in_l2 = min(nheads_in_l2, self.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= self.num_head else self.num_head - section_idx * nheads_in_l2 + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks is_valid = self._is_first_block and batch_idx < self.num_batch # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) return cutlass.utils.WorkTileInfo( @@ -560,6 +592,7 @@ def __extract_mlir_values__(self): for obj in [ self.num_head, self.num_batch, + self.max_kvblock_in_l2, self._tile_idx, self.mCuSeqlensQ, self.mSeqUsedQ, @@ -575,6 +608,7 @@ def __new_from_mlir_values__(self, values): [ self.num_head, self.num_batch, + self.max_kvblock_in_l2, self._tile_idx, self.mCuSeqlensQ, self.mSeqUsedQ, @@ -587,5 +621,6 @@ def __new_from_mlir_values__(self, values): *(tuple(obj_list)), block_size=self.block_size, qhead_per_kvhead_packgqa=self.qhead_per_kvhead_packgqa, + lpt=self.lpt, loc=self._loc, ) From bac1001e4f6caa09d70537495d6746a685a2fa78 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 00:51:41 -0400 Subject: [PATCH 041/258] [Cute] Use bit manipulation for masking in sm100 --- flash_attn/cute/flash_fwd_sm100.py | 28 +++++++------ flash_attn/cute/mask.py | 65 ++++++++++++++++++++++++------ flash_attn/cute/softmax.py | 1 + flash_attn/cute/utils.py | 11 +++-- 4 files changed, 74 insertions(+), 31 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index dfac68787d2..a08871637b7 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -127,19 +127,21 @@ def __init__( self.tmem_vec0_offset = 0 self.tmem_vec1_offset = self.tmem_vec0_offset + self.n_block_size - # self.num_regs_softmax = 192 - # self.num_regs_softmax = 184 - self.num_regs_softmax = 176 - # self.num_regs_correction = 104 - # self.num_regs_correction = 96 - # self.num_regs_correction = 80 - self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 - # self.num_regs_other = 24 - # self.num_regs_other = 32 - # self.num_regs_other = 64 - # self.num_regs_other = 80 - self.num_regs_other = 96 if self.is_causal or self.is_local else 80 - # self.num_regs_other = 48 + if self.head_dim_padded < 96: + self.num_regs_softmax = 192 + self.num_regs_correction = 64 + self.num_regs_other = 64 + else: + # self.num_regs_softmax = 184 + self.num_regs_softmax = 176 + # self.num_regs_correction = 96 + # self.num_regs_correction = 80 + self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + # self.num_regs_other = 80 + # self.num_regs_other = 48 + self.num_regs_other = 96 if self.is_causal or self.is_local else 80 self.num_regs_empty = 24 self.buffer_align_bytes = 1024 diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 89ce612c6ec..ab795c15da0 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -135,13 +135,39 @@ def apply_mask_sm100( seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): - for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): - # if tScS_t2r[i][1] >= seqlenk_col_limit: - # acc_S[i] = -cutlass.Float32.inf - # For some reason the 2 lines above generate really bad SASS - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] - ) + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(not ncol % 16 == 0): + for i in cutlass.range_constexpr(ncol): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -cutlass.Float32.inf + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + # Ideally we'd move by 32 instead of 16, but mask >> i isn't correct for i == 31 + # (see below). + for s in cutlass.range_constexpr(ncol // 16): + col_limit_right_s = seqlenk_col_limit - s * 16 + # Don't need to clamp to 32 since the shr.u32 instruction does that already + col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) + # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_cur = %d", mask, col_limit_right_s, col_limit_right_cur) + for i in cutlass.range_constexpr(16): + # mask >> i does not produce correct result for 0b11..11 >> 31 + # However, if we use utils.shr_u32, the compiler doesn't generate + # the R2P instruction, so it's slower. + # Instead we just move by 16 instead of 32. + mask_i_bit = cutlass.Boolean((mask >> i) & 1) + # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) + # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) + acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf + # This is the equivalent of: + # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + # if tidx == 0: cute.print_tensor(acc_S) else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.m_block_size @@ -153,11 +179,26 @@ def apply_mask_sm100( col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) - for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] - ) - + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(not ncol % 16 == 0): + for i in cutlass.range_constexpr(ncol): + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + for s in cutlass.range_constexpr(ncol // 16): + col_limit_right_s = col_limit_right - s * 16 + col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) + for i in cutlass.range_constexpr(16): + # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) + mask_i_bit = cutlass.Boolean((mask >> i) & 1) + acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf + # This is the equivalent of: + # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf else: local_row_offset_right = ( causal_row_offset + self.window_size_right diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index dfbfa708fc8..bf98cf9126e 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -194,6 +194,7 @@ def apply_exp2_convert( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) + @cute.jit def scale_apply_exp2_convert( self, acc_S_row: cute.Tensor, diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 4b2fe92bac5..df6ad0fe3b3 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -443,14 +443,13 @@ def shuffle_sync( @dsl_user_op -def noop_asm(val: cutlass.Int32, *, loc=None, ip=None) -> cute.Numeric: - assert val.width == 32, "noop_asm only supports 32-bit types" - return type(val)( +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + return cutlass.Uint32( llvm.inline_asm( T.i32(), - [cutlass.Int32(val).ir_value(loc=loc, ip=ip)], - "mov.b32 $0, $1;", - "=r,r", + [cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)], + "shr.s32 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, From b959a98990035f09cf366ab3f043166def55571c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 01:27:38 -0400 Subject: [PATCH 042/258] [Cute] Don't need a separate masking iter if causal for fwd_sm100 --- flash_attn/cute/flash_fwd_sm100.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index a08871637b7..c887e6eee4d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -87,7 +87,8 @@ def __init__( self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False - self.s0_s1_barrier = self.head_dim_padded in [64, 96] # Does S1 need to wait for S0 to finish + # Does S1 need to wait for S0 to finish + self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) @@ -1170,17 +1171,20 @@ def softmax_loop( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) si_corr_producer_phase ^= 1 - # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): + if const_expr(not (self.is_causal or self.is_local)): + # 1 masking iter + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + n_block_max -= 1 + else: + # Next couple of iterations with causal masking + # Careful, we're not setting is_first=True for any iteration here. + # Currently this doesn't matter, but we might change the synchronization later n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=True)) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( @@ -1194,7 +1198,8 @@ def softmax_loop( n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=True)) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) # tSrScale_r2t[0] = softmax.row_sum[0] From ed6964c01298105732b6a6b8e8693223939a0494 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 01:32:45 -0400 Subject: [PATCH 043/258] [Cute] Back to having a separate iteration with masking a couple of failing varlen tests if we don't have that, will investigate later --- flash_attn/cute/flash_fwd_sm100.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c887e6eee4d..b2b6c6c58ed 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1171,20 +1171,17 @@ def softmax_loop( cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) si_corr_producer_phase ^= 1 - if const_expr(not (self.is_causal or self.is_local)): - # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) - n_block_max -= 1 - else: - # Next couple of iterations with causal masking - # Careful, we're not setting is_first=True for any iteration here. - # Currently this doesn't matter, but we might change the synchronization later + # 1 masking iter + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=True)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( @@ -1198,7 +1195,7 @@ def softmax_loop( n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=True)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) # Now that we no longer already have the 1st iteration, need mask_seqlen=True here # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) From c909b679e0321e610a8b97d7a517d08355ad0b5a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 16:31:04 -0400 Subject: [PATCH 044/258] [Cute] Try e2e --- flash_attn/cute/flash_fwd_sm100.py | 2 +- flash_attn/cute/softmax.py | 15 ++++++- flash_attn/cute/utils.py | 65 ++++++++++++++++++++++++++++-- 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index b2b6c6c58ed..d8b86b612b8 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1311,7 +1311,7 @@ def softmax_step( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and (self.is_causal or self.is_local)) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index bf98cf9126e..e7b8f913ebf 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -174,6 +174,10 @@ def apply_exp2_convert( self, acc_S_row: cute.Tensor, acc_S_row_converted: cute.Tensor, + e2e: cutlass.Constexpr[bool] = False, + e2e_freq: cutlass.Constexpr[bool] = 8, + e2e_res: cutlass.Constexpr[bool] = 2, + e2e_frg_limit: cutlass.Constexpr[bool] = 1, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" frg_tile = 32 @@ -188,8 +192,15 @@ def apply_exp2_convert( for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) - acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) - acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + if cutlass.const_expr(not e2e): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + if cutlass.const_expr(k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index df6ad0fe3b3..1819446809f 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -1,14 +1,14 @@ # Copyright (c) 2025, Tri Dao. import math -from typing import Type, Callable, Optional +from typing import Type, Callable, Optional, Tuple import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, Int32 from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import nvvm, llvm +from cutlass._mlir.dialects import nvvm, llvm, arith, vector from cutlass.cute.runtime import from_dlpack @@ -498,3 +498,62 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor): assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) for i in cutlass.range_constexpr(cute.size(dst_i32)): dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@dsl_user_op +def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip) + vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1) + res0 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + res1 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + ) + return res0, res1 + + +@cute.jit +def e2e_asm2(x: Float32, y: Float32) -> Tuple[Float32, Float32]: + out_i64 = cutlass.Int64( + llvm.inline_asm( + T.i64(), + [Float32(x).ir_value(), Float32(y).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $1, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $2, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.u32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.u32 r8, r6, r4;\n\t" + "mov.b64 $0, {r7, r8};\n\t" + "}\n", + "=l,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return i64_to_f32x2(out_i64) From 75c7d998c60973c35f032ffabbeba5e9f4fa8567 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 16:31:32 -0400 Subject: [PATCH 045/258] [Cute] Bench hdim 64 --- benchmarks/benchmark_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index b08a9c84dcf..85f86282ce6 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -254,7 +254,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [128]: +for headdim in [64]: nheads = dim // headdim # nheads = 128 # headdim = 64 From 5639535e8814fd57c29683a333adbf379dfa4411 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 16:51:25 -0400 Subject: [PATCH 046/258] [Cute] Bench both hdim 64 and 128 --- benchmarks/benchmark_attn.py | 2 +- flash_attn/cute/flash_fwd_sm100.py | 2 +- flash_attn/cute/softmax.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 85f86282ce6..2107c6c0026 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -254,7 +254,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [64]: +for headdim in [64, 128]: nheads = dim // headdim # nheads = 128 # headdim = 64 diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index d8b86b612b8..96fd560f463 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1311,7 +1311,7 @@ def softmax_step( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and (self.is_causal or self.is_local)) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None, e2e_freq=16 if self.head_dim_padded <= 64 else 32) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index e7b8f913ebf..fa955290426 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -175,8 +175,8 @@ def apply_exp2_convert( acc_S_row: cute.Tensor, acc_S_row_converted: cute.Tensor, e2e: cutlass.Constexpr[bool] = False, - e2e_freq: cutlass.Constexpr[bool] = 8, - e2e_res: cutlass.Constexpr[bool] = 2, + e2e_freq: cutlass.Constexpr[bool] = 32, + e2e_res: cutlass.Constexpr[bool] = 4, e2e_frg_limit: cutlass.Constexpr[bool] = 1, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" From 5d98558b557ba975d751e10c7c8c3939497551e2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 17:06:04 -0400 Subject: [PATCH 047/258] [Cute] Tune num regs --- flash_attn/cute/flash_fwd_sm100.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 96fd560f463..414bf3c6df9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -133,16 +133,18 @@ def __init__( self.num_regs_correction = 64 self.num_regs_other = 64 else: - # self.num_regs_softmax = 184 - self.num_regs_softmax = 176 + self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + # self.num_regs_softmax = 176 # self.num_regs_correction = 96 # self.num_regs_correction = 80 - self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + self.num_regs_correction = 64 # self.num_regs_other = 32 # self.num_regs_other = 64 # self.num_regs_other = 80 # self.num_regs_other = 48 - self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + self.num_regs_other = 64 if self.is_causal or self.is_local else 80 self.num_regs_empty = 24 self.buffer_align_bytes = 1024 @@ -1311,7 +1313,7 @@ def softmax_step( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None, e2e_freq=16 if self.head_dim_padded <= 64 else 32) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None, e2e_freq=16 if self.head_dim_padded <= 64 else 16) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) From 50e0736f45a11f0e6d4e37a6cce59c8bff98b3c3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 17:55:05 -0400 Subject: [PATCH 048/258] [Cute] Tune regs a bit --- flash_attn/cute/flash_fwd_sm100.py | 7 ++++--- flash_attn/cute/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 414bf3c6df9..2375c3ebdaa 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -88,7 +88,8 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False # Does S1 need to wait for S0 to finish - self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + self.s0_s1_barrier = False self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) @@ -129,9 +130,9 @@ def __init__( self.tmem_vec1_offset = self.tmem_vec0_offset + self.n_block_size if self.head_dim_padded < 96: - self.num_regs_softmax = 192 + self.num_regs_softmax = 200 self.num_regs_correction = 64 - self.num_regs_other = 64 + self.num_regs_other = 48 else: self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 # self.num_regs_softmax = 176 diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 1819446809f..fbd836be1d9 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -545,9 +545,9 @@ def e2e_asm2(x: Float32, y: Float32) -> Tuple[Float32, Float32]: "mov.b64 {r1, r2}, l7;\n\t" "mov.b64 {r3, r4}, l10;\n\t" "shl.b32 r5, r1, 23;\n\t" - "add.u32 r7, r5, r3;\n\t" + "add.s32 r7, r5, r3;\n\t" "shl.b32 r6, r2, 23;\n\t" - "add.u32 r8, r6, r4;\n\t" + "add.s32 r8, r6, r4;\n\t" "mov.b64 $0, {r7, r8};\n\t" "}\n", "=l,f,f", From 34a3656b70711aed2383c4d486186e68ac1a2619 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 15 Jul 2025 18:43:49 -0400 Subject: [PATCH 049/258] [Cute] Bench multiple seqlens --- benchmarks/benchmark_attn.py | 10 +++++----- flash_attn/cute/softmax.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 2107c6c0026..bad67de2097 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -240,10 +240,10 @@ def run(*args, **kwargs): headdim = 256 # for headdim in [64, 128, 256]: # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] -# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] # bs_seqlen_vals = [(32, 512), (16, 1024)] # bs_seqlen_vals = [(2, 64 * 132)] -bs_seqlen_vals = [(4, 8192)] +# bs_seqlen_vals = [(4, 8192)] # bs_seqlen_vals = [(1, 16 * 1024)] time_f = {} time_b = {} @@ -254,7 +254,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [64, 128]: +for headdim in [128]: nheads = dim // headdim # nheads = 128 # headdim = 64 @@ -312,8 +312,8 @@ def run(*args, **kwargs): else: page_table = None - for causal in [False, True]: - # for causal in [False]: + # for causal in [False, True]: + for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index fa955290426..5799cd4bd98 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -175,9 +175,9 @@ def apply_exp2_convert( acc_S_row: cute.Tensor, acc_S_row_converted: cute.Tensor, e2e: cutlass.Constexpr[bool] = False, - e2e_freq: cutlass.Constexpr[bool] = 32, - e2e_res: cutlass.Constexpr[bool] = 4, - e2e_frg_limit: cutlass.Constexpr[bool] = 1, + e2e_freq: cutlass.Constexpr[int] = 16, + e2e_res: cutlass.Constexpr[int] = 4, + e2e_frg_limit: cutlass.Constexpr[int] = 1, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" frg_tile = 32 From 24f0957be6cff1bf9ad9a65939d56227b92ad3d0 Mon Sep 17 00:00:00 2001 From: One Date: Wed, 23 Jul 2025 01:36:36 +0800 Subject: [PATCH 050/258] Revert "[BE] Better compress flash attention binaries (#1744)" (#1751) This reverts commit 8ba246f6cc8813d41f9289e2781b7d8fa22a97cb. --- hopper/setup.py | 3 --- setup.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/hopper/setup.py b/hopper/setup.py index 10894252db0..c15c438f56c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -524,9 +524,6 @@ def nvcc_threads_args(): "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted - "-Xfatbin", # compress all binary sections - "-compress-all", - "-compress-mode=size", # compress with CUDA fatbin more aggressively ] if get_platform() == "win_amd64": nvcc_flags.extend( diff --git a/setup.py b/setup.py index d54e93f6649..cafc818fa2c 100644 --- a/setup.py +++ b/setup.py @@ -206,9 +206,6 @@ def validate_and_update_archs(archs): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "-Xfatbin", - "-compress-all", - "-compress-mode=size", # "--ptxas-options=-v", # "--ptxas-options=-O2", # "-lineinfo", From 7321879fde54f09ed94f7f6ce9377e2f4cf1fac0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 23 Jul 2025 22:44:59 -0700 Subject: [PATCH 051/258] Bump to v2.8.2 --- flash_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index fa45a44cbe1..69eae460e36 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.8.1" +__version__ = "2.8.2" from flash_attn.flash_attn_interface import ( flash_attn_func, From 413d07e9deef1e3c793c7de59d7146b43ae4d558 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 31 Jul 2025 06:09:21 +0800 Subject: [PATCH 052/258] [AMD ROCm] Fix compilation issue in gfx942 (#1787) * update ck * Set default head dim, some instances might have bug * update ck * To pass the test --- csrc/composable_kernel | 2 +- setup.py | 9 +++++---- tests/test_flash_attn_ck.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 663992e99b4..e8709c24f40 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 663992e99b412991eab554b0deb89bb916d40161 +Subproject commit e8709c24f403173ad21a2da907d1347957e324fb diff --git a/setup.py b/setup.py index cafc818fa2c..a108c412c00 100644 --- a/setup.py +++ b/setup.py @@ -325,10 +325,11 @@ def validate_and_update_archs(archs): if not os.path.exists("./build"): os.makedirs("build") - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2"], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2"], check=True) + optdim = os.getenv("OPT_DIM", "32,64,128,256") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c3..d5590fcfc82 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1399,7 +1399,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + assert (q.grad - q_ref.grad).abs().max().item() <= 7 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( From 1a15733e52b86d4264f8a78bda8d54365ebc2b45 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 1 Aug 2025 13:32:57 -0400 Subject: [PATCH 053/258] [Cute] Support hdim_v != hdim_qk --- flash_attn/cute/flash_fwd_sm100.py | 108 +++++++++++++++++------------ tests/cute/test_flash_attn.py | 7 +- 2 files changed, 69 insertions(+), 46 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 2375c3ebdaa..7681e0e3523 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -69,8 +69,8 @@ def __init__( self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v - assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded self.check_hdim_oob = head_dim != self.head_dim_padded self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.m_block_size = m_block_size @@ -78,7 +78,7 @@ def __init__( # 2 Q tile per CTA self.cta_tiler = (2 * m_block_size, n_block_size, self.head_dim_padded) self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) - self.pv_mma_tiler = (m_block_size, self.head_dim_v_padded, n_block_size) + self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) self.qk_acc_dtype = Float32 self.pv_acc_dtype = Float32 self.cluster_shape_mn = (1, 1) @@ -256,7 +256,7 @@ def __call__( self.v_major_mode, self.pv_acc_dtype, cta_group, - self.pv_mma_tiler[:2], + self.mma_tiler_pv[:2], p_source, ) @@ -266,7 +266,7 @@ def __call__( (tiled_mma_qk.thr_id.shape,), ) - self.epi_tile = self.pv_mma_tiler[:2] + self.epi_tile = self.mma_tiler_pv[:2] sQ_layout = sm100_utils_basic.make_smem_layout_a( tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, @@ -275,14 +275,19 @@ def __call__( tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, ) tP_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_pv, self.pv_mma_tiler, self.q_dtype, self.acc_stage, + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.acc_stage, ) sV_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pv, self.pv_mma_tiler, self.v_dtype, self.kv_stage, + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage, ) sO_layout = sm100_utils_basic.make_smem_layout_epi( self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, ) + if const_expr(not self.same_hdim_kv_padded): + # sK and sV are using the same physical smem so we need to adjust the stride so that they line up + stage_stride = const_expr(max(sK_layout.outer.stride[-1], sV_layout.outer.stride[-1])) + sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) + sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) @@ -311,7 +316,7 @@ def __call__( tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), - self.pv_mma_tiler, + self.mma_tiler_pv, tiled_mma_pv, self.cluster_layout_vmnk.shape, ) @@ -348,7 +353,8 @@ def __call__( gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) - self.tma_copy_kv_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + self.tma_copy_k_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + self.tma_copy_v_bytes = cute.size_in_bytes(self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -594,7 +600,7 @@ def kernel( assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) - pv_acc_shape = thr_mma_pv.partition_shape_C((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + pv_acc_shape = thr_mma_pv.partition_shape_C((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) @@ -827,7 +833,7 @@ def load( tSgQ = thr_mma_qk.partition_A(gQ) gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) tSgK = thr_mma_qk.partition_B(gK) - gV = cute.local_tile(mV_cur, cute.select(self.pv_mma_tiler, mode=[1, 2]), (0, None)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) tOgV = thr_mma_pv.partition_B(gV) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, @@ -851,33 +857,38 @@ def load( cute.group_modes(tOgV, 0, 3), ) - def load_Q(stage: int): - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_empty_offset + stage, q_producer_phase) - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr + self.mbar_load_q_full_offset + stage, self.tma_copy_q_bytes) - cute.copy( - tma_atom_Q, - tQgQ[None, 2 * m_block + stage], - tQsQ[None, stage], - tma_bar_ptr=mbar_ptr + self.mbar_load_q_full_offset + stage, - ) - - load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_kv) - load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_kv) + load_Q = partial( + self.load_QKV, tma_atom_Q, tQgQ, tQsQ, + mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, + self.tma_copy_q_bytes, + phase=q_producer_phase, + ) + # We have to use mbarrier directly in the load for KV instead of replying on + # pipeline_kv, because we could have different number of TMA bytes for K and V + load_K = partial( + self.load_QKV, tma_atom_K, tKgK, tKsK, + mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.tma_copy_k_bytes + ) + load_V = partial( + self.load_QKV, tma_atom_V, tVgV, tVsV, + mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.tma_copy_v_bytes + ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - load_Q(0) # Q0 - load_K(n_block_max - 1, kv_producer_state) # K0 + load_Q(block=2 * m_block + 0, stage=0) # Q0 + load_K(block=n_block_max - 1, producer_state=kv_producer_state) # K0 kv_producer_state.advance() - load_Q(1) # Q1 + load_Q(block=2 * m_block + 1, stage=1) # Q1 q_producer_phase ^= 1 - load_V(n_block_max - 1, kv_producer_state) # V0 + load_V(block=n_block_max - 1, producer_state=kv_producer_state) # V0 kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i - load_K(n_block, kv_producer_state) # Ki + load_K(block=n_block, producer_state=kv_producer_state) # Ki kv_producer_state.advance() - load_V(n_block, kv_producer_state) # Vi + load_V(block=n_block, producer_state=kv_producer_state) # Vi kv_producer_state.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1468,7 +1479,7 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 - # gO_qdhb = cute.local_tile(mO, cute.select(self.pv_mma_tiler, mode=[0, 1]), (None, 0, None, None)) + # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) # gO = gO_qdhb[None, None, None, head_idx, batch_idx] # tOsO, tOgO = cpasync.tma_partition( # tma_atom_O, @@ -1515,7 +1526,7 @@ def correction_rescale( 2. Apply the scaling factor to all elements 3. Store the rescaled results back to tensor memory """ - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) tOcO = thr_mma.partition_C(cO) corr_tile_size = 16 # tuneable parameter @@ -1590,7 +1601,7 @@ def correction_epilogue( :type sO: cute.Tensor """ - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) corr_tile_size = 32 * 8 // self.o_dtype.width tOsO = thr_mma.partition_C(sO) tOcO = thr_mma.partition_C(cO) @@ -1601,7 +1612,7 @@ def correction_epilogue( epi_subtile = (self.epi_tile[0], corr_tile_size) tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( - self.pv_mma_tiler, + self.mma_tiler_pv, self.o_layout, self.o_dtype, self.pv_acc_dtype, @@ -1719,22 +1730,31 @@ def epilogue_s2g( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - # @cute.jit - def load_K( + def load_QKV( self, tma_atom: cute.CopyAtom, - tKgK: cute.Tensor, - tKsK: cute.Tensor, - pipeline: cutlass.pipeline.PipelineAsync, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + tma_copy_bytes: int, block: Int32, - producer_state: cutlass.pipeline.PipelineState, + producer_state: Optional[cutlass.pipeline.PipelineState] = None, + stage: Optional[Int32] = None, + phase: Optional[Int32] = None, ): - pipeline.producer_acquire(producer_state) + if cutlass.const_expr(producer_state is not None): + stage, phase = producer_state.index, producer_state.phase + else: + assert stage is not None and phase is not None, "stage and phase must be provided if producer_state is None" + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) cute.copy( tma_atom, - tKgK[None, block], - tKsK[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + tXgX[None, block], + tXsX[None, stage], + tma_bar_ptr=mbar_full_ptr + stage, ) def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): @@ -1746,7 +1766,7 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): num_stages=self.kv_stage, producer_group=load_kv_producer_group, consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_kv_bytes, + tx_count=self.tma_copy_k_bytes, ) # @cute.jit diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 848c68eb8a1..253f1fd7007 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -81,7 +81,7 @@ def test_flash_attn_output( nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - dv_vals = [d] + dv_vals = [d] if d != 128 else [64, d] if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] @@ -188,6 +188,7 @@ def test_flash_attn_output( and not attention_chunk != 0 and softcap == 0.0 and not local + and dv == d # and False ): g = torch.randn_like(out) @@ -290,7 +291,8 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] if d != 128 else [64, d] if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] @@ -450,6 +452,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_qv and not dv > 256 and not attention_chunk != 0 + and dv == d and False ): g_unpad = torch.randn_like(out_unpad) From 1b36ab19c8f5f666e99196f2474803d01b9cdc74 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 1 Aug 2025 17:12:54 -0400 Subject: [PATCH 054/258] [Cute] Support hdim (192,128) --- benchmarks/benchmark_attn.py | 2 +- flash_attn/cute/flash_fwd_sm100.py | 125 ++++++++++++++++++++++------- tests/cute/test_flash_attn.py | 8 +- 3 files changed, 102 insertions(+), 33 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index bad67de2097..289518822ab 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -387,7 +387,7 @@ def run(*args, **kwargs): print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') if has_backward: print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if cudnn is not None: + if cudnn is not None and headdim == headdim_v: print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') if has_backward: print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7681e0e3523..fd94f6e3b62 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -2,7 +2,7 @@ # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA -# - hdim 64, 96, 128. +# - hdim 64, 96, 128, (192, 128). # - varlen # - sliding window # Unsupported features that will be added later: @@ -90,6 +90,10 @@ def __init__( # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False + self.overlap_sO_sQ = self.head_dim_padded == 192 and self.head_dim_v_padded >= 64 + if self.overlap_sO_sQ: + assert self.head_dim_padded >= self.head_dim_v_padded # We assume sQ is larger than sO + self.is_persistent = False self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) @@ -162,8 +166,20 @@ def _setup_attributes(self): self.q_stage = 2 self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + # TODO: temp solution to get this to run as uneven_kv_smem isn't working yet + if self.head_dim_padded == 192 and self.head_dim_v_padded == 128: + self.kv_stage = 2 self.acc_stage = 1 self.epi_stage = 2 + # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + # 128 x 192 x 2 bytes x 3 stages = 144KB, as we need 64KB for Q and 64 KB for O. + # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is + # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be + # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, + # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. + self.uneven_kv_smem = self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + self.uneven_kv_smem_offset = self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 + assert self.uneven_kv_smem_offset % 1024 == 0 @cute.jit def __call__( @@ -285,7 +301,9 @@ def __call__( ) if const_expr(not self.same_hdim_kv_padded): # sK and sV are using the same physical smem so we need to adjust the stride so that they line up - stage_stride = const_expr(max(sK_layout.outer.stride[-1], sV_layout.outer.stride[-1])) + stride_sK = const_expr(max(sK_layout.outer.stride[-1], 0)) # take max to turn tuple to Int32 + stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) + stage_stride = const_expr(max(stride_sK, stride_sV) if not self.uneven_kv_smem else (stride_sK + stride_sV) // 2) sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) @@ -399,6 +417,8 @@ def __call__( self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 self.mbar_total = self.mbar_P_full_2_offset + 2 + sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines @@ -408,7 +428,7 @@ class SharedStorage: # Smem tensors sScale: cute.struct.MemRange[Float32, 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2)] sO: cute.struct.Align[ - cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], + cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes, ] sQ: cute.struct.Align[ @@ -416,6 +436,7 @@ class SharedStorage: self.buffer_align_bytes, ] sK: cute.struct.Align[ + # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], self.buffer_align_bytes, ] @@ -586,7 +607,10 @@ def kernel( # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) - sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + if const_expr(not self.overlap_sO_sQ): + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + else: + sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer) sScale = storage.sScale.get_tensor(cute.make_layout(256)) @@ -858,7 +882,7 @@ def load( ) load_Q = partial( - self.load_QKV, tma_atom_Q, tQgQ, tQsQ, + self.load_Q, tma_atom_Q, tQgQ, tQsQ, mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, self.tma_copy_q_bytes, phase=q_producer_phase, @@ -866,12 +890,12 @@ def load( # We have to use mbarrier directly in the load for KV instead of replying on # pipeline_kv, because we could have different number of TMA bytes for K and V load_K = partial( - self.load_QKV, tma_atom_K, tKgK, tKsK, + self.load_KV, tma_atom_K, tKgK, tKsK, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, self.tma_copy_k_bytes ) load_V = partial( - self.load_QKV, tma_atom_V, tVgV, tVsV, + self.load_KV, tma_atom_V, tVgV, tVsV, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, self.tma_copy_v_bytes ) @@ -974,7 +998,10 @@ def mma( # of the while loop. # 3. gemm # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) - gemm_Si[stage](tCrB=tSrKi, sB=sK[None, None, None, mma_kv_consumer_state.index]) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) # 4. release S0 / S1 with cute.arch.elect_one(): tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) @@ -993,7 +1020,7 @@ def mma( # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) mma_kv_release_state = mma_kv_consumer_state.clone() - Vi_index = mma_kv_consumer_state.index + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(2): # 2. acquire corrected O0/O1_partial and P0 / P1 @@ -1004,7 +1031,10 @@ def mma( # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) - gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase= P_full_O_rescaled_phase) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase= P_full_O_rescaled_phase) # 4. release accumulated O0_partial / O1_partial # Don't need to signal O_full to the correction warps anymore since the # correction warps wait for the softmax warps anyway. By the time the softmax @@ -1023,13 +1053,16 @@ def mma( if const_expr(stage == 0): mma_kv_consumer_state.advance() pipeline_kv.consumer_wait(mma_kv_consumer_state) - Ki_index = mma_kv_consumer_state.index + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase # 2. gemm # Don't need to wait for the softmax warp to have finished reading the previous # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si # has been read and Pi has been written. # sm100_utils.gemm(tiled_mma_qk, tStS0, tSrQs[0], tSrK[None, None, None, Ki_index], zero_init=True) - gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK[None, None, None, Ki_index]) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) # 3. release S0 with cute.arch.elect_one(): tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) @@ -1049,7 +1082,7 @@ def mma( # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) - Vi_index = mma_kv_consumer_state.index + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(2): # 2. acquire corrected Oi_partial and Pi @@ -1057,7 +1090,10 @@ def mma( # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) - gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp # has signaled to the correction warp, the softmax warp has just finished compute @@ -1431,6 +1467,9 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. stats = [None, None] for stage in cutlass.range_constexpr(2): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) @@ -1730,33 +1769,63 @@ def epilogue_s2g( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - def load_QKV( + def load_Q( self, tma_atom: cute.CopyAtom, - tXgX: cute.Tensor, - tXsX: cute.Tensor, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, tma_copy_bytes: int, block: Int32, - producer_state: Optional[cutlass.pipeline.PipelineState] = None, - stage: Optional[Int32] = None, - phase: Optional[Int32] = None, + stage: int, + phase: int, ): - if cutlass.const_expr(producer_state is not None): - stage, phase = producer_state.index, producer_state.phase - else: - assert stage is not None and phase is not None, "stage and phase must be provided if producer_state is None" cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) cute.copy( - tma_atom, - tXgX[None, block], - tXsX[None, stage], - tma_bar_ptr=mbar_full_ptr + stage, + tma_atom, tQgQ[None, block], tQsQ[None, stage], tma_bar_ptr=mbar_full_ptr + stage ) + @cute.jit + def load_KV( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + tma_copy_bytes: int, + block: Int32, + producer_state: cutlass.pipeline.PipelineState, + ): + stage, phase = producer_state.index, producer_state.phase + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) + tXsX_cur = tXsX[None, stage] + # print(tXsX_cur) + if const_expr(self.uneven_kv_smem): + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase) + # print(tXsX_cur) + cute.copy(tma_atom, tXgX[None, block], tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + + @cute.jit + # def offset_kv_smem(self, sX: cute.Tensor, state: cutlass.pipeline.PipelineState): + def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): + if const_expr(self.uneven_kv_smem): + # smem layout is [smem_large, smem_small, smem_large], and the current stride is + # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if + # phase == 0, or left by offset if phase == 1. + # stage, phase = state.index, state.phase + offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + return cute.make_tensor(sX.iterator + offset, sX.layout) + # new_ptr = utils.ptr_offset_aligned(tXsX_cur.iterator, offset) + # tXsX_cur = cute.make_tensor(new_ptr, tXsX_cur.layout) + else: + return sX + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): load_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 253f1fd7007..9f966b1044f 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -39,7 +39,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -81,7 +81,7 @@ def test_flash_attn_output( nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - dv_vals = [d] if d != 128 else [64, d] + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] @@ -251,7 +251,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -292,7 +292,7 @@ def test_flash_attn_varlen_output( nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - dv_vals = [d] if d != 128 else [64, d] + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] From 733730723b1ba54bbca3a3a26309db711cdbb633 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 1 Aug 2025 21:55:19 -0400 Subject: [PATCH 055/258] [Cute] Use kv_stage=3 for hdim (192,128) --- benchmarks/benchmark_attn.py | 31 +++++++++++++++------------ flash_attn/cute/flash_fwd_sm100.py | 34 ++++++++++++++---------------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 289518822ab..d6379b43510 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -93,10 +93,11 @@ def convert_to_cudnn_type(torch_type): def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): b, nheads, seqlen_q, headdim = q.shape _, nheads_k, seqlen_k, _ = k.shape - assert v.shape == (b, nheads_k, seqlen_k, headdim) + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) assert cudnn is not None, 'CUDNN is not available' q_gpu, k_gpu, v_gpu = q, k, v - o_gpu = torch.empty_like(q_gpu) + o_gpu = torch.empty((b, nheads, seqlen_q, headdim_v), dtype=q.dtype, device=q.device) stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) graph = cudnn.pygraph( io_data_type=convert_to_cudnn_type(q.dtype), @@ -148,9 +149,10 @@ def run(*args, **kwargs): def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): b, nheads, seqlen_q, headdim = q.shape _, nheads_k, seqlen_k, _ = k.shape - assert v.shape == (b, nheads_k, seqlen_k, headdim) - assert g.shape == (b, nheads, seqlen_q, headdim) - assert o.shape == (b, nheads, seqlen_q, headdim) + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert g.shape == (b, nheads, seqlen_q, headdim_v) + assert o.shape == (b, nheads, seqlen_q, headdim_v) assert lse.shape == (b, nheads, seqlen_q, 1) assert cudnn is not None, 'CUDNN is not available' q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g @@ -265,7 +267,8 @@ def run(*args, **kwargs): nheads_kv = nheads # nheads_kv = nheads // 4 # nheads_kv = 1 - headdim_v = headdim + # headdim_v = headdim + headdim_v = 128 if headdim == 192 else headdim # headdim_v = 512 has_qv = headdim == 64 and headdim_v == 512 # has_qv = False @@ -318,9 +321,10 @@ def run(*args, **kwargs): nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: + if headdim <= 256 and dtype != torch.float8_e4m3fn: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) - cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if has_backward and headdim == headdim_v: + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: # if False: if not varlen: @@ -341,13 +345,14 @@ def run(*args, **kwargs): if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: + if headdim <= 256 and dtype != torch.float8_e4m3fn: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean - time.sleep(1) - m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') - time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean + if has_backward: + time.sleep(1) + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean # pytorch_profiler(cudnn_spda, backward=False) # pytorch_profiler(cudnn_spda_bwd, backward=False) time.sleep(1) @@ -387,7 +392,7 @@ def run(*args, **kwargs): print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') if has_backward: print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if cudnn is not None and headdim == headdim_v: + if cudnn is not None: print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') if has_backward: print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index fd94f6e3b62..ee1c104333f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -166,13 +166,10 @@ def _setup_attributes(self): self.q_stage = 2 self.kv_stage = 4 if self.q_dtype.width == 8 else 3 - # TODO: temp solution to get this to run as uneven_kv_smem isn't working yet - if self.head_dim_padded == 192 and self.head_dim_v_padded == 128: - self.kv_stage = 2 self.acc_stage = 1 self.epi_stage = 2 # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: - # 128 x 192 x 2 bytes x 3 stages = 144KB, as we need 64KB for Q and 64 KB for O. + # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, @@ -884,7 +881,6 @@ def load( load_Q = partial( self.load_Q, tma_atom_Q, tQgQ, tQsQ, mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, - self.tma_copy_q_bytes, phase=q_producer_phase, ) # We have to use mbarrier directly in the load for KV instead of replying on @@ -892,12 +888,12 @@ def load( load_K = partial( self.load_KV, tma_atom_K, tKgK, tKsK, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, - self.tma_copy_k_bytes + K_or_V="K", ) load_V = partial( self.load_KV, tma_atom_V, tVgV, tVsV, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, - self.tma_copy_v_bytes + K_or_V="V", ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) @@ -1361,7 +1357,8 @@ def softmax_step( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None, e2e_freq=16 if self.head_dim_padded <= 64 else 16) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=16 if self.head_dim_padded <= 64 else 16) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) @@ -1776,14 +1773,13 @@ def load_Q( tQsQ: cute.Tensor, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, - tma_copy_bytes: int, block: Int32, stage: int, phase: int, ): cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_q_bytes) cute.copy( tma_atom, tQgQ[None, block], tQsQ[None, stage], tma_bar_ptr=mbar_full_ptr + stage ) @@ -1796,33 +1792,35 @@ def load_KV( tXsX: cute.Tensor, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, - tma_copy_bytes: int, block: Int32, producer_state: cutlass.pipeline.PipelineState, + K_or_V: str, ): + assert K_or_V in ("K", "V") + tma_copy_bytes = self.tma_copy_k_bytes if const_expr(K_or_V == "K") else self.tma_copy_v_bytes stage, phase = producer_state.index, producer_state.phase cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + if const_expr(K_or_V == "K" and self.uneven_kv_smem): + # Before this round, the smem location was occupied by V, which is smaller than + # K. So we need to wait for the stage after that (stage 1) to be empty as well. + if stage == 0: + cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) tXsX_cur = tXsX[None, stage] - # print(tXsX_cur) if const_expr(self.uneven_kv_smem): - tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase) - # print(tXsX_cur) + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) cute.copy(tma_atom, tXgX[None, block], tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) @cute.jit - # def offset_kv_smem(self, sX: cute.Tensor, state: cutlass.pipeline.PipelineState): def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): if const_expr(self.uneven_kv_smem): # smem layout is [smem_large, smem_small, smem_large], and the current stride is # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if # phase == 0, or left by offset if phase == 1. - # stage, phase = state.index, state.phase offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) return cute.make_tensor(sX.iterator + offset, sX.layout) - # new_ptr = utils.ptr_offset_aligned(tXsX_cur.iterator, offset) - # tXsX_cur = cute.make_tensor(new_ptr, tXsX_cur.layout) else: return sX From d6dbdaf1d978b05e0eb3653d5cef7c551f2a4e07 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 2 Aug 2025 00:56:15 -0400 Subject: [PATCH 056/258] [Cute] Simplify some variables, be more careful about self.q_stage --- flash_attn/cute/flash_fwd_sm100.py | 142 +++++++++++++---------------- 1 file changed, 65 insertions(+), 77 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index ee1c104333f..25430b8fcde 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -75,8 +75,11 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.m_block_size = m_block_size self.n_block_size = n_block_size + self.q_stage = 2 + assert self.q_stage in [1, 2] + # 2 Q tile per CTA - self.cta_tiler = (2 * m_block_size, n_block_size, self.head_dim_padded) + self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) self.qk_acc_dtype = Float32 @@ -119,15 +122,12 @@ def __init__( self.tmem_alloc_sync_bar_id = 1 - self.tmem_s0_offset = 0 - self.tmem_s1_offset = self.tmem_s0_offset + self.n_block_size - self.tmem_o0_offset = self.tmem_s1_offset + self.n_block_size - self.tmem_o1_offset = self.tmem_o0_offset + self.head_dim_v_padded - self.tmem_total = self.tmem_o1_offset + self.head_dim_v_padded + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 + self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 + self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS - self.tmem_p_offset = 0 - self.tmem_p0_offset = self.tmem_s0_offset + self.tmem_p_offset - self.tmem_p1_offset = self.tmem_s1_offset + self.tmem_p_offset + self.tmem_s_to_p_offset = 0 + self.tmem_p_offset = [self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)] # 0, 128 # vec buffer for row_max & row_sum self.tmem_vec0_offset = 0 @@ -164,7 +164,6 @@ def _setup_attributes(self): - Configures pipeline stages for softmax, correction, and epilogue operations """ - self.q_stage = 2 self.kv_stage = 4 if self.q_dtype.width == 8 else 3 self.acc_stage = 1 self.epi_stage = 2 @@ -568,7 +567,7 @@ def kernel( for i in cutlass.range_constexpr(8): cute.arch.mbarrier_init(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) if warp_idx == 4: - for i in cutlass.range_constexpr(2): + for i in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) if warp_idx == 5: @@ -609,14 +608,17 @@ def kernel( else: sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer) - sScale = storage.sScale.get_tensor(cute.make_layout(256)) + sScale = storage.sScale.get_tensor(cute.make_layout( + 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2) + )) thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM qk_acc_shape = thr_mma_qk.partition_shape_C((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) - # TODO: this is a fake tensor, need to retrieve tmem_ptr + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) @@ -624,25 +626,19 @@ def kernel( pv_acc_shape = thr_mma_pv.partition_shape_C((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) - tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) - tStS1 = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) - - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) + tStSs = tuple(cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2)) + tOtOs = tuple(cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage)) tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] - tOrP0 = cute.make_tensor( + tOrPs = [cute.make_tensor( tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], tOrP.layout, - ) - tOrP1 = cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, - tOrP.layout, - ) + ) for stage in range(2)] block_info = BlockInfo( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) @@ -715,12 +711,9 @@ def kernel( sQ_layout.inner, sK_layout.inner, sV_layout.inner, - tStS0, - tStS1, - tOtO0, - tOtO1, - tOrP0, - tOrP1, + tStSs, + tOtOs, + tOrPs, pipeline_kv, mbar_ptr, block_info, @@ -771,16 +764,16 @@ def kernel( stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) softmax_loop( stage=stage, - tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s0_offset if stage == 0 else self.tmem_s1_offset), tStS.layout)) + tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), tStS.layout)) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) else: # If there's s0_s1_barrier, it's faster to have 2 WGs having different code if warp_idx < self.softmax1_warp_ids[0]: - tStSi = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout) softmax_loop(stage=0, tStSi=tStSi) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: - tStSi = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout) softmax_loop(stage=1, tStSi=tStSi) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -793,8 +786,7 @@ def kernel( thr_mma_qk, thr_mma_pv, tStS, - tOtO0, - tOtO1, + tOtOs, sScale, mO, mLSE, @@ -897,10 +889,11 @@ def load( ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - load_Q(block=2 * m_block + 0, stage=0) # Q0 + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 load_K(block=n_block_max - 1, producer_state=kv_producer_state) # K0 kv_producer_state.advance() - load_Q(block=2 * m_block + 1, stage=1) # Q1 + if const_expr(self.q_stage == 2): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 q_producer_phase ^= 1 load_V(block=n_block_max - 1, producer_state=kv_producer_state) # V0 kv_producer_state.advance() @@ -926,12 +919,9 @@ def mma( sQ_swizzle: cute.Swizzle, sK_swizzle: cute.Swizzle, sV_swizzle: cute.Swizzle, - tStS0: cute.Tensor, - tStS1: cute.Tensor, - tOtO0: cute.Tensor, - tOtO1: cute.Tensor, - tOrP0: cute.Tensor, - tOrP1: cute.Tensor, + tStSs: Tuple[cute.Tensor, cute.Tensor], + tOtOs: tuple[cute.Tensor], + tOrPs: Tuple[cute.Tensor, cute.Tensor], pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, @@ -943,17 +933,17 @@ def mma( tSrQ = thr_mma_qk.make_fragment_A(sQ) tSrK = thr_mma_qk.make_fragment_B(sK) tOrV = thr_mma_pv.make_fragment_B(sV) - tStSs = (tStS0, tStS1) - tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) - tOrPs = (tOrP0, tOrP1) + if const_expr(self.q_stage == 2): + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + else: + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 0]) qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op gemm_Si = [ partial( sm100_utils.gemm_ptx_partial, - qk_mma_op, self.tmem_s0_offset if const_expr(stage == 0) else self.tmem_s1_offset, tSrQs[stage], - sA=sQ[None, None, None, stage], + qk_mma_op, self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], sA_swizzle=sQ_swizzle, sB_swizzle=sK_swizzle, zero_init=True ) for stage in range(2) @@ -961,7 +951,7 @@ def mma( gemm_Pi = [ partial( sm100_utils.gemm_ptx_partial, - pv_mma_op, self.tmem_o0_offset if const_expr(stage == 0) else self.tmem_o1_offset, tOrPs[stage], + pv_mma_op, self.tmem_o_offset[stage if self.q_stage == 2 else 0], tOrPs[stage], sA=None, sA_swizzle=None, sB_swizzle=sV_swizzle ) for stage in range(2) @@ -980,7 +970,7 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - for stage in cutlass.range_constexpr(2): + for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) @@ -1072,8 +1062,8 @@ def mma( # release Q0 & Q1 with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + 0) - tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + 1) + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 @@ -1113,8 +1103,7 @@ def mma( @cute.jit def softmax_loop( self, - stage: int, - # stage: Int32, + stage: int | Int32, softmax_scale_log2: Float32, thr_mma_qk: cute.core.ThrMma, tStSi: cute.Tensor, @@ -1154,7 +1143,7 @@ def softmax_loop( tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) - tStP = cute.make_tensor(tStSi.iterator + self.tmem_p_offset, tStP_layout) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32, @@ -1283,7 +1272,6 @@ def softmax_loop( @cute.jit def softmax_step( self, - # stage: Int32, mma_si_consumer_phase: Int32, si_corr_producer_phase: Int32, s0_s1_sequence_phase: Int32, @@ -1299,7 +1287,7 @@ def softmax_step( tStScale_r2t: cute.Tensor, tStP_r2t: cute.Tensor, sScale: cute.Tensor, - stage: int, + stage: int | Int32, mask_fn: Optional[Callable] = None, is_first: bool = False, ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: @@ -1385,8 +1373,7 @@ def correction_loop( thr_mma_qk: cute.core.ThrMma, thr_mma_pv: cute.core.ThrMma, tStS: cute.Tensor, - tOtO0: cute.Tensor, - tOtO1: cute.Tensor, + tOtOs: tuple[cute.Tensor], sScale: cute.Tensor, mO: cute.Tensor, mLSE: cute.Tensor, @@ -1415,7 +1402,6 @@ def correction_loop( tStScale_1_t2r = thr_tmem_load_vec.partition_S(tStScale_1) tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape - tOtOs = [tOtO0, tOtO1] tStScales_t2r = [tStScale_0_t2r, tStScale_1_t2r] # First iter: no correction is required @@ -1455,7 +1441,9 @@ def correction_loop( # warps, S_i must have been done, so O_i-1 must have been done as well. # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) if should_rescale: - self.correction_rescale(thr_mma_pv, tOtOs[stage], tidx, scale) + self.correction_rescale( + thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) softmax_corr_consumer_phase ^= 1 @@ -1467,8 +1455,8 @@ def correction_loop( # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. - stats = [None, None] - for stage in cutlass.range_constexpr(2): + stats = [None] * self.q_stage + for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() @@ -1498,8 +1486,8 @@ def correction_loop( else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2,)) - for stage in cutlass.range_constexpr(2): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block,)) + for stage in cutlass.range_constexpr(self.q_stage): row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -1508,7 +1496,7 @@ def correction_loop( (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) - if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + if tidx < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: gLSE[tidx + stage * self.m_block_size] = lse o_corr_consumer_phase ^= 1 @@ -1526,12 +1514,12 @@ def correction_loop( # ) # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 # stage = warp_idx_in_wg - # if stage < 2: + # if stage < self.q_stage: # # wait from corr, issue tma store on smem # # 1. wait for O0 / O1 final # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) # # 2. copy O0 / O1 to gmem - # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) # cute.arch.cp_async_bulk_commit_group() # # Ensure O0 / O1 buffer is ready to be released # cute.arch.cp_async_bulk_wait_group(0, read=True) @@ -1722,14 +1710,14 @@ def epilogue_s2g( cute.group_modes(sO, 0, 2), cute.group_modes(gO, 0, 2), ) - for stage in cutlass.range_constexpr(2): + for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem - cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, 2 * m_block + stage]) + cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) cute.arch.cp_async_bulk_commit_group() - for stage in cutlass.range_constexpr(2): + for stage in cutlass.range_constexpr(self.q_stage): # Ensure O0 / O1 buffer is ready to be released cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) @@ -1742,7 +1730,7 @@ def epilogue_s2g( tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - for stage in cutlass.range_constexpr(2): + for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) @@ -1752,11 +1740,11 @@ def epilogue_s2g( cute.autovec_copy(tOsO[None, None, None, stage], tOrO) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size - tOcO[0][0]: + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], - tOgO[None, rest_m, None, 2 * m_block + stage], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) @@ -1775,7 +1763,7 @@ def load_Q( mbar_empty_ptr: cute.Pointer, block: Int32, stage: int, - phase: int, + phase: Int32, ): cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) with cute.arch.elect_one(): From b8eb683bc4b702d735186a652bf5ed147f92782c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Aug 2025 09:44:36 -0400 Subject: [PATCH 057/258] [Cute] Update to nvidia-cutlass-dsl==4.1.0 --- flash_attn/cute/flash_bwd.py | 14 ++++++------ flash_attn/cute/flash_bwd_postprocess.py | 6 ++--- flash_attn/cute/flash_bwd_preprocess.py | 4 ++-- flash_attn/cute/flash_fwd.py | 4 ++-- flash_attn/cute/flash_fwd_sm100.py | 28 ++++++++++-------------- flash_attn/cute/interface.py | 3 ++- flash_attn/cute/mask.py | 26 +++++++++++----------- flash_attn/cute/softmax.py | 8 +++---- flash_attn/cute/utils.py | 24 ++++---------------- 9 files changed, 49 insertions(+), 68 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 3ae61ba08dc..79f5ee8ec13 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -559,7 +559,7 @@ def kernel( smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB ).get_slice(tidx) # TODO: what's the number of bits? What if SdP_swapAB - r2s_thr_copy_PdS = utils.make_tiled_copy_C( + r2s_thr_copy_PdS = cute.make_tiled_copy_C( cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width ), @@ -774,7 +774,7 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) - for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) @@ -798,7 +798,7 @@ def load_dO_next(): acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) - for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) @@ -850,7 +850,7 @@ def dQ_mma(hook_fn): tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(acc_dQ) - for i in cutlass.range_constexpr(cute.size(acc_dQ_atomic)): + for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) @@ -910,7 +910,7 @@ def epilogue( smem_copy_atom_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width ) - smem_thr_copy_dKV = utils.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) taccdVrdV = smem_thr_copy_dKV.retile(rdV) taccdKrdK = smem_thr_copy_dKV.retile(rdK) taccdVsdV = smem_thr_copy_dKV.partition_D(sdV) @@ -982,9 +982,9 @@ def epilogue( acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) - for i in cutlass.range_constexpr(cute.size(acc_dV_atomic)): + for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) - for i in cutlass.range_constexpr(cute.size(acc_dK_atomic)): + for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) @cute.jit diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 9136dcd8460..6a408906d53 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -265,7 +265,7 @@ def kernel( # print(acc) # print(tdQsdQaccum) # ((1, 1), 64) # print(tdQrdQaccum) # ((1, 4), 4, 4) - for i in cutlass.range_constexpr(cute.size(tdQsdQaccum)): + for i in cutlass.range(cute.size(tdQsdQaccum), unroll_full=True): tdQrdQaccum[i] = tdQsdQaccum[i] # Convert tdQrdQaccum from fp32 to fp16/bf16 rdQ = cute.make_fragment_like(acc, self.dtype) @@ -276,7 +276,7 @@ def kernel( smem_copy_atom_dQ = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width ) - smem_thr_copy_dQ = utils.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) + smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) @@ -296,7 +296,7 @@ def kernel( cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) - for rest_m in cutlass.range_constexpr(cute.size(tdQrdQ.shape[1])): + for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: cute.copy( gmem_tiled_copy_dQ, diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 7a2734ec205..a5da7b7009e 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -233,7 +233,7 @@ def kernel( assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) - for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): # Instead of using tOcO, we using t0OcO and subtract the offset from the limit # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: @@ -263,7 +263,7 @@ def kernel( ) # Only the thread corresponding to column 0 writes out the lse to gmem if tOcO[0, 0, 0][1] == 0: - for m in cutlass.range_constexpr(cute.size(dP_sum)): + for m in cutlass.range(cute.size(dP_sum), unroll_full=True): row = tOcO[0, m, 0][0] gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0 diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 3c0651f7893..311540abaf7 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -289,7 +289,7 @@ def epilogue( # Make sure all threads have finished reading V cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_O = utils.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) # copy acc O from rmem to smem with the smem copy atom @@ -1539,7 +1539,7 @@ def mma( # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_P = utils.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None # if cute.arch.thread_idx()[0] == 0: diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 25430b8fcde..f17a489bd6f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -126,12 +126,11 @@ def __init__( self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS - self.tmem_s_to_p_offset = 0 + self.tmem_s_to_p_offset = self.n_block_size // 2 self.tmem_p_offset = [self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)] # 0, 128 # vec buffer for row_max & row_sum - self.tmem_vec0_offset = 0 - self.tmem_vec1_offset = self.tmem_vec0_offset + self.n_block_size + self.tmem_vec_offset = self.tmem_s_offset if self.head_dim_padded < 96: self.num_regs_softmax = 200 @@ -1323,11 +1322,11 @@ def softmax_step( mask_fn(tSrS_t2r, n_block=n_block) row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) - # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) - # tSrScale_r2t[0] = acc_scale - # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) - # cute.arch.fence_view_async_tmem_store() if const_expr(not is_first): + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() thread_idx = thr_tmem_load.thr_idx sScale[thread_idx + stage * self.m_block_size] = acc_scale # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) @@ -1387,23 +1386,20 @@ def correction_loop( ): tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) - tStScale_0 = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_scale_layout) - tStScale_1 = cute.make_tensor(tStS.iterator + self.tmem_vec1_offset, tStS_scale_layout) + tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStS_scale_layout) + for stage in range(2)) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) tmem_load_v_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, ) - tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScale_0) + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]) tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(tidx) - tStScale_0_t2r = thr_tmem_load_vec.partition_S(tStScale_0) - tStScale_1_t2r = thr_tmem_load_vec.partition_S(tStScale_1) + tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(2)] tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape - tStScales_t2r = [tStScale_0_t2r, tStScale_1_t2r] - # First iter: no correction is required cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) @@ -1430,9 +1426,9 @@ def correction_loop( for stage in cutlass.range_constexpr(2): # wait for S0 / S1 cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) - # cute.copy(tiled_tmem_load_vec, tStScale_1_t2r, tSrScale_t2r) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() - # scale = tSrScale_t2r[stage] + # scale = tSrScale_t2r[0] scale = sScale[tidx + stage * self.m_block_size] should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 # should_rescale = True diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8ede8958dbe..624c325f764 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,11 +1,12 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.dev0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. # Supported features: # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128. +# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) # - varlen # - sliding window # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index ab795c15da0..1415cf1b65c 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -42,11 +42,11 @@ def apply_mask( if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): # traverse column index. - for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): # if t0ScS_mn[0, c][1] >= seqlenk_col_limit: # acc_S_mn[None, c].fill(-cutlass.Float32.inf) oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit - for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] else: # Causal or local # If PackGQA, we split the work of compute divmod among threads in the same row @@ -64,7 +64,7 @@ def apply_mask( 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset ) if cutlass.const_expr(mask_causal): - for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size @@ -76,7 +76,7 @@ def apply_mask( if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # traverse column index. - for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): # only consider the column index, so the row index sets to 0. # if t0ScS_mn[0, c][1] >= col_limit_right: # acc_S_mn[r, c] = -cutlass.Float32.inf @@ -92,7 +92,7 @@ def apply_mask( if cutlass.const_expr(self.window_size_left is not None) else None ) - for r in cutlass.range_constexpr(cute.size(tScS_mn.shape[0])): + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size else: @@ -110,7 +110,7 @@ def apply_mask( ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) # traverse column index. - for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): col_idx = t0ScS_mn[0, c][1] # only consider the column index, so the row index sets to 0. if col_idx >= col_limit_right or col_idx < col_limit_left: @@ -137,7 +137,7 @@ def apply_mask_sm100( if cutlass.const_expr(mask_seqlen): ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) if cutlass.const_expr(not ncol % 16 == 0): - for i in cutlass.range_constexpr(ncol): + for i in cutlass.range(ncol, unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf # For some reason the 2 lines above generate really bad SASS @@ -149,14 +149,14 @@ def apply_mask_sm100( # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using # Ideally we'd move by 32 instead of 16, but mask >> i isn't correct for i == 31 # (see below). - for s in cutlass.range_constexpr(ncol // 16): + for s in cutlass.range(ncol // 16, unroll_full=True): col_limit_right_s = seqlenk_col_limit - s * 16 # Don't need to clamp to 32 since the shr.u32 instruction does that already col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_cur = %d", mask, col_limit_right_s, col_limit_right_cur) - for i in cutlass.range_constexpr(16): + for i in cutlass.range(16, unroll_full=True): # mask >> i does not produce correct result for 0b11..11 >> 31 # However, if we use utils.shr_u32, the compiler doesn't generate # the R2P instruction, so it's slower. @@ -181,19 +181,19 @@ def apply_mask_sm100( # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) if cutlass.const_expr(not ncol % 16 == 0): - for i in cutlass.range_constexpr(ncol): + for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] ) else: # Bit manipulation, compiles down to the R2P instruction # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using - for s in cutlass.range_constexpr(ncol // 16): + for s in cutlass.range(ncol // 16, unroll_full=True): col_limit_right_s = col_limit_right - s * 16 col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) - for i in cutlass.range_constexpr(16): + for i in cutlass.range(16, unroll_full=True): # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) mask_i_bit = cutlass.Boolean((mask >> i) & 1) acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf @@ -220,7 +220,7 @@ def apply_mask_sm100( row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 ) # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) - for i in cutlass.range_constexpr(cute.size(tScS_t2r.shape)): + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): col_idx = tScS_t2r[i][1] acc_S[i] = ( -cutlass.Float32.inf diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 5799cd4bd98..e0407e99cdf 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -55,7 +55,7 @@ def online_softmax( acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) # Each iteration processes one row of acc_S - for r in cutlass.range_constexpr(cute.size(self.row_max)): + for r in cutlass.range(cute.size(self.row_max), unroll_full=True): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) row_max_cur = self._compute_row_max( acc_S_row, @@ -89,7 +89,7 @@ def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(self.row_max, Float32) - for r in cutlass.range_constexpr(cute.size(self.row_sum)): + for r in cutlass.range(cute.size(self.row_sum), unroll_full=True): # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = ( self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] @@ -116,7 +116,7 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: """ acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) - for r in cutlass.range_constexpr(cute.size(row_scale)): + for r in cutlass.range(cute.size(row_scale), unroll_full=True): acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) @@ -162,7 +162,7 @@ def scale_subtract_rowmax( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 - for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index fbd836be1d9..4f0adb8dd42 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -26,34 +26,18 @@ def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if cutlass.const_expr(swapAB): - return make_tiled_copy_B(copy_atom, tiled_mma) + return cute.make_tiled_copy_B(copy_atom, tiled_mma) else: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_A_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), - ) + return cute.make_tiled_copy_A(copy_atom, tiled_mma) def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if cutlass.const_expr(swapAB): - return make_tiled_copy_A(copy_atom, tiled_mma) + return cute.make_tiled_copy_A(copy_atom, tiled_mma) else: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_B_tiled, - tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), - ) - - -def make_tiled_copy_C(copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma) -> cute.TiledCopy: - return cute.make_tiled_copy( - copy_atom, - layout_tv=tiled_mma.tv_layout_C_tiled, - tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), - ) + return cute.make_tiled_copy_B(copy_atom, tiled_mma) def mma_make_fragment_A( From cc5c5745038b160615ba0a38878612affef147e3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 29 Jul 2025 18:10:32 -0400 Subject: [PATCH 058/258] [Cute] Implement additive sink for fwd_sm100 --- flash_attn/cute/flash_fwd_sm100.py | 10 +++++++++- flash_attn/cute/interface.py | 26 ++++++++++++++++++++------ flash_attn/utils/testing.py | 11 ++++++++++- tests/cute/test_flash_attn.py | 26 ++++++++++++++++++++++++-- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index f17a489bd6f..3db9153378d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -193,6 +193,7 @@ def __call__( softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, + additive_sink: Optional[cute.Tensor] = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -473,6 +474,7 @@ class SharedStorage: softcap_val, window_size_left, window_size_right, + additive_sink, sQ_layout, sK_layout, tP_layout, @@ -512,6 +514,7 @@ def kernel( softcap_val: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], + additive_sink: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, @@ -790,6 +793,7 @@ def kernel( mO, mLSE, sO, + additive_sink, tma_atom_O, mbar_ptr, softmax_scale_log2, @@ -1377,6 +1381,7 @@ def correction_loop( mO: cute.Tensor, mLSE: cute.Tensor, sO: cute.Tensor, + additive_sink: Optional[cute.Tensor], tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, @@ -1452,17 +1457,20 @@ def correction_loop( # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. stats = [None] * self.q_stage + add_sink_val = additive_sink[head_idx] if const_expr(additive_sink is not None) else None for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] row_sum = sScale[tidx + stage * self.m_block_size] - if const_expr(mLSE is not None): + if const_expr(mLSE is not None or additive_sink is not None): row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] else: row_max = None cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + if const_expr(additive_sink is not None): + row_sum += add_sink_val * utils.exp2f(-row_max * softmax_scale_log2) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 624c325f764..e60b42f0304 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -62,6 +62,7 @@ def _flash_attn_fwd( softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, + additive_sink: Optional[torch.Tensor] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, @@ -98,7 +99,10 @@ def _flash_attn_fwd( if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)), "inputs must be on CUDA device" + if additive_sink is not None: + assert additive_sink.shape == (num_head,) + assert additive_sink.dtype == torch.float32 + assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, additive_sink)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -125,9 +129,9 @@ def _flash_attn_fwd( ) for t in (q, k, v, out) ] lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) if lse is not None else None - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, additive_sink_tensor = [ from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, additive_sink) ] if causal: window_size_right = 0 @@ -149,11 +153,13 @@ def _flash_attn_fwd( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, window_size_left is not None, window_size_right is not None, + additive_sink is not None, m_block_size, n_block_size, num_threads, compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: + assert additive_sink is None, "Sm90 doesn't support additive sink" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, @@ -185,12 +191,12 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - softcap, window_size_left, window_size_right, + softcap, window_size_left, window_size_right, additive_sink_tensor, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - softcap, window_size_left, window_size_right, + softcap, window_size_left, window_size_right, additive_sink_tensor, ) return out, lse @@ -394,6 +400,7 @@ def forward( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), + additive_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -404,6 +411,7 @@ def forward( causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], + additive_sink=additive_sink, softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse) @@ -427,7 +435,7 @@ def backward(ctx, dout, *args): ctx.causal, ctx.softcap, ) - return dq, dk, dv, *((None,) * 4) + return dq, dk, dv, *((None,) * 5) class FlashAttnVarlenFunc(torch.autograd.Function): @@ -445,6 +453,7 @@ def forward( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), + additive_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -459,6 +468,7 @@ def forward( causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], + additive_sink=additive_sink, softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -483,6 +493,7 @@ def flash_attn_func( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), + additive_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): return FlashAttnFunc.apply( @@ -492,6 +503,7 @@ def flash_attn_func( softmax_scale, causal, window_size, + additive_sink, softcap, ) @@ -507,6 +519,7 @@ def flash_attn_varlen_func( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), + additive_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): return FlashAttnVarlenFunc.apply( @@ -520,5 +533,6 @@ def flash_attn_varlen_func( softmax_scale, causal, window_size, + additive_sink, softcap, ) diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py index b2c03addd2b..984940e818c 100644 --- a/flash_attn/utils/testing.py +++ b/flash_attn/utils/testing.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math +from typing import Optional import torch from einops import rearrange, repeat @@ -240,6 +241,7 @@ def attention_ref( window_size=(None, None), attention_chunk=0, sink_token_length=0, + additive_sink: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, @@ -323,7 +325,14 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) + if additive_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + row_max = torch.amax(scores, dim=-1, keepdim=True) + numerator = torch.exp(scores_fp32 - row_max) + row_sum = torch.sum(numerator, dim=-1, keepdim=True) + rearrange(additive_sink, "h -> h 1 1") * torch.exp(-row_max) + attention = (numerator / row_sum).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 9f966b1044f..65692cfba0d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -21,6 +21,8 @@ @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_additive_sink", [False, True]) +# @pytest.mark.parametrize("has_additive_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -67,7 +69,7 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_additive_sink, mha_type, dtype ): if (causal or local) and seqlen_k < seqlen_q: pytest.skip("Causal attention requires seqlen_k >= seqlen_q") @@ -101,6 +103,11 @@ def test_flash_attn_output( # Put window_size after QKV randn so that window_size changes from test to test window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) + if has_additive_sink: + # We don't want negative here + additive_sink = torch.rand(nheads, dtype=torch.float32, device=device) * 5.0 + else: + additive_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -118,6 +125,7 @@ def test_flash_attn_output( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + additive_sink=additive_sink, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -131,6 +139,7 @@ def test_flash_attn_output( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + additive_sink=additive_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -168,6 +177,7 @@ def test_flash_attn_output( window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, + additive_sink=additive_sink, # pack_gqa=pack_gqa, # num_splits=num_splits ) @@ -189,6 +199,7 @@ def test_flash_attn_output( and softcap == 0.0 and not local and dv == d + and additive_sink is None # and False ): g = torch.randn_like(out) @@ -233,6 +244,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_additive_sink", [False, True]) +# @pytest.mark.parametrize("has_additive_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -278,7 +291,7 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_additive_sink, mha_type, dtype ): if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q @@ -311,6 +324,11 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + if has_additive_sink: + # We don't want negative here + additive_sink = torch.rand(nheads, dtype=torch.float32, device=device) * 5.0 + else: + additive_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -382,6 +400,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + additive_sink=additive_sink, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -395,6 +414,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, + additive_sink=additive_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -431,6 +451,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, + additive_sink=additive_sink, softcap=softcap, ) out = output_pad_fn(out_unpad) @@ -453,6 +474,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not dv > 256 and not attention_chunk != 0 and dv == d + and not has_additive_sink and False ): g_unpad = torch.randn_like(out_unpad) From 5bdd30e4467722ed02c9f12f8e730886e62cfdae Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 29 Jul 2025 18:52:58 -0400 Subject: [PATCH 059/258] [Cute] Sink values in bf16 --- flash_attn/cute/flash_fwd_sm100.py | 2 +- flash_attn/cute/interface.py | 2 +- tests/cute/test_flash_attn.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 3db9153378d..c4f71e930e3 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1457,7 +1457,7 @@ def correction_loop( # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. stats = [None] * self.q_stage - add_sink_val = additive_sink[head_idx] if const_expr(additive_sink is not None) else None + add_sink_val = Float32(additive_sink[head_idx]) if const_expr(additive_sink is not None) else None for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index e60b42f0304..eacd9d964b2 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -101,7 +101,7 @@ def _flash_attn_fwd( assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" if additive_sink is not None: assert additive_sink.shape == (num_head,) - assert additive_sink.dtype == torch.float32 + assert additive_sink.dtype == torch.bfloat16, "additive_sink must be bfloat16" assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, additive_sink)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 65692cfba0d..5918e444226 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -105,7 +105,7 @@ def test_flash_attn_output( # window_size = (-1, -1) if not local else (16, 0) if has_additive_sink: # We don't want negative here - additive_sink = torch.rand(nheads, dtype=torch.float32, device=device) * 5.0 + additive_sink = torch.rand(nheads, dtype=torch.bfloat16, device=device) * 5.0 else: additive_sink = None if dtype == torch.float8_e4m3fn: @@ -326,7 +326,7 @@ def test_flash_attn_varlen_output( window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() if has_additive_sink: # We don't want negative here - additive_sink = torch.rand(nheads, dtype=torch.float32, device=device) * 5.0 + additive_sink = torch.rand(nheads, dtype=torch.bfloat16, device=device) * 5.0 else: additive_sink = None if dtype == torch.float8_e4m3fn: From e81c237e2872e0bc9aa0ebb52828f6736ed294ac Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 5 Aug 2025 20:42:41 -0400 Subject: [PATCH 060/258] [Cute] Fix sink impl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously we implemented e^x ----------- sink + Σ e^x Now we implemented e^x ----------- e^sink + Σ e^x --- flash_attn/cute/flash_fwd_sm100.py | 19 +++++++------- flash_attn/cute/interface.py | 32 +++++++++++------------ flash_attn/utils/testing.py | 14 +++++----- tests/cute/test_flash_attn.py | 42 ++++++++++++++---------------- 4 files changed, 54 insertions(+), 53 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c4f71e930e3..0106be59d5e 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -193,7 +193,7 @@ def __call__( softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, - additive_sink: Optional[cute.Tensor] = None, + learnable_sink: Optional[cute.Tensor] = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -474,7 +474,7 @@ class SharedStorage: softcap_val, window_size_left, window_size_right, - additive_sink, + learnable_sink, sQ_layout, sK_layout, tP_layout, @@ -514,7 +514,7 @@ def kernel( softcap_val: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], - additive_sink: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, @@ -793,7 +793,7 @@ def kernel( mO, mLSE, sO, - additive_sink, + learnable_sink, tma_atom_O, mbar_ptr, softmax_scale_log2, @@ -1381,7 +1381,7 @@ def correction_loop( mO: cute.Tensor, mLSE: cute.Tensor, sO: cute.Tensor, - additive_sink: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, @@ -1457,20 +1457,21 @@ def correction_loop( # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. stats = [None] * self.q_stage - add_sink_val = Float32(additive_sink[head_idx]) if const_expr(additive_sink is not None) else None + learnable_sink_val = Float32(learnable_sink[head_idx]) if const_expr(learnable_sink is not None) else None for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] row_sum = sScale[tidx + stage * self.m_block_size] - if const_expr(mLSE is not None or additive_sink is not None): + if const_expr(mLSE is not None or learnable_sink is not None): row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] else: row_max = None cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) - if const_expr(additive_sink is not None): - row_sum += add_sink_val * utils.exp2f(-row_max * softmax_scale_log2) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + row_sum += utils.exp2f(learnable_sink_val * LOG2_E - row_max * softmax_scale_log2) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index eacd9d964b2..dff4564d180 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -62,7 +62,7 @@ def _flash_attn_fwd( softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, @@ -99,10 +99,10 @@ def _flash_attn_fwd( if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" - if additive_sink is not None: - assert additive_sink.shape == (num_head,) - assert additive_sink.dtype == torch.bfloat16, "additive_sink must be bfloat16" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, additive_sink)), "inputs must be on CUDA device" + if learnable_sink is not None: + assert learnable_sink.shape == (num_head,) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" + assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -131,7 +131,7 @@ def _flash_attn_fwd( lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) if lse is not None else None cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, additive_sink_tensor = [ from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, additive_sink) + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] if causal: window_size_right = 0 @@ -153,13 +153,13 @@ def _flash_attn_fwd( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, window_size_left is not None, window_size_right is not None, - additive_sink is not None, + learnable_sink is not None, m_block_size, n_block_size, num_threads, compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: - assert additive_sink is None, "Sm90 doesn't support additive sink" + assert learnable_sink is None, "Sm90 doesn't support additive sink" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, @@ -400,7 +400,7 @@ def forward( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -411,7 +411,7 @@ def forward( causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse) @@ -453,7 +453,7 @@ def forward( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): out, lse = _flash_attn_fwd( @@ -468,7 +468,7 @@ def forward( causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -493,7 +493,7 @@ def flash_attn_func( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): return FlashAttnFunc.apply( @@ -503,7 +503,7 @@ def flash_attn_func( softmax_scale, causal, window_size, - additive_sink, + learnable_sink, softcap, ) @@ -519,7 +519,7 @@ def flash_attn_varlen_func( softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, ): return FlashAttnVarlenFunc.apply( @@ -533,6 +533,6 @@ def flash_attn_varlen_func( softmax_scale, causal, window_size, - additive_sink, + learnable_sink, softcap, ) diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py index 984940e818c..81be51f1de8 100644 --- a/flash_attn/utils/testing.py +++ b/flash_attn/utils/testing.py @@ -241,7 +241,7 @@ def attention_ref( window_size=(None, None), attention_chunk=0, sink_token_length=0, - additive_sink: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, @@ -325,14 +325,16 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - if additive_sink is None: + if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: scores_fp32 = scores.to(torch.float32) - row_max = torch.amax(scores, dim=-1, keepdim=True) - numerator = torch.exp(scores_fp32 - row_max) - row_sum = torch.sum(numerator, dim=-1, keepdim=True) + rearrange(additive_sink, "h -> h 1 1") * torch.exp(-row_max) - attention = (numerator / row_sum).to(v.dtype) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) + attention = (unnormalized_scores / normalizer).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 5918e444226..58fe891d32c 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -21,8 +21,8 @@ @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_additive_sink", [False, True]) -# @pytest.mark.parametrize("has_additive_sink", [False]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -69,7 +69,7 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_additive_sink, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype ): if (causal or local) and seqlen_k < seqlen_q: pytest.skip("Causal attention requires seqlen_k >= seqlen_q") @@ -103,11 +103,10 @@ def test_flash_attn_output( # Put window_size after QKV randn so that window_size changes from test to test window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) - if has_additive_sink: - # We don't want negative here - additive_sink = torch.rand(nheads, dtype=torch.bfloat16, device=device) * 5.0 + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: - additive_sink = None + learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -125,7 +124,7 @@ def test_flash_attn_output( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -139,7 +138,7 @@ def test_flash_attn_output( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -177,7 +176,7 @@ def test_flash_attn_output( window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, - additive_sink=additive_sink, + learnable_sink=learnable_sink, # pack_gqa=pack_gqa, # num_splits=num_splits ) @@ -199,7 +198,7 @@ def test_flash_attn_output( and softcap == 0.0 and not local and dv == d - and additive_sink is None + and learnable_sink is None # and False ): g = torch.randn_like(out) @@ -244,8 +243,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_additive_sink", [False, True]) -# @pytest.mark.parametrize("has_additive_sink", [False]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -291,7 +290,7 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_additive_sink, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype ): if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q @@ -324,11 +323,10 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() - if has_additive_sink: - # We don't want negative here - additive_sink = torch.rand(nheads, dtype=torch.bfloat16, device=device) * 5.0 + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: - additive_sink = None + learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: @@ -400,7 +398,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -414,7 +412,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -451,7 +449,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, - additive_sink=additive_sink, + learnable_sink=learnable_sink, softcap=softcap, ) out = output_pad_fn(out_unpad) @@ -474,7 +472,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not dv > 256 and not attention_chunk != 0 and dv == d - and not has_additive_sink + and not has_learnable_sink and False ): g_unpad = torch.randn_like(out_unpad) From 2f78d4840b2d8afa8f1b1a6d25559a83ed4e6492 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 6 Aug 2025 13:38:19 -0400 Subject: [PATCH 061/258] [Cute] Fix row_max not being written to smem when there's sink --- flash_attn/cute/flash_fwd_sm100.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0106be59d5e..81e94c52f6f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -755,6 +755,7 @@ def kernel( thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, + learnable_sink=learnable_sink, mbar_ptr=mbar_ptr, block_info=block_info, SeqlenInfoCls=SeqlenInfoCls, @@ -1112,6 +1113,7 @@ def softmax_loop( tStSi: cute.Tensor, sScale: cute.Tensor, mLSE: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1241,7 +1243,7 @@ def softmax_loop( # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - if const_expr(mLSE is not None): + if const_expr(mLSE is not None or learnable_sink is not None): sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) From dc742f2c47baa4b15cc33e6a2444f33d02c0a6d4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 6 Aug 2025 15:13:07 -0400 Subject: [PATCH 062/258] [Cute] Make flash_attn.cute installable as a standalone package --- flash_attn/cute/README.md | 0 flash_attn/cute/__init__.py | 13 +++++++++++ flash_attn/cute/pyproject.toml | 42 ++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 flash_attn/cute/README.md create mode 100644 flash_attn/cute/__init__.py diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py new file mode 100644 index 00000000000..f1a4ed2d214 --- /dev/null +++ b/flash_attn/cute/__init__.py @@ -0,0 +1,13 @@ +"""Flash Attention CUTE (CUDA Template Engine) implementation.""" + +from .interface import ( + flash_attn_func, + flash_attn_varlen_func, +) + +__version__ = "0.1.0" + +__all__ = [ + "flash_attn_func", + "flash_attn_varlen_func", +] diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 585c50079a3..8c4d89e52e1 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -1,8 +1,50 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "flash-attn-cute" +version = "0.1.0" +description = "Flash Attention CUTE (CUDA Template Engine) implementation" +readme = "README.md" +requires-python = ">=3.12" +license = {text = "BSD 3-Clause License"} +authors = [ + {name = "Tri Dao"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "nvidia-cutlass-dsl==4.1.0", + "torch", + "einops", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/Dao-AILab/flash-attention" +Repository = "https://github.com/Dao-AILab/flash-attention" + +[tool.setuptools] +packages = ["flash_attn.cute"] +package-dir = {"flash_attn.cute" = "."} + [tool.ruff] line-length = 100 [tool.ruff.lint] ignore = [ "E731", # do not assign a lambda expression, use a def + "E741", # Do not use variables named 'I', 'O', or 'l' "F841", # local variable is assigned to but never used ] \ No newline at end of file From 66ee1b5be2a12132f49e3807b3e44e09c36a4165 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 9 Aug 2025 14:50:27 -0400 Subject: [PATCH 063/258] [Cute] No longer assume Q, K, V are compact --- flash_attn/cute/flash_fwd.py | 10 ++++++++++ flash_attn/cute/flash_fwd_sm100.py | 3 +++ flash_attn/cute/interface.py | 18 ++++++++---------- tests/cute/test_flash_attn.py | 4 +++- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 311540abaf7..61333ca7357 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -551,12 +551,14 @@ def __call__( softcap: Optional[cutlass.Float32] = None, window_size_left: Optional[cutlass.Int32] = None, window_size_right: Optional[cutlass.Int32] = None, + learnable_sink: Optional[cute.Tensor] = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ + assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size @@ -567,6 +569,9 @@ def __call__( self.use_tma_O = self.arch >= 90 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) for t in (mQ, mK, mV, mO)] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) # grid_dim: (m_block, num_head, batch_size) @@ -1067,16 +1072,21 @@ def __call__( softcap: cutlass.Float32 | float | None = None, window_size_left: cutlass.Int32 | int | None = None, window_size_right: cutlass.Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ + assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 81e94c52f6f..f0406a06c1c 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -214,6 +214,9 @@ def __call__( self.k_dtype = mK.element_type self.v_dtype = mV.element_type self.o_dtype = mO.element_type + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index dff4564d180..3e154ace813 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -124,11 +124,10 @@ def _flash_attn_fwd( dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor = [ - utils.convert_from_dlpack( - t.detach(), leading_dim=t.ndim - 1, divisibility=128 // dtype.width - ) for t in (q, k, v, out) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out) ] - lse_tensor = utils.convert_from_dlpack(lse, leading_dim=lse.ndim - 1, alignment=4) if lse is not None else None + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, additive_sink_tensor = [ from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) @@ -267,18 +266,17 @@ def _flash_attn_bwd( dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ - utils.convert_from_dlpack( - t.detach(), leading_dim=3, divisibility=128 // dtype.width - ) for t in (q, k, v, out, dout, dq, dk, dv) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = utils.convert_from_dlpack(lse.detach(), leading_dim=2, alignment=4) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - utils.convert_from_dlpack(t.detach(), leading_dim=2, divisibility=128 // cutlass.Float32.width) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) for t in (dq_accum, dpsum, lse_log2) ] if qhead_per_kvhead > 1: dk_accum_tensor, dv_accum_tensor = [ - utils.convert_from_dlpack(t.detach(), leading_dim=2, divisibility=128 // cutlass.Float32.width) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) for t in (dk_accum, dv_accum) ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 58fe891d32c..61da6991c79 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -42,6 +42,7 @@ # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("d", [128, 192]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -199,7 +200,7 @@ def test_flash_attn_output( and not local and dv == d and learnable_sink is None - # and False + and False ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -264,6 +265,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", [128, 192]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From 5844fa69c73a838d26ac3917904952f0f9a98976 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 9 Aug 2025 15:52:54 -0400 Subject: [PATCH 064/258] [Cute] Fix not allocating enough smem for sScale when there's sink --- flash_attn/cute/flash_fwd_sm100.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index f0406a06c1c..d630668aa8d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -425,7 +425,8 @@ class SharedStorage: # Tmem holding buffer tmem_holding_buf: Int32 # Smem tensors - sScale: cute.struct.MemRange[Float32, 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2)] + # store row max and row sum + sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes, @@ -613,9 +614,7 @@ def kernel( else: sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer) - sScale = storage.sScale.get_tensor(cute.make_layout( - 2 * self.m_block_size * (1 if const_expr(mLSE is None) else 2) - )) + sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM From 8c348fd79f423923710cb5a949c8e79f6aa29f7f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 9 Aug 2025 19:20:57 -0400 Subject: [PATCH 065/258] [FA3] Fix doc: page block size can be arbitrary --- benchmarks/benchmark_attn.py | 4 +++- hopper/flash_attn_interface.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index d6379b43510..147b00f15b3 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -272,6 +272,8 @@ def run(*args, **kwargs): # headdim_v = 512 has_qv = headdim == 64 and headdim_v == 512 # has_qv = False + # sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + sinks = None for batch_size, seqlen in bs_seqlen_vals: num_splits = 0 @@ -367,7 +369,7 @@ def run(*args, **kwargs): time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if flash_attn_func_python is not None: if not varlen: - m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') else: m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 0e93f234aa3..b753a0fba7b 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -706,7 +706,7 @@ def flash_attn_with_kvcache( q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. + page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.). v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate From 81cdf4cec35d6e4e0c9bc3d89b507698b40ba7bb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 11 Aug 2025 23:13:03 -0400 Subject: [PATCH 066/258] [Cute] Don't need i64_to_f32x2 anymore --- flash_attn/cute/utils.py | 96 +++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 4f0adb8dd42..193b369eba7 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -485,59 +485,53 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor): @dsl_user_op -def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: - vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip) - vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1) - res0 = Float32( - vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + T.vector(2, T.f32()), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $1, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $2, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b64 $0, {r7, r8};\n\t" + "}\n", + "=l,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, ) - res1 = Float32( - vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + out0 = Float32( + vector.extract(out_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) ) - return res0, res1 + out1 = Float32( + vector.extract(out_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + ) + return out0, out1 -@cute.jit -def e2e_asm2(x: Float32, y: Float32) -> Tuple[Float32, Float32]: - out_i64 = cutlass.Int64( - llvm.inline_asm( - T.i64(), - [Float32(x).ir_value(), Float32(y).ir_value()], - "{\n\t" - ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" - ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" - ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" - "max.ftz.f32 f1, $1, 0fC2FE0000;\n\t" - "max.ftz.f32 f2, $2, 0fC2FE0000;\n\t" - "mov.b64 l1, {f1, f2};\n\t" - "mov.f32 f3, 0f4B400000;\n\t" - "mov.b64 l2, {f3, f3};\n\t" - "add.rm.ftz.f32x2 l7, l1, l2;\n\t" - "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" - "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" - "mov.f32 f7, 0f3D9DF09D;\n\t" - "mov.b64 l6, {f7, f7};\n\t" - "mov.f32 f6, 0f3E6906A4;\n\t" - "mov.b64 l5, {f6, f6};\n\t" - "mov.f32 f5, 0f3F31F519;\n\t" - "mov.b64 l4, {f5, f5};\n\t" - "mov.f32 f4, 0f3F800000;\n\t" - "mov.b64 l3, {f4, f4};\n\t" - "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" - "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" - "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" - "mov.b64 {r1, r2}, l7;\n\t" - "mov.b64 {r3, r4}, l10;\n\t" - "shl.b32 r5, r1, 23;\n\t" - "add.s32 r7, r5, r3;\n\t" - "shl.b32 r6, r2, 23;\n\t" - "add.s32 r8, r6, r4;\n\t" - "mov.b64 $0, {r7, r8};\n\t" - "}\n", - "=l,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) ) - return i64_to_f32x2(out_i64) From c4be57875be56014d77f21000d52f4e8fb643f4d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Aug 2025 11:26:48 -0400 Subject: [PATCH 067/258] Remove old xentropy kernel This hasn't been used since 2023-09 --- csrc/xentropy/README.md | 14 - csrc/xentropy/interface.cpp | 59 --- csrc/xentropy/setup.py | 139 ------ csrc/xentropy/xentropy_kernel.cu | 758 ------------------------------- 4 files changed, 970 deletions(-) delete mode 100644 csrc/xentropy/README.md delete mode 100644 csrc/xentropy/interface.cpp delete mode 100644 csrc/xentropy/setup.py delete mode 100644 csrc/xentropy/xentropy_kernel.cu diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md deleted file mode 100644 index 1bc90fdab77..00000000000 --- a/csrc/xentropy/README.md +++ /dev/null @@ -1,14 +0,0 @@ -This CUDA extension implements optimized cross-entropy loss, adapted from Apex's -[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). -We make it work for bfloat16 and support in-place backward to save memory. - -It has only been tested on A100s. - -```sh -cd csrc/xentropy && pip install . -``` - -As of 2023-09-15, this extension is no longer used in the FlashAttention repo. -We've instead switched to a Triton-based -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). -See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details. diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp deleted file mode 100644 index 41a783fd0fc..00000000000 --- a/csrc/xentropy/interface.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include - -// CUDA forward declarations -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes=-1) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - CHECK_INPUT(input); - CHECK_INPUT(labels); - - return softmax_xentropy_cuda(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes=-1) { - CHECK_INPUT(grad_loss); - CHECK_INPUT(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, - smoothing, inplace, total_classes); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); -} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py deleted file mode 100644 index 5079b4f3847..00000000000 --- a/csrc/xentropy/setup.py +++ /dev/null @@ -1,139 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--xentropy") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("xentropy is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="xentropy_cuda_lib", - sources=[ - "interface.cpp", - "xentropy_kernel.cu" - ], - extra_compile_args={ - "cxx": ["-O3"] + generator_flag, - "nvcc": append_nvcc_threads( - ["-O3"] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="xentropy_cuda_lib", - version="0.1", - description="Cross-entropy loss", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu deleted file mode 100644 index 66aab0007ba..00000000000 --- a/csrc/xentropy/xentropy_kernel.cu +++ /dev/null @@ -1,758 +0,0 @@ -// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu -// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). -/** - * From PyTorch: - * - * Copyright (c) 2016- Facebook, Inc (Adam Paszke) - * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - * Copyright (c) 2011-2013 NYU (Clement Farabet) - * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - * - * From Caffe2: - * - * Copyright (c) 2016-present, Facebook Inc. All rights reserved. - * - * All contributions by Facebook: - * Copyright (c) 2016 Facebook Inc. - * - * All contributions by Google: - * Copyright (c) 2015 Google Inc. - * All rights reserved. - * - * All contributions by Yangqing Jia: - * Copyright (c) 2015 Yangqing Jia - * All rights reserved. - * - * All contributions from Caffe: - * Copyright(c) 2013, 2014, 2015, the respective contributors - * All rights reserved. - * - * All other contributions: - * Copyright(c) 2015, 2016 the respective contributors - * All rights reserved. - * - * Caffe2 uses a copyright model similar to Caffe: each contributor holds - * copyright over their contributions to Caffe2. The project versioning records - * all such contribution and copyright details. If a contributor wants to further - * mark their specific copyright on a particular contribution, they should - * indicate their copyright solely in the commit message of the change when it is - * committed. - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - * and IDIAP Research Institute nor the names of its contributors may be - * used to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#include -#include -#include - -#include -#include - -// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -// #else -// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ -// switch(TYPE) \ -// { \ -// case at::ScalarType::Float: \ -// { \ -// using scalar_t_##LEVEL = float; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// case at::ScalarType::Half: \ -// { \ -// using scalar_t_##LEVEL = at::Half; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// default: \ -// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -// } -// #endif - -#define ALIGN_BYTES 16 - -using Tensor = at::Tensor; -using TensorList = at::TensorList; -using ScalarType = at::ScalarType; -using at::acc_type; - -template -struct LogSoftMaxForwardEpilogue { - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} - - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} - - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } - - const AccumT logsum; -}; - -template -struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} - - __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { - return static_cast(gradOutput - std::exp(static_cast(output)) * sum); - } - - const AccumT sum; -}; - - - -const int max_threads = 1024; - -inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { - uint64_t block_size = 1; - uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; - // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); - return dim3(block_size); -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// Regular kernel (fast when dim_size is large; requires inner_size == 1) -//////////////////////////////////////////////////////////////////////////////// - - -template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } -}; - -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } -}; - -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} - - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } - - const AccumT max_k; -}; - -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val; - - __syncthreads(); - - AccumT warpVal = defaultVal; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); - } - __syncwarp(mask); - smem[lane] = warpVal; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal = defaultVal; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal = r(blockVal, smem[i]); - } - smem[0] = blockVal; - } - - // Sync and broadcast - __syncthreads(); - return smem[0]; -} - -template class Reduction1, template class Reduction2, typename AccumT> -__device__ __forceinline__ void -blockReduce(AccumT* smem, - AccumT* reducVal1, - AccumT val1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - AccumT val2, - const Reduction2& r2, - AccumT defaultVal2) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val1; - smem[blockDim.x + threadIdx.x] = val2; - - __syncthreads(); - - AccumT warpVal1 = defaultVal1; - AccumT warpVal2 = defaultVal2; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal1 = r1(warpVal1, smem[lane * 32 + i]); - warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); - } - __syncwarp(mask); - smem[lane] = warpVal1; - smem[lane + blockDim.x] = warpVal2; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal1 = defaultVal1; - AccumT blockVal2 = defaultVal2; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal1 = r1(blockVal1, smem[i]); - blockVal2 = r2(blockVal2, smem[i + blockDim.x]); - } - smem[0] = blockVal1; - smem[blockDim.x] = blockVal2; - } - - // Sync and broadcast - __syncthreads(); - *reducVal1 = smem[0]; - *reducVal2 = smem[blockDim.x]; - __syncthreads(); -} - -template class Reduction, int ILP, typename T, typename AccumT> -__device__ __forceinline__ AccumT -ilpReduce(int shift, - T* data, - int size, - const Reduction& r, - AccumT defaultVal) -{ - typedef typename std::aligned_storage::type LoadT; - AccumT threadVal = defaultVal; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal = r(threadVal, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal = r(threadVal, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) - threadVal = r(threadVal, data[offset]); - - return threadVal; -} - -template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> -__device__ __forceinline__ void -ilpReduce(int shift, - T* data, - int size, - AccumT* reducVal1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - const Reduction2& r2, - AccumT defaultVal2) -{ - typedef typename std::aligned_storage::type LoadT; - - AccumT threadVal1 = defaultVal1; - AccumT threadVal2 = defaultVal2; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal1 = r1(threadVal1, v[j]); - threadVal2 = r2(threadVal2, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) { - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - - *reducVal1 = threadVal1; - *reducVal2 = threadVal2; -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyForward( - accscalar_t *losses, - outscalar_t *max_log_sum_exp, - scalar_t *input, - int64_t *labels, - int64_t classes, - const float smoothing, - const int total_classes) -{ - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - // forward pointers to batch[blockIdx.x] - // each block handles a sample in the mini-batch - input += blockIdx.x * classes; - //output += blockIdx.x * classes; - const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); - - int64_t label = labels[blockIdx.x]; - - // find the max and sum - accscalar_t threadMax, threadSum, max_k, sum_k; - ilpReduce( - shift, input, classes, - &threadMax, MaxFloat(), - -at::numeric_limits::max(), - &threadSum, AddFloat(), - static_cast(0)); - - blockReduce( - sdata, - &max_k, threadMax, Max(), - -at::numeric_limits::max(), - &sum_k, threadSum, Add(), - static_cast(0)); - - accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); - accscalar_t sumAll = blockReduce( - sdata, threadExp, Add(), static_cast(0)); - - Epilogue epilogue(max_k, sumAll); - - // calculate per element loss with label smoothing - // reserve max + log_sum_exp for bprop - if (threadIdx.x == 0) { - accscalar_t lse = max_k + std::log(sumAll); - accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; - losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); - max_log_sum_exp[blockIdx.x] = lse; - } -} - -template -__device__ __forceinline__ void -apply(scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - int last = classes % (ILP * blockDim.x); - - for (; offset < classes - last; offset += blockDim.x * ILP) { - accscalar_t tmpLogits[ILP]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); - } - -#pragma unroll - for (int j = 0; j < ILP; ++j) - gradInput[offset + j * blockDim.x] = tmpGradOutput * ( - std::exp(tmpLogits[j] - coeff) - static_cast( - (offset + j * blockDim.x == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast((offset == label) ? 1 : 0) * - smooth_positives - smooth_negatives); -} - - -template -__device__ __forceinline__ void -aligned_apply(int shift, - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - logits -= shift; - gradInput -= shift; - classes += shift; - if(threadIdx.x >= shift){ - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - classes -= blockDim.x; - gradInput += blockDim.x; - logits += blockDim.x; - shift -= blockDim.x; - } - - int last = classes % (ILP * blockDim.x); - - typedef typename std::aligned_storage::type LoadT; - // input - scalar_t v[ILP]; - LoadT* value = reinterpret_cast(&v); - // output - scalar_t r[ILP]; - LoadT* result = reinterpret_cast(&r); - - for (; offset * ILP < (classes - last); offset += blockDim.x) { - *value = reinterpret_cast(logits)[offset]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - r[j] = tmpGradOutput * (std::exp( - static_cast(v[j]) - coeff) - - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - reinterpret_cast(gradInput)[offset] = *result; - } - - offset = classes - last + threadIdx.x; - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyBackward( - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - gradInput += blockIdx.x * classes; - logits += blockIdx.x * classes; - - // Do vectorized load/store when input/output have same alignment - const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); - const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); - if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - -} - -template class Epilogue> -std::vector host_softmax_xentropy( - const Tensor & input_, - const Tensor & labels_, - const float smoothing, - const int total_classes) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{input_.device()}; - - auto input = input_.contiguous(); - Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); - Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); - - const int64_t dim = 1; - int64_t outer_size = 1; - int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - using namespace at; - DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", - using accscalar_t = at::acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyForward - <<>>( - losses.data_ptr(), max_log_sum_exp.data_ptr(), - input.data_ptr(), labels_.data_ptr(), - dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - - std::vector ret = {losses, max_log_sum_exp}; - return ret; -} - -template class Epilogue> -Tensor host_softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits_, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - bool inplace, - const int total_classes) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{grad_loss.device()}; - - const int64_t dim = 1; - Tensor gI = inplace ? logits_ : at::empty_like(logits_); - if (grad_loss.numel() == 0) { - return gI; - } - - auto grad = grad_loss.contiguous(); - auto logits = logits_.contiguous(); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - if (grad.dim() == 0) grad = grad.view(1); - - AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); - AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); - - int64_t outer_size = 1; - int64_t dim_size = logits.size(dim); - int64_t inner_size = 1; - for (int64_t i = 0; i < dim; ++i) - outer_size *= logits.size(i); - for (int64_t i = dim + 1; i < logits.dim(); ++i) - inner_size *= logits.size(i); - // See descriptions of kernels above. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", - using accscalar_t = acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyBackward - <<>>( - gI.data_ptr(), logits.data_ptr(), - max_log_sum_exp.data_ptr(), - grad.data_ptr(), labels.data_ptr(), - smoothing, dim_size, total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - return gI; -} - -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ - return host_softmax_xentropy(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes) { - AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); -} From 3edef7c07220a1ec44c8729d61e9c5afc53928a4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Aug 2025 11:30:00 -0400 Subject: [PATCH 068/258] Remove old fused softmax kernel from apex/Megatron --- benchmarks/benchmark_causal.py | 30 - csrc/fused_softmax/fused_softmax.cpp | 148 ----- csrc/fused_softmax/scaled_masked_softmax.h | 528 ----------------- .../scaled_masked_softmax_cuda.cu | 121 ---- .../scaled_upper_triang_masked_softmax.h | 529 ------------------ ...scaled_upper_triang_masked_softmax_cuda.cu | 98 ---- csrc/fused_softmax/setup.py | 50 -- csrc/fused_softmax/type_shim.h | 20 - flash_attn/fused_softmax.py | 201 ------- 9 files changed, 1725 deletions(-) delete mode 100644 csrc/fused_softmax/fused_softmax.cpp delete mode 100644 csrc/fused_softmax/scaled_masked_softmax.h delete mode 100644 csrc/fused_softmax/scaled_masked_softmax_cuda.cu delete mode 100644 csrc/fused_softmax/scaled_upper_triang_masked_softmax.h delete mode 100644 csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu delete mode 100644 csrc/fused_softmax/setup.py delete mode 100644 csrc/fused_softmax/type_shim.h delete mode 100644 flash_attn/fused_softmax.py diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 6c4797c83e0..c97581c6581 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -17,12 +17,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func -try: - from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax -except ImportError: - scaled_upper_triang_masked_softmax = None - - def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: @@ -52,27 +46,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): return output.to(dtype=qkv.dtype) -def attention_megatron(qkv): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0) - output = torch.einsum('bhts,bshd->bthd', attention, v) - return output.to(dtype=qkv.dtype) - - torch.manual_seed(0) repeats = 30 batch_size = 8 @@ -130,9 +103,6 @@ def attention_megatron(qkv): # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') # # pytorch_profiler(attention, q, k, v, 1.0, backward=True) -# if scaled_upper_triang_masked_softmax is not None: -# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') - # from src.ops.fftconv import fftconv_func # dim = nheads * headdim diff --git a/csrc/fused_softmax/fused_softmax.cpp b/csrc/fused_softmax/fused_softmax.cpp deleted file mode 100644 index 2aaed913314..00000000000 --- a/csrc/fused_softmax/fused_softmax.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - - m.def("scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); - - m.def("scaled_masked_softmax_get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); - - m.def("scaled_upper_triang_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("scaled_upper_triang_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/csrc/fused_softmax/scaled_masked_softmax.h b/csrc/fused_softmax/scaled_masked_softmax.h deleted file mode 100644 index 14b9f6e4242..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax.h +++ /dev/null @@ -1,528 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 12: // 4096 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 13: // 8192 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_masked_softmax_cuda.cu deleted file mode 100644 index a08e752699c..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,121 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches - ); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - void* input_grads_ptr = static_cast(input_grads.data_ptr()); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(input_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads - ); - ); - return input_grads; -} -} -} -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h b/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index 21e93fb313a..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,529 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 79ec30be364..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,98 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 8192); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/fused_softmax/setup.py b/csrc/fused_softmax/setup.py deleted file mode 100644 index 9c1c6ed76e9..00000000000 --- a/csrc/fused_softmax/setup.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron -# We add the case where seqlen = 4k and seqlen = 8k -import os -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") - -setup( - name='fused_softmax_lib', - ext_modules=[ - CUDAExtension( - name='fused_softmax_lib', - sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ) - ], - cmdclass={ - 'build_ext': BuildExtension -}) diff --git a/csrc/fused_softmax/type_shim.h b/csrc/fused_softmax/type_shim.h deleted file mode 100644 index 815ec7ec889..00000000000 --- a/csrc/fused_softmax/type_shim.h +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ -switch(TYPE) \ -{ \ -case at::ScalarType::Half: \ - { \ -using scalar_t = at::Half; \ -__VA_ARGS__; \ -break; \ - } \ -case at::ScalarType::BFloat16: \ - { \ -using scalar_t = at::BFloat16; \ -__VA_ARGS__; \ -break; \ - } \ -default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -} diff --git a/flash_attn/fused_softmax.py b/flash_attn/fused_softmax.py deleted file mode 100644 index 382f94f092c..00000000000 --- a/flash_attn/fused_softmax.py +++ /dev/null @@ -1,201 +0,0 @@ -# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py -# for benchmarking. -# We added support for seqlen=2k and seqlen=4k - -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType -from fused_softmax_lib import ( - scaled_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_get_batch_per_block, - scaled_upper_triang_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, -) - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 8192 # sk must be 16 ~ 8192 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 8192: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) From 2715c53932c28e81c15ad4d1690639b77ddda6c1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Aug 2025 11:32:00 -0400 Subject: [PATCH 069/258] Remove old attn decode kernel from FasterTransformer --- csrc/ft_attention/README.md | 14 - csrc/ft_attention/cuda_bf16_fallbacks.cuh | 257 --- csrc/ft_attention/cuda_bf16_wrapper.h | 23 - .../decoder_masked_multihead_attention.cu | 149 -- .../decoder_masked_multihead_attention.h | 192 -- ...er_masked_multihead_attention_template.hpp | 1619 ------------- ...decoder_masked_multihead_attention_utils.h | 2017 ----------------- csrc/ft_attention/ft_attention.cpp | 231 -- csrc/ft_attention/setup.py | 153 -- 9 files changed, 4655 deletions(-) delete mode 100644 csrc/ft_attention/README.md delete mode 100644 csrc/ft_attention/cuda_bf16_fallbacks.cuh delete mode 100644 csrc/ft_attention/cuda_bf16_wrapper.h delete mode 100644 csrc/ft_attention/decoder_masked_multihead_attention.cu delete mode 100644 csrc/ft_attention/decoder_masked_multihead_attention.h delete mode 100644 csrc/ft_attention/decoder_masked_multihead_attention_template.hpp delete mode 100644 csrc/ft_attention/decoder_masked_multihead_attention_utils.h delete mode 100644 csrc/ft_attention/ft_attention.cpp delete mode 100644 csrc/ft_attention/setup.py diff --git a/csrc/ft_attention/README.md b/csrc/ft_attention/README.md deleted file mode 100644 index 97feb78cc1c..00000000000 --- a/csrc/ft_attention/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Attention kernel from FasterTransformer - -This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from -FasterTransformer v5.2.1 for benchmarking purpose. - -```sh -cd csrc/ft_attention && pip install . -``` - -As of 2023-09-17, this extension is no longer used in the FlashAttention repo. -FlashAttention now has implemented -[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py) -with all the features of this `ft_attention` kernel (and more). - diff --git a/csrc/ft_attention/cuda_bf16_fallbacks.cuh b/csrc/ft_attention/cuda_bf16_fallbacks.cuh deleted file mode 100644 index f5641f61609..00000000000 --- a/csrc/ft_attention/cuda_bf16_fallbacks.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include - -namespace fastertransformer { - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; -} - -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace fastertransformer diff --git a/csrc/ft_attention/cuda_bf16_wrapper.h b/csrc/ft_attention/cuda_bf16_wrapper.h deleted file mode 100644 index efb6e798730..00000000000 --- a/csrc/ft_attention/cuda_bf16_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu deleted file mode 100644 index 13306f76868..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,149 +0,0 @@ -// Adapted from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include -#include -#include - -#include "decoder_masked_multihead_attention_template.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - auto kernel = mmha::masked_multihead_attention_kernel; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ - kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#undef MMHA_LAUNCH_KERNEL - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h deleted file mode 100644 index 3c79f88b856..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,192 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride_q = 0; - int stride_k = 0; - int stride_v = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - int num_heads_kv = 0; - int num_heads_q_kv_ratio = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - float rotary_base = 0.0f; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; - - const T *rotary_cos = nullptr; - const T *rotary_sin = nullptr; - - const int *nnz_head_idx = nullptr; - int nnz_heads = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 2ae1b2425b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1619 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -#define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ { -}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ { -}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ { -}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - // const int hi = blockIdx.x; - const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x]; - const int hi_kv = hi / params.num_heads_q_kv_ratio; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bhi_kv = bi * params.num_heads_kv + hi_kv; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv; - // The thread in the block. - const int tidx = threadIdx.x; - - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh; - int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh; - int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - int q_offset = q_base_offset + tidx * QK_VEC_SIZE; - int k_offset = k_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = *reinterpret_cast(¶ms.q[q_offset]); - } - } - - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = *reinterpret_cast(¶ms.k[k_offset]); - } - } - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : - q_bias; - - Qk_vec k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - else { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (has_beams) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); - } - else { - k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); - } - } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE])); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti - first_step]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - const auto v_offset = v_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = *reinterpret_cast(¶ms.v[v_offset]); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - if (hi % params.num_heads_q_kv_ratio == 0) { - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; - } - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength - first_step], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h deleted file mode 100644 index 98875aba9b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h +++ /dev/null @@ -1,2017 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include - -using namespace fastertransformer; - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct num_elems; -template<> -struct num_elems { - static constexpr int value = 1; -}; -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -#ifdef ENABLE_BF16 -template<> -struct num_elems<__nv_bfloat162> { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct packed_type; -template -struct packed_type { - using type = T; -}; -template<> -struct packed_type { - using type = int16_t; -}; -template<> -struct packed_type { - using type = int32_t; -}; -template<> -struct packed_type { - using type = int64_t; -}; - -template<> -struct packed_type { - using type = float2; -}; -template<> -struct packed_type { - using type = float4; -}; -template<> -struct packed_type { - using type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, uint16_t b) -{ - return a + half_to_float(b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float add(float a, __nv_bfloat16 b) -{ - return a + __bfloat162float(b); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(float a, Float8_ b) -{ - Float8_ c; - c.x = make_float2(a * b.x.x, a * b.x.y); - c.y = make_float2(a * b.y.x, a * b.y.y); - c.z = make_float2(a * b.z.x, a * b.z.y); - c.w = make_float2(a * b.w.x, a * b.w.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, float b) -{ - return half_to_float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, float b) -{ - return __bfloat162float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base) -{ - const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); - return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)}; -} - -inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) -{ - float2 rot_v; - rot_v.x = coef.x * v.x - coef.y * v.y; - rot_v.y = coef.x * v.y + coef.y * v.x; - return rot_v; -} - -inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) -{ - float2 fv = half2_to_float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return float2_to_half2(rot_fv); -} - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) -{ - float2 fv = bf1622float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); -} -#endif - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])}; -} - -// fp16 is special because we use uint16_t for reading the data, for backward compatibility. -template <> -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(reinterpret_cast(rotary_cos)[zid / 2]), - float(reinterpret_cast(rotary_sin)[zid / 2])}; -} - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - - vec = tmp_3.u32x2; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - tmp_3.u16[4] = tmp_1.u16[2]; - tmp_3.u16[5] = tmp_2.u16[2]; - tmp_3.u16[6] = tmp_1.u16[3]; - tmp_3.u16[7] = tmp_2.u16[3]; - - vec = tmp_3.u32x4; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - __nv_bfloat16 bf16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; -} - -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - __nv_bfloat16 bf16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; - vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; - vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; -} -#endif // ENABLE_BF16 - -template<> -__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.z = smem[transpose_idx + 1]; - vec.y = smem[smem_pitch + transpose_idx]; - vec.w = smem[smem_pitch + transpose_idx + 1]; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} -#endif - -template<> -__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} - -template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u32x4 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - tmp_1.u16[2] = tmp_3.u16[4]; - tmp_2.u16[2] = tmp_3.u16[5]; - tmp_1.u16[3] = tmp_3.u16[6]; - tmp_2.u16[3] = tmp_3.u16[7]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u32x2 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = vec; - - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -template<> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[transpose_idx + 1] = vec.z; - smem[smem_pitch + transpose_idx] = vec.y; - smem[smem_pitch + transpose_idx + 1] = vec.w; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - - tmp.u32 = vec; - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} -#endif - -template<> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -} // namespace mmha diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp deleted file mode 100644 index 886da9729ba..00000000000 --- a/csrc/ft_attention/ft_attention.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" -#include - - -#include "decoder_masked_multihead_attention.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ - if (TYPE == at::ScalarType::Half) { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::BFloat16) { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::Float) { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ - } - -template -void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -void cross_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -struct SATypeConverter { - using Type = T; -}; - -template<> -struct SATypeConverter { - using Type = uint16_t; -}; - -template<> -struct SATypeConverter { - using Type = __nv_bfloat16; -}; - -template -void set_params(Masked_multihead_attention_params ¶ms, - const size_t batch_size, - const size_t nheads, - const size_t nheads_kv, - const size_t memory_max_seqlen, - const size_t headdim, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - const bool neox_rotary_style, - const int q_batch_stride, - const int k_batch_stride, - const int v_batch_stride, - const int nnz_heads, - T *q_ptr, - T *k_ptr, - T *v_ptr, - T *k_cache_ptr, - T *v_cache_ptr, - int *length_per_sample, - T *rotary_cos, - T *rotary_sin, - T *out_ptr, - int *nnz_head_idx) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.q_bias = nullptr; - params.k_bias = nullptr; - params.v_bias = nullptr; - params.k_cache = k_cache_ptr; - params.v_cache = v_cache_ptr; - params.out = out_ptr; - params.cache_indir = nullptr; - params.stride_q = q_batch_stride; - params.stride_k = k_batch_stride; - params.stride_v = v_batch_stride; - params.batch_size = batch_size; - params.beam_width = 1; - params.memory_max_len = memory_max_seqlen; - params.num_heads = nheads; - params.num_heads_kv = nheads_kv; - params.num_heads_q_kv_ratio = nheads / nheads_kv; - params.nnz_heads = nnz_heads; - params.hidden_size_per_head = headdim; - params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_base = rotary_base; - params.neox_rotary_style = neox_rotary_style; - params.timestep = timestep; - params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); - params.total_padding_tokens = nullptr; - params.masked_tokens = nullptr; - params.prefix_prompt_lengths = nullptr; - params.max_prefix_prompt_length = 0; - params.relative_attention_bias = nullptr; - params.relative_attention_bias_stride = 0; - params.cross_attention_out = nullptr; - params.max_decoder_seq_len = 0; - params.is_return_cross_attentions = false; - params.finished = nullptr; - params.memory_length_per_sample = nullptr; - params.length_per_sample = length_per_sample; - params.rotary_cos = rotary_cos; - params.rotary_sin = rotary_sin; - params.nnz_head_idx = nnz_head_idx; -} - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - std::optional length_per_sample_, - std::optional rotary_cos_, - std::optional rotary_sin_, - std::optional nnz_head_idx_, - const int timestep, - int rotary_embedding_dim = 0, - const float rotary_base = 10000.0f, - const bool neox_rotary_style=true) { - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); - int batch_size = v_cache.size(0); - int nheads = q.size(1); - int nheads_kv = v_cache.size(1); - int memory_max_seqlen = v_cache.size(2); - int headdim = v_cache.size(3); - auto input_type = q.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - - CHECK_SHAPE(q, batch_size, nheads, headdim); - CHECK_SHAPE(k, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); - // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; - CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); - TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); - TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); - TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); - CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); - - TORCH_CHECK(q.scalar_type() == input_type); - TORCH_CHECK(k.scalar_type() == input_type); - TORCH_CHECK(v.scalar_type() == input_type); - TORCH_CHECK(k_cache.scalar_type() == input_type); - TORCH_CHECK(v_cache.scalar_type() == input_type); - - if (length_per_sample_.has_value()) { - auto length_per_sample = length_per_sample_.value(); - CHECK_DEVICE(length_per_sample); - CHECK_SHAPE(length_per_sample, batch_size); - CHECK_CONTIGUOUS(length_per_sample); - TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); - rotary_embedding_dim = rotary_cos.size(-1) * 2; - CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_cos); - TORCH_CHECK(rotary_cos.scalar_type() == input_type); - - TORCH_CHECK(rotary_sin_.has_value()); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); - CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_sin); - TORCH_CHECK(rotary_sin.scalar_type() == input_type); - } - - if (nnz_head_idx_.has_value()) { - auto nnz_head_idx = nnz_head_idx_.value(); - CHECK_DEVICE(nnz_head_idx); - int nnz_heads = nnz_head_idx.size(0); - CHECK_SHAPE(nnz_head_idx, nnz_heads); - CHECK_CONTIGUOUS(nnz_head_idx); - TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32); - } - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - torch::Tensor out = torch::empty_like(q); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { - using DataType = typename SATypeConverter::Type; - Masked_multihead_attention_params params; - set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, rotary_base, neox_rotary_style, - q.stride(0), k.stride(0), v.stride(0), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - length_per_sample_.has_value() - ? length_per_sample_.value().data_ptr() : nullptr, - rotary_cos_.has_value() - ? reinterpret_cast(rotary_cos_.value().data_ptr()) : nullptr, - rotary_sin_.has_value() - ? reinterpret_cast(rotary_sin_.value().data_ptr()) : nullptr, - reinterpret_cast(out.data_ptr()), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr() : nullptr - ); - auto stream = at::cuda::getCurrentCUDAStream(); - masked_multihead_attention(params, stream); - }); - return out; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_query_attention", &single_query_attention, "Attention with a single query", - py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("rotary_cos_"), - py::arg("rotary_sin_"), py::arg("nnz_head_idx_"), - py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); -} diff --git a/csrc/ft_attention/setup.py b/csrc/ft_attention/setup.py deleted file mode 100644 index fa385ad768c..00000000000 --- a/csrc/ft_attention/setup.py +++ /dev/null @@ -1,153 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--ft_attention") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("ft_attention is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="ft_attention", - sources=[ - "ft_attention.cpp", - "decoder_masked_multihead_attention.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-DENABLE_BF16", # TODO - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="ft_attention", - version="0.1", - description="Attention for single query from FasterTransformer", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) From f28841db5043c6a329869b6c3e4e3f5f5ebdc1a0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Aug 2025 11:33:51 -0400 Subject: [PATCH 070/258] Remove old rotary kernel --- csrc/rotary/rotary.cpp | 40 ------------ csrc/rotary/rotary_cuda.cu | 45 ------------- csrc/rotary/setup.py | 126 ------------------------------------- 3 files changed, 211 deletions(-) delete mode 100644 csrc/rotary/rotary.cpp delete mode 100644 csrc/rotary/rotary_cuda.cu delete mode 100644 csrc/rotary/setup.py diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp deleted file mode 100644 index 640eea423ac..00000000000 --- a/csrc/rotary/rotary.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj); - -void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - CHECK_DEVICE(x1); CHECK_DEVICE(x2); - CHECK_DEVICE(cos); CHECK_DEVICE(sin); - CHECK_DEVICE(out1); CHECK_DEVICE(out1); - TORCH_CHECK(x1.dtype() == x2.dtype()); - TORCH_CHECK(cos.dtype() == sin.dtype()); - TORCH_CHECK(out1.dtype() == out2.dtype()); - TORCH_CHECK(x1.dtype() == cos.dtype()); - TORCH_CHECK(x1.dtype() == out1.dtype()); - TORCH_CHECK(x1.sizes() == x2.sizes()); - TORCH_CHECK(cos.sizes() == sin.sizes()); - TORCH_CHECK(out1.sizes() == out2.sizes()); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{x1.device()}; - - apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); -} diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu deleted file mode 100644 index 2dd0ff3f6e2..00000000000 --- a/csrc/rotary/rotary_cuda.cu +++ /dev/null @@ -1,45 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - auto iter = at::TensorIteratorConfig() - .add_output(out1) - .add_output(out2) - .add_input(x1) - .add_input(x2) - .add_input(cos) - .add_input(sin) - .check_all_same_dtype(false) - .promote_inputs_to_common_dtype(false) - .build(); - - if (!conj) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); - scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); - scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } -} \ No newline at end of file diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py deleted file mode 100644 index 24d328d9c6a..00000000000 --- a/csrc/rotary/setup.py +++ /dev/null @@ -1,126 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -raise_if_cuda_home_none("rotary_emb") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - 'rotary_emb', [ - 'rotary.cpp', - 'rotary_cuda.cu', - ], - extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], - 'nvcc': append_nvcc_threads([ - '-O3', '--use_fast_math', '--expt-extended-lambda' - ] + cc_flag) - } - ) -) - -setup( - name="rotary_emb", - version="0.1", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) From a1c2e22817960fd68933d46747db39d930ac2c8f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 12 Aug 2025 14:51:16 -0400 Subject: [PATCH 071/258] [Cute] Implement page table with TMA for fwd_sm100 --- flash_attn/cute/flash_fwd.py | 11 +- flash_attn/cute/flash_fwd_sm100.py | 67 +++-- flash_attn/cute/interface.py | 38 ++- flash_attn/cute/tile_scheduler.py | 29 +- tests/cute/test_flash_attn.py | 430 ++++++++++++++++++++++++++++- 5 files changed, 525 insertions(+), 50 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 61333ca7357..c71a049c752 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1058,10 +1058,10 @@ class SharedStorageSharedQV: @cute.jit def __call__( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mO: cute.Tensor, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: cutlass.Float32, stream: cuda.CUstream, @@ -1069,6 +1069,7 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) softcap: cutlass.Float32 | float | None = None, window_size_left: cutlass.Int32 | int | None = None, window_size_right: cutlass.Int32 | int | None = None, @@ -1169,7 +1170,7 @@ def __call__( mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - block_size=self.m_block_size, + tile_shape_mn=(self.m_block_size, self.n_block_size), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index d630668aa8d..0a0dae7eb12 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -179,10 +179,10 @@ def _setup_attributes(self): @cute.jit def __call__( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mO: cute.Tensor, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: Float32, stream: cuda.CUstream, @@ -190,6 +190,7 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, @@ -222,6 +223,7 @@ def __call__( cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) for t in (mQ, mO) ] + # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) @@ -384,11 +386,11 @@ def __call__( cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - cute.size(mK.shape[0]), + cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mQ.shape[1], mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - block_size=self.cta_tiler[0], + tile_shape_mn=self.cta_tiler[:2], mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -470,6 +472,7 @@ class SharedStorage: mCuSeqlensK, mSeqUsedQ, mSeqUsedK, + mPageTable, tma_atom_Q, tma_atom_K, tma_atom_V, @@ -501,15 +504,16 @@ class SharedStorage: @cute.kernel def kernel( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, + mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q + mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table + mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table mO: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, @@ -651,8 +655,9 @@ def kernel( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfo, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], - seqlen_k_static=mK.shape[0], + SeqlenInfo, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) @@ -684,6 +689,7 @@ def kernel( sQ, sK, sV, + mPageTable, tma_atom_Q, tma_atom_K, tma_atom_V, @@ -819,6 +825,7 @@ def load( sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, + mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, @@ -841,18 +848,24 @@ def load( else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx - if const_expr(not seqlen.has_cu_seqlens_k): - mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + if const_expr(mPageTable is None): + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) + mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) else: - mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) - mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) - - gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + # Need to keep batch coord None since we'll index into it with page idx + mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)) tSgQ = thr_mma_qk.partition_A(gQ) - gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) tSgK = thr_mma_qk.partition_B(gK) - gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) tOgV = thr_mma_pv.partition_B(gV) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, @@ -896,18 +909,21 @@ def load( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - load_K(block=n_block_max - 1, producer_state=kv_producer_state) # K0 + page_idx = mPageTable[batch_idx, n_block_max - 1] if const_expr(mPageTable is not None) else None + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() if const_expr(self.q_stage == 2): load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state) # V0 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i - load_K(block=n_block, producer_state=kv_producer_state) # Ki + page_idx = mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state) # Vi + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi kv_producer_state.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1792,6 +1808,7 @@ def load_KV( block: Int32, producer_state: cutlass.pipeline.PipelineState, K_or_V: str, + page_idx: Optional[Int32] = None, ): assert K_or_V in ("K", "V") tma_copy_bytes = self.tma_copy_k_bytes if const_expr(K_or_V == "K") else self.tma_copy_v_bytes @@ -1808,7 +1825,9 @@ def load_KV( if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) - cute.copy(tma_atom, tXgX[None, block], tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) @cute.jit def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 3e154ace813..4a7b903a175 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -57,6 +57,7 @@ def _flash_attn_fwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, @@ -80,11 +81,26 @@ def _flash_attn_fwd( batch_size = cu_seqlens_q.shape[0] - 1 seqlen_q = None total_q = q.shape[0] - seqlen_k, num_head_kv, _ = k.shape[-3:] + if page_table is not None: + assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" + assert page_table.dtype == torch.int32, "page_table must be int32" + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" + max_num_pages_per_seq = page_table.shape[1] + assert page_table.shape == (batch_size, max_num_pages_per_seq) + num_pages, page_size = k.shape[:2] + seqlen_k = num_pages * page_size + else: + num_pages, page_size = None, None + seqlen_k = k.shape[-3] + num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] if cu_seqlens_k is None: - assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) - assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + if page_table is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) @@ -102,7 +118,7 @@ def _flash_attn_fwd( if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)), "inputs must be on CUDA device" + assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -132,6 +148,7 @@ def _flash_attn_fwd( from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] + page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None if causal: window_size_right = 0 local = window_size_left is not None or window_size_right is not None @@ -151,6 +168,7 @@ def _flash_attn_fwd( compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + page_table is not None, window_size_left is not None, window_size_right is not None, learnable_sink is not None, m_block_size, n_block_size, num_threads, @@ -158,6 +176,7 @@ def _flash_attn_fwd( ) if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: + assert page_table is None, "paged KV not supported on SM 9.0" assert learnable_sink is None, "Sm90 doesn't support additive sink" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( @@ -176,6 +195,7 @@ def _flash_attn_fwd( Q_in_regs=False, ) elif compute_capability == 10: + assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -190,11 +210,13 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, softcap, window_size_left, window_size_right, additive_sink_tensor, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, softcap, window_size_left, window_size_right, additive_sink_tensor, ) return out, lse @@ -446,8 +468,9 @@ def forward( v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -462,6 +485,7 @@ def forward( cu_seqlens_k, seqused_q, seqused_k, + page_table=page_table, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], @@ -514,6 +538,7 @@ def flash_attn_varlen_func( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -528,6 +553,7 @@ def flash_attn_varlen_func( cu_seqlens_k, seqused_q, seqused_k, + page_table, softmax_scale, causal, window_size, diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index c7fad36b22a..58e9d776df2 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -44,7 +44,7 @@ class TileSchedulerArguments(ParamsBase): headdim: Int32 headdim_v: Int32 total_q: Int32 - block_size: cutlass.Constexpr[int] + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @@ -235,7 +235,7 @@ class Params(ParamsBase): def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTScheduler.Params": - # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.block_size, args.qhead_per_kvhead_packgqa, args.element_size) + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V @@ -393,7 +393,7 @@ class Params(ParamsBase): num_batch: Int32 total_q: Int32 max_kvblock_in_l2: Int32 - block_size: cutlass.Constexpr[int] + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @@ -405,13 +405,13 @@ def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.block_size) + max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, total_q=args.total_q, max_kvblock_in_l2=max_kvblock_in_l2, - block_size=args.block_size, + tile_shape_mn=args.tile_shape_mn, mCuSeqlensQ=args.mCuSeqlensQ, mSeqUsedQ=args.mSeqUsedQ, qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, @@ -426,7 +426,7 @@ def __init__( tile_idx: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, - block_size: cutlass.Constexpr[int] = 128, + tile_shape_mn: cutlass.Constexpr[[int, int]] = (128, 128), qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, lpt: cutlass.Constexpr[bool] = False, *, @@ -441,7 +441,7 @@ def __init__( assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) - self.block_size = block_size + self.tile_shape_mn = tile_shape_mn self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa self.lpt = lpt self._tile_idx = tile_idx @@ -463,7 +463,7 @@ def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": tile_idx, mCuSeqlensQ=params.mCuSeqlensQ, mSeqUsedQ=params.mSeqUsedQ, - block_size=params.block_size, + tile_shape_mn=params.tile_shape_mn, qhead_per_kvhead_packgqa=params.qhead_per_kvhead_packgqa, lpt=params.lpt, loc=loc, @@ -479,8 +479,8 @@ def get_grid_shape( ip=None, ) -> Tuple[Int32, Int32, Int32]: total_blocks_max = ( - params.total_q + params.num_batch * (params.block_size - 1) - ) // params.block_size + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] return (total_blocks_max * params.num_head, Int32(1), Int32(1)) @cute.jit @@ -500,7 +500,7 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): seqlen *= self.qhead_per_kvhead_packgqa return ( - cute.ceil_div(seqlen, self.block_size) + cute.ceil_div(seqlen, self.tile_shape_mn[0]) if batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 else Int32(0) ) @@ -555,9 +555,10 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_m_blocks, 1), self.num_head) + num_n_blocks = num_m_blocks * self.tile_shape_mn[0] // self.tile_shape_mn[1] + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) # Seems faster to have this be a power of 2 - nheads_in_l2 = 16 if num_m_blocks * 16 <= self.max_kvblock_in_l2 else (8 if num_m_blocks * 8 <= self.max_kvblock_in_l2 else (4 if num_m_blocks * 4 <= self.max_kvblock_in_l2 else (2 if num_m_blocks * 2 <= self.max_kvblock_in_l2 else 1))) + nheads_in_l2 = 16 if num_n_blocks * 16 <= self.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= self.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= self.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1))) nheads_in_l2 = min(nheads_in_l2, self.num_head) mh_in_l2 = nheads_in_l2 * num_m_blocks section_idx = mh_block // mh_in_l2 @@ -619,7 +620,7 @@ def __new_from_mlir_values__(self, values): values = values[n_items:] return SingleTileVarlenScheduler( *(tuple(obj_list)), - block_size=self.block_size, + tile_shape_mn=self.tile_shape_mn, qhead_per_kvhead_packgqa=self.qhead_per_kvhead_packgqa, lpt=self.lpt, loc=self._loc, diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 61da6991c79..eaf351f3977 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -12,7 +12,7 @@ except ImportError: apply_rotary_emb = None -# from padding import pad_input, unpad_input +from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.utils.testing import attention_ref, generate_qkv, generate_random_padding_mask from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func @@ -549,3 +549,431 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +@pytest.mark.parametrize("page_size", [None, 128]) +# @pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # # (1, 128 * 1024), + # # (16, 128 * 1024), + # (128, 128), + # (256, 512), # To test appending KV with more than 1 block + # (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + has_learnable_sink, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + # has_qv = d == 64 and dv >= 256 + has_qv = False + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + # num_splits_vals = [1, 0] + num_splits_vals = [1] + # precompute_metadata_vals = [False, True] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + # if precompute_metadata: + # scheduler_metadata = get_scheduler_metadata( + # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + # max_seqlen_k_new=seqlen_new, page_size=page_size, + # causal=causal, window_size=window_size, attention_chunk=attention_chunk, + # num_splits=num_splits + # ) + # else: + # scheduler_metadata = None + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + # out, lse, *rest = flash_attn_with_kvcache( + out, lse, *rest = flash_attn_varlen_func( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + # k if not new_kv or not varlen_q else k_unpad, + # v if not new_kv or not varlen_q else v_unpad, + # qv=qv if not varlen_q else qv_unpad, + # rotary_cos=cos, + # rotary_sin=sin, + seqused_k=cache_seqlens, + # cache_batch_idx=cache_batch_idx, + # cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, + # rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + # attention_chunk=attention_chunk, + # rotary_interleaved=rotary_interleaved, + # scheduler_metadata=scheduler_metadata, + # num_splits=num_splits, + # return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks From 581b68d5a9cabbae959d4a4f99b13c30cdbbf689 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 12 Aug 2025 17:59:35 -0700 Subject: [PATCH 072/258] [Cute] Remove trailing bracket (#1809) This fixes Commit 81cdf4c --- flash_attn/cute/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 193b369eba7..81c0caeb431 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -532,6 +532,3 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo vector.extract(out_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) ) return out0, out1 - - - ) From 3c51f15dc04c05e97cae1cfbd494e1f02962516a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 13 Aug 2025 12:33:12 -0400 Subject: [PATCH 073/258] [Cute] Make sure R2P happen --- flash_attn/cute/mask.py | 12 ++++++++---- flash_attn/cute/utils.py | 19 ++++++++----------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 1415cf1b65c..d5cb09db7b4 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -156,12 +156,14 @@ def apply_mask_sm100( # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_cur = %d", mask, col_limit_right_s, col_limit_right_cur) - for i in cutlass.range(16, unroll_full=True): + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(16): # mask >> i does not produce correct result for 0b11..11 >> 31 # However, if we use utils.shr_u32, the compiler doesn't generate # the R2P instruction, so it's slower. # Instead we just move by 16 instead of 32. - mask_i_bit = cutlass.Boolean((mask >> i) & 1) + mask_i_bit = cutlass.Boolean(mask & (1 << i)) # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf @@ -193,9 +195,11 @@ def apply_mask_sm100( col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) - for i in cutlass.range(16, unroll_full=True): + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(16): # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) - mask_i_bit = cutlass.Boolean((mask >> i) & 1) + mask_i_bit = cutlass.Boolean(mask & (1 << i)) acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf # This is the equivalent of: # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 81c0caeb431..02e19ad4cda 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -487,14 +487,14 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor): @dsl_user_op def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: out_f32x2 = llvm.inline_asm( - T.vector(2, T.f32()), + llvm.StructType.get_literal([T.f32(), T.f32()]), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], "{\n\t" ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" - "max.ftz.f32 f1, $1, 0fC2FE0000;\n\t" - "max.ftz.f32 f2, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" "mov.b64 l1, {f1, f2};\n\t" "mov.f32 f3, 0f4B400000;\n\t" "mov.b64 l2, {f3, f3};\n\t" @@ -518,17 +518,14 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo "add.s32 r7, r5, r3;\n\t" "shl.b32 r6, r2, 23;\n\t" "add.s32 r8, r6, r4;\n\t" - "mov.b64 $0, {r7, r8};\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" "}\n", - "=l,f,f", + "=r,=r,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) - out0 = Float32( - vector.extract(out_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) - ) - out1 = Float32( - vector.extract(out_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) - ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 From d2e3fc30f02426e0c2a06ad45791b19491c92760 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 14 Aug 2025 03:45:49 +0700 Subject: [PATCH 074/258] feat: add support for pytorch2.8 (#1801) --- .github/workflows/publish.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0a6a57510d7..8d2ea71e4df 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-22.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1'] + torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1', '2.8.0'] cuda-version: ['12.9.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -111,8 +111,8 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then From 69b33b5324938278eb669056daf19bb205d782d7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 14 Aug 2025 12:36:04 -0400 Subject: [PATCH 075/258] [Cute] Implement PackGQA with TMA for fwd_sm100 Credit: Jay Shah's idea --- benchmarks/benchmark_attn.py | 14 +++--- flash_attn/cute/flash_fwd.py | 2 +- flash_attn/cute/flash_fwd_sm100.py | 77 ++++++++++++++++++++++-------- flash_attn/cute/interface.py | 22 +++++++-- tests/cute/test_flash_attn.py | 27 +++++------ 5 files changed, 97 insertions(+), 45 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 147b00f15b3..b3902110eea 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -228,6 +228,7 @@ def run(*args, **kwargs): varlen = False has_backward = False page_size = None +# page_size = 128 softcap = 0.0 V_colmajor = False deterministic = False @@ -257,15 +258,16 @@ def run(*args, **kwargs): # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: for headdim in [128]: - nheads = dim // headdim + # nheads = dim // headdim + nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8 # nheads = 128 # headdim = 64 # batch_size = 64 # seqlen = 512 # nheads = 8 # headdim = 128 - nheads_kv = nheads - # nheads_kv = nheads // 4 + # nheads_kv = nheads + nheads_kv = nheads // 8 # nheads_kv = 1 # headdim_v = headdim headdim_v = 128 if headdim == 192 else headdim @@ -302,7 +304,7 @@ def run(*args, **kwargs): if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q - cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) # q_unpad = q_unpad[:256] # seqlen_q = 256 @@ -369,9 +371,9 @@ def run(*args, **kwargs): time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if flash_attn_func_python is not None: if not varlen: - m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') else: - m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') + m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: time.sleep(1) if not varlen: diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index c71a049c752..ddd5cfc13d9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -296,7 +296,7 @@ def epilogue( cute.copy(smem_copy_atom_O, taccOrO, taccOsO) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) # Write LSE from rmem -> gmem if const_expr(mLSE is not None): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0a0dae7eb12..8309a19f89c 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -32,6 +32,7 @@ from flash_attn.cute.softmax import SoftmaxSm100 from flash_attn.cute.seqlen_info import SeqlenInfo from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod @@ -56,9 +57,10 @@ def __init__( # dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, is_causal: bool = False, is_local: bool = False, - qhead_per_kvhead: cutlass.Constexpr[int] = 1, + pack_gqa: bool = False, m_block_size: int = 128, n_block_size: int = 128, is_persistent: bool = True, @@ -89,7 +91,9 @@ def __init__( self.is_causal = is_causal self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead - self.pack_gqa = False + self.pack_gqa = pack_gqa + if pack_gqa: + assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False @@ -253,7 +257,11 @@ def __call__( if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None + # This can be tuned + self.e2e_freq = 16 + if const_expr(self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa): + self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 cta_group = tcgen05.CtaGroup.ONE # the intermediate tensor p is from tmem & mK-major @@ -308,6 +316,18 @@ def __call__( sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) + if const_expr(self.pack_gqa): + shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) + stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) + mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) + shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) + stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) + mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + if const_expr(mLSE is not None): + shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) + stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) + mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() @@ -517,7 +537,7 @@ def kernel( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - tma_atom_O: cute.CopyAtom, + tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, softcap_val: Optional[Float32], window_size_left: Optional[Int32], @@ -551,11 +571,10 @@ def kernel( # Prefetch tma descriptor if warp_idx == 0: - if const_expr(not self.pack_gqa): - cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - if const_expr(self.use_tma_O): + if const_expr(tma_atom_O is not None): cpasync.prefetch_descriptor(tma_atom_O) # Alloc @@ -1369,7 +1388,7 @@ def softmax_step( ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and self.head_dim_padded <= 128, - e2e_freq=16 if self.head_dim_padded <= 64 else 16) + e2e_freq=self.e2e_freq) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) @@ -1477,7 +1496,15 @@ def correction_loop( # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. stats = [None] * self.q_stage - learnable_sink_val = Float32(learnable_sink[head_idx]) if const_expr(learnable_sink is not None) else None + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ((self.q_stage * m_block + stage) * self.m_block_size + tidx) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) @@ -1491,7 +1518,7 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) - row_sum += utils.exp2f(learnable_sink_val * LOG2_E - row_max * softmax_scale_log2) + row_sum += utils.exp2f(learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) @@ -1511,8 +1538,8 @@ def correction_loop( else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block,)) for stage in cutlass.range_constexpr(self.q_stage): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -1521,8 +1548,10 @@ def correction_loop( (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) - if tidx < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: - gLSE[tidx + stage * self.m_block_size] = lse + seqlen_q = seqlen.seqlen_q if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead + if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse o_corr_consumer_phase ^= 1 softmax_corr_consumer_phase ^= 1 @@ -1755,6 +1784,9 @@ def epilogue_s2g( tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final @@ -1764,14 +1796,17 @@ def epilogue_s2g( tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) cute.autovec_copy(tOsO[None, None, None, stage], tOrO) # copy acc O from rmem to gmem - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, - ) + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, self.q_stage * m_block + stage, seqlen.seqlen_q) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) # Advance to next tile diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4a7b903a175..8a54c152185 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -70,6 +70,7 @@ def _flash_attn_fwd( m_block_size: int = 128, n_block_size: int = 128, num_threads: int = 384, + pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(t) for t in (q, k, v)] @@ -129,6 +130,8 @@ def _flash_attn_fwd( if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 out_torch_dtype = q.dtype device = q.device @@ -164,6 +167,10 @@ def _flash_attn_fwd( if compute_capability == 9: # TODO: tune block size according to hdim if not causal and not local: n_block_size = 192 + if compute_capability == 10: + # TODO: fix the varlen case + if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): + pack_gqa = False compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, @@ -171,7 +178,7 @@ def _flash_attn_fwd( page_table is not None, window_size_left is not None, window_size_right is not None, learnable_sink is not None, - m_block_size, n_block_size, num_threads, + m_block_size, n_block_size, num_threads, pack_gqa, compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -186,7 +193,7 @@ def _flash_attn_fwd( qhead_per_kvhead, is_causal=causal, is_local=local, - pack_gqa=False, + pack_gqa=pack_gqa, m_block_size=m_block_size, n_block_size=n_block_size, # num_stages=1, @@ -199,9 +206,10 @@ def _flash_attn_fwd( fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, - qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, ) else: @@ -422,6 +430,7 @@ def forward( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + pack_gqa: Optional[bool] = None, ): out, lse = _flash_attn_fwd( q, @@ -433,6 +442,7 @@ def forward( window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, + pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -476,6 +486,7 @@ def forward( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + pack_gqa: Optional[bool] = None, ): out, lse = _flash_attn_fwd( q, @@ -492,6 +503,7 @@ def forward( window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, + pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -517,6 +529,7 @@ def flash_attn_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + pack_gqa: Optional[bool] = None, ): return FlashAttnFunc.apply( q, @@ -527,6 +540,7 @@ def flash_attn_func( window_size, learnable_sink, softcap, + pack_gqa, ) @@ -544,6 +558,7 @@ def flash_attn_varlen_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + pack_gqa: Optional[bool] = None, ): return FlashAttnVarlenFunc.apply( q, @@ -559,4 +574,5 @@ def flash_attn_varlen_func( window_size, learnable_sink, softcap, + pack_gqa, ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index eaf351f3977..879fd0a2c27 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -32,7 +32,7 @@ @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) @@ -81,13 +81,13 @@ def test_flash_attn_output( # batch_size = 1 nheads = 6 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) @@ -162,9 +162,8 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - pack_gqa_vals = [False] + # num_splits_vals = [1, 3] + pack_gqa_vals = [False, True, None] num_splits_vals = [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( @@ -243,7 +242,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) @pytest.mark.parametrize("has_learnable_sink", [False, True]) # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @@ -265,7 +264,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", [128, 192]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -299,17 +298,17 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) - batch_size = 49 if seqlen_q <= 2048 else 2 + batch_size = 49 if seqlen_q <= 1024 else 7 nheads = 6 # batch_size = 1 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] - # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) @@ -431,9 +430,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - pack_gqa_vals = [False] + pack_gqa_vals = [False, True, None] + # num_splits_vals = [1, 3] num_splits_vals = [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad, lse = flash_attn_varlen_func( @@ -453,6 +451,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, + pack_gqa=pack_gqa, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: From 060c9188beec3a8b62b33a3bfa6d5d2d44975fab Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 14 Aug 2025 13:11:47 -0400 Subject: [PATCH 076/258] Bump to v2.8.3 --- flash_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 69eae460e36..4a8a7c33f46 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.8.2" +__version__ = "2.8.3" from flash_attn.flash_attn_interface import ( flash_attn_func, From cd9383f314b6bb81c79f56139da9c405f0e397dd Mon Sep 17 00:00:00 2001 From: Chao Shi Date: Fri, 15 Aug 2025 23:38:10 +0800 Subject: [PATCH 077/258] [BugFix] Fix flash_attn_with_kvcache with scalar cache_seqlen (#1795) When the parameter `cache_seqlen` is scalar, it should be expand to vector of shape (batch_size). In the original code, whenever `block_table` is used, the shape of `k_cache` is (num_blocks, page_size, ...), and thus `cache_seqlen` is expanded to shape (num_blocks) instead of (batch_size), which is wrong. This fix uses the shape of `q`, which is always `batch_size`. --- flash_attn/flash_attn_interface.py | 2 +- hopper/flash_attn_interface.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1e041e4538d..535bd416745 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1576,7 +1576,7 @@ def flash_attn_with_kvcache( softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index b753a0fba7b..5547f426da5 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -751,7 +751,7 @@ def flash_attn_with_kvcache( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) out, softmax_lse, *rest = _flash_attn_forward( From b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Aug 2025 00:03:26 -0400 Subject: [PATCH 078/258] [Cute] Port fwd_combine kernel from C++ to cute-dsl --- flash_attn/cute/block_info.py | 8 +- flash_attn/cute/flash_bwd.py | 4 +- flash_attn/cute/flash_fwd.py | 8 +- flash_attn/cute/flash_fwd_combine.py | 644 +++++++++++++++++++++++++++ flash_attn/cute/flash_fwd_sm100.py | 4 +- flash_attn/cute/interface.py | 223 ++++++++++ flash_attn/cute/seqlen_info.py | 22 + flash_attn/cute/tile_scheduler.py | 2 +- flash_attn/cute/utils.py | 60 +++ tests/cute/test_flash_attn.py | 62 ++- 10 files changed, 1023 insertions(+), 14 deletions(-) create mode 100644 flash_attn/cute/flash_fwd_combine.py diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 2739a31c4ef..2914e42e2ab 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -5,7 +5,7 @@ import cutlass import cutlass.cute as cute -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass(frozen=True) @@ -20,7 +20,7 @@ class BlockInfo: @cute.jit def get_n_block_min_max( - self, seqlen_info: SeqlenInfo, m_block: cutlass.Int32 + self, seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32 ) -> Tuple[cutlass.Int32, cutlass.Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) if cutlass.const_expr( @@ -45,7 +45,7 @@ def get_n_block_min_max( @cute.jit def get_n_block_min_causal_local_mask( self, - seqlen_info: SeqlenInfo, + seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32, n_block_min: cutlass.Int32, ) -> cutlass.Int32: @@ -64,7 +64,7 @@ def get_n_block_min_causal_local_mask( @cute.jit def get_n_block_min_before_local_mask( self, - seqlen_info: SeqlenInfo, + seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32, n_block_min: cutlass.Int32, ) -> cutlass.Int32: diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 79f5ee8ec13..619e0408cd4 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -16,7 +16,7 @@ from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK class FlashAttentionBackwardSm80: @@ -631,7 +631,7 @@ def kernel( gmem_copy_params = SimpleNamespace( gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum ) - seqlen = SeqlenInfo(batch_idx, mQ.shape[1], mK.shape[1]) + seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1]) load_Q_LSE = partial( self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index ddd5cfc13d9..48a4a3203ff 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -24,7 +24,7 @@ from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import Softmax -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA @@ -274,7 +274,7 @@ def epilogue( mO: cute.Tensor, mLSE: Optional[cute.Tensor], sO: cute.Tensor, - seqlen: SeqlenInfo, + seqlen: SeqlenInfoQK, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], tiled_mma: cute.TiledMma, @@ -655,7 +655,7 @@ def kernel( window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfo(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) + seqlen = SeqlenInfoQK(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: @@ -1343,7 +1343,7 @@ def kernel( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfo, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py new file mode 100644 index 00000000000..4c423b80968 --- /dev/null +++ b/flash_attn/cute/flash_fwd_combine.py @@ -0,0 +1,644 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.seqlen_info import SeqlenInfo + + +class FlashAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 8, + k_block_size: int = 64, + log_max_splits: int = 4, + num_threads: int = 256, + stages: int = 4, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param m_block_size: m block size + :param k_block_size: k block size + :param log_max_splits: log2 of maximum splits + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.m_block_size = m_block_size + self.k_block_size = k_block_size + self.max_splits = 1 << log_max_splits + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + + @staticmethod + def can_implement( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, + log_max_splits, num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if m_block_size % 8 != 0: + return False + max_splits = 1 << log_max_splits + if max_splits > 256: + return False + if (m_block_size * max_splits) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else + (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store + ) + + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 if self.m_block_size % 128 == 0 else + (64 if self.m_block_size % 64 == 0 else + (32 if self.m_block_size % 32 == 0 else + (16 if self.m_block_size % 16 == 0 else 8))) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + ) + + # O partial shared memory layout (simple layout for pipeline stages) + self.smem_layout_o = cute.make_ordered_layout( + (self.m_block_size, self.k_block_size, self.stages), + order=(1, 0, 2) + ) + + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(not mLSE_partial.element_type in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and not mLSE.element_type in [Float32]): + raise TypeError("LSE tensor must be Float32") + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)") + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)") + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)") + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)") + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose)) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.m_block_size], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = mO_partial.shape[4] + + # Create FastDivmod objects for efficient division + seqlen_divmod = FastDivmod.create(seqlen) + head_divmod = FastDivmod.create(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmod, + head_divmod: FastDivmod, + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sO = storage.sO.get_tensor(smem_layout_o) + + # Handle semaphore reset + if const_expr(semaphore_to_reset is not None): + if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and + k_block == cute.arch.grid_dim()[1] - 1 and + batch_idx == cute.arch.grid_dim()[2] - 1): + semaphore_to_reset[0] = 0 + + # Get number of splits + num_splits = ( + num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None) + else mLSE_partial.shape[1] + ) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + # Extract number of heads (head index will be determined dynamically) + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + + # Early exit for single split if dynamic + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx): + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + + if const_expr(cu_seqlens is None): + # mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] + mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3) + else: + # mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + + # Create identity tensor for coordinate tracking + cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + # Load LSE partial values + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] # Get m coordinate + idx = m_block * self.m_block_size + mi + if idx < max_idx: + # Calculate actual sequence position and head using FastDivmod + if const_expr(not varlen): + head_idx, m_idx = seqlen_divmod.divmod(idx) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] # Get split coordinate + if si < num_splits: + cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + if const_expr(cu_seqlens is None): + # mO_partial_cur = mO_partial[None, None, None, None, batch_idx] + mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4) + else: + # mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial) + + # Precompute these values to avoid recomputing them in the loop + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_fragment(num_rows, cutlass.Int32) + tOhidx = cute.make_fragment(num_rows, cutlass.Int32) + tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate + idx = m_block * self.m_block_size + mi + if const_expr(not varlen): + tOhidx[m], tOmidx[m] = seqlen_divmod.divmod(idx) + else: + tOhidx[m] = idx // seqlen + tOmidx[m] = idx - tOhidx[m] * seqlen + tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint() + if idx >= max_idx: + tOhidx[m] = -1 + + tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + if const_expr(not self.is_even_k): + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = utils.warp_reduce( + ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + op=cute.arch.fmax, + width=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) + # Compute exp scales and sum + lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E)) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) + lse_sum[m] = utils.logf(lse_sum_cur) + lse_max + # Normalize scales + inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.m_block_size: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + # mLSE_cur = mLSE[None, None, batch_idx] + mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2) + else: + # mLSE_cur = cute.domain_offset((offset, 0), mLSE) + mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.m_block_size + mi + if idx < max_idx: + if const_expr(not varlen): + head_idx, m_idx = seqlen_divmod.divmod(idx) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_cur[m_idx, head_idx] = lse_sum[m] + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_fragment(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32)) + + # =============================== + # Step 7: Write final O to gmem + # =============================== + + rO = cute.make_fragment_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + if const_expr(cu_seqlens is None): + # mO_cur = mO[None, None, None, batch_idx] + mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3) + else: + # mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_i64((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # Write final results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOpO: cute.Tensor, + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy( + gmem_tiled_copy_O_partial, + # mO_partial_cur_copy[None, k_idx, split], + utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], + tOsO_partial_cur[None, m, k] + ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 8309a19f89c..186b2190318 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -30,7 +30,7 @@ # import flash_attn.cute.pipeline as pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100 -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc @@ -674,7 +674,7 @@ def kernel( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfo, + SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8a54c152185..8d24b5623d2 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -36,6 +36,7 @@ from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine def maybe_contiguous(x): @@ -576,3 +577,225 @@ def flash_attn_varlen_func( softcap, pack_gqa, ) + + +def _flash_attn_fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: torch.Tensor, + lse: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + semaphore_to_reset: Optional[torch.Tensor] = None, +) -> None: + """Forward combine kernel for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. + + Args: + out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or + (num_splits, total_q, nheads, headdim) if there's cu_seqlens + lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or + (num_splits, total_q, nheads) if there's cu_seqlens + out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens + lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + num_splits_dynamic_ptr: Dynamic number of splits per batch + semaphore_to_reset: Semaphore for synchronization + k_block_size: Block size for head dimension + + Returns: + None + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], "out_partial must be fp16, bf16, or fp32" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" + assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" + assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" + assert lse_partial.shape == out_partial.shape[:-1] + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + # Validate output tensor shapes and types + assert out.shape == out_partial.shape[1:], "out shape mismatch" + if lse is not None: + assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" + assert lse.dtype == torch.float32, "lse must be fp32" + + # Validate optional tensors + for t, name in [(cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr")]: + if t is not None: + assert t.dtype == torch.int32, f"{name} must be int32" + assert t.is_cuda, f"{name} must be on CUDA device" + assert t.is_contiguous(), f"{name} must be contiguous" + + head_dim = out_partial.shape[-1] + num_splits = out_partial.shape[0] + assert num_splits <= 256 + # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + # so that kBlockM is smaller and we have more parallelism. + k_block_size = 64 if head_dim <= 64 else 128 + # We want kBlockM to be as small as possible to maximize parallelism. + # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + log_max_splits = max(math.ceil(math.log2(num_splits)), 4) + if m_block_size == 8: + # If kBlockM == 8 then the minimum number of splits is 32. + # TODO: we can deal w this by using 128 threads instead + log_max_splits = max(log_max_splits, 5) + + # Convert to cute tensors (using kernel-formatted tensors) + out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=4) + lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 2) + out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) if lse is not None else None + + optional_tensors = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) + ] + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = optional_tensors + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Create combine kernel configuration + dtype = torch2cute_dtype_map[out.dtype] + dtype_partial = torch2cute_dtype_map[out_partial.dtype] + + compile_key = ( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, + log_max_splits, + cu_seqlens is not None, seqused is not None, lse is not None, + ) + + if compile_key not in _flash_attn_fwd_combine.compile_cache: + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + m_block_size=m_block_size, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + + # Check if implementation is supported + if not fa_combine.can_implement( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, log_max_splits, num_threads=256 + ): + raise RuntimeError(f"FlashAttention combine kernel cannot be implemented with given parameters") + + _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( + fa_combine, + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream + ) + + _flash_attn_fwd_combine.compile_cache[compile_key]( + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream + ) + + +_flash_attn_fwd_combine.compile_cache = {} + + +def flash_attn_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + return_lse: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Flash Attention combine function for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. This is the main user-facing + interface for the combine kernel. + + Args: + out_partial: Partial outputs tensor with shape: + - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input + - (num_splits, total_q, num_heads, head_size) for variable length input + lse_partial: Partial LSE tensor with shape: + - (num_splits, batch_size, seqlen, num_heads) for regular batched input + - (num_splits, total_q, num_heads) for variable length input + out: Optional output tensor. If None, will be created automatically. + out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. + return_lse: Whether to return the combined LSE tensor. Default is True. + + Returns: + Tuple of (out, lse) where: + - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) + or (total_q, num_heads, head_size) for varlen + - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) + or (total_q, num_heads) for varlen. None if return_lse=False + + Note: + This function expects the input tensors to be in the format produced by + split attention computation, where the first dimension is num_splits. + The permuting from user format to kernel format is now done inside the kernel. + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + if is_varlen: + # Variable length: (num_splits, total_q, num_heads, head_size) + num_splits, total_q, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, total_q, num_heads), "lse_partial shape mismatch for varlen" + batch_size = 1 # Treat as single batch for varlen + seqlen = total_q + else: + # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) + num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), "lse_partial shape mismatch" + + # Determine output dtype + if out_dtype is None: + out_dtype = out_partial.dtype + + # Create output if not provided + device = out_partial.device + if out is None: + if is_varlen: + out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) + else: + out = torch.empty(batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device) + + # Create lse output only if requested + if return_lse: + if is_varlen: + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(0, 1) + else: + lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device).transpose(1, 2) + else: + lse = None + + _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) + return out, lse diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 8d7eb904c8b..dee63db6bf4 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -3,8 +3,30 @@ import cutlass import cutlass.cute as cute +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" class SeqlenInfo: + def __init__( + self, + batch_idx: cutlass.Int32, + seqlen_static: cutlass.Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + ): + self.offset = 0 if cutlass.const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + if cutlass.const_expr(seqused is not None): + self.seqlen = seqused[batch_idx] + elif cutlass.const_expr(cu_seqlens is not None): + self.seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + self.seqlen = seqlen_static + + +class SeqlenInfoQK: def __init__( self, batch_idx: cutlass.Int32, diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 58e9d776df2..747d5392c9a 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -555,7 +555,7 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = num_m_blocks * self.tile_shape_mn[0] // self.tile_shape_mn[1] + num_n_blocks = num_m_blocks * self.tile_shape_mn[0] // self.qhead_per_kvhead_packgqa // self.tile_shape_mn[1] # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) # Seems faster to have this be a power of 2 nheads_in_l2 = 16 if num_n_blocks * 16 <= self.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= self.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= self.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1))) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 02e19ad4cda..0a26fc9866f 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -219,6 +219,10 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: ) ) +@dsl_user_op +def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: + return log2f(a, loc=loc, ip=ip) * math.log(2.0) + @dsl_user_op def fmax( @@ -352,6 +356,15 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) +@dsl_user_op +def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(x.stride) + assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + @cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" @@ -529,3 +542,50 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 +@dsl_user_op +def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len( + flat_stride + ), "Coordinate and stride must have the same length" + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def coord_offset_i64( + tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))) + return cute.make_tensor(new_ptr, new_layout) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 879fd0a2c27..f3042f07635 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -14,7 +14,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.utils.testing import attention_ref, generate_qkv, generate_random_padding_mask -from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -976,3 +976,63 @@ def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, de b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" From 591dc7eb1c8057ec9ee915cb210edc5d35a03bef Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Aug 2025 12:04:45 -0400 Subject: [PATCH 079/258] [Cute] Simplify tile scheduler storing params --- flash_attn/cute/interface.py | 2 +- flash_attn/cute/tile_scheduler.py | 242 ++++++++---------------------- 2 files changed, 60 insertions(+), 184 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8d24b5623d2..da7690d9427 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -166,7 +166,7 @@ def _flash_attn_fwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim - if not causal and not local: + if head_dim == head_dim_v == 128 and not causal and not local: n_block_size = 192 if compute_capability == 10: # TODO: fix the varlen case diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 747d5392c9a..1d7e2dbb32f 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -135,19 +135,8 @@ def create( FastDivmod.create(args.num_block), FastDivmod.create(args.num_head), total_blocks ) - def __init__( - self, - num_block_divmod: FastDivmod, - num_head_divmod: FastDivmod, - total_blocks: Int32, - tile_idx: Int32, - *, - loc=None, - ip=None, - ): - self.num_block_divmod = num_block_divmod - self.num_head_divmod = num_head_divmod - self.total_blocks = total_blocks + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params self._tile_idx = tile_idx self._loc = loc self._ip = ip @@ -159,14 +148,7 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": tile_idx = cute.arch.block_idx()[0] - return StaticPersistentTileScheduler( - params.num_block_divmod, - params.num_head_divmod, - params.total_blocks, - tile_idx, - loc=loc, - ip=ip, - ) + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) # called by host @staticmethod @@ -182,9 +164,9 @@ def get_grid_shape( # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - hn_idx, block_idx = self.num_block_divmod.divmod(self._tile_idx) - batch_idx, head_idx = self.num_head_divmod.divmod(hn_idx) - is_valid = self._tile_idx < self.total_blocks + hn_idx, block_idx = self.params.num_block_divmod.divmod(self._tile_idx) + batch_idx, head_idx = self.params.num_head_divmod.divmod(hn_idx) + is_valid = self._tile_idx < self.params.total_blocks # if cute.arch.thread_idx()[0] == 0: # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return cutlass.utils.WorkTileInfo( @@ -202,7 +184,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.num_block_divmod, self.num_head_divmod, self.total_blocks, self._tile_idx]: + for obj in [self.params, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -210,10 +192,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [self.num_block_divmod, self.num_head_divmod, self.total_blocks, self._tile_idx], - self._values_pos, - ): + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos,): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) @@ -263,27 +242,8 @@ def create( num_hb_quotient=Int32(num_hb_quotient), ) - def __init__( - self, - total_blocks: Int32, - num_block_divmod: FastDivmod, - num_head_divmod: FastDivmod, - l2_minor_divmod: FastDivmod, - l2_major_divmod: FastDivmod, - l2_minor_residual_divmod: FastDivmod, - num_hb_quotient: Int32, - tile_idx: Int32, - *, - loc=None, - ip=None, - ): - self.total_blocks = total_blocks - self.num_block_divmod = num_block_divmod - self.num_head_divmod = num_head_divmod - self.l2_minor_divmod = l2_minor_divmod - self.l2_major_divmod = l2_major_divmod - self.l2_minor_residual_divmod = l2_minor_residual_divmod - self.num_hb_quotient = num_hb_quotient + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params self._tile_idx = tile_idx self._loc = loc self._ip = ip @@ -296,18 +256,7 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @cute.jit def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": tile_idx = cute.arch.block_idx()[0] - return SingleTileLPTScheduler( - params.total_blocks, - params.num_block_divmod, - params.num_head_divmod, - params.l2_minor_divmod, - params.l2_major_divmod, - params.l2_minor_residual_divmod, - params.num_hb_quotient, - tile_idx, - loc=loc, - ip=ip, - ) + return SingleTileLPTScheduler(params, tile_idx, loc=loc, ip=ip) # called by host @staticmethod @@ -321,20 +270,21 @@ def get_grid_shape( @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params # Implement LPT scheduling coordinate calculation - bidhb, l2_mod = self.l2_major_divmod.divmod(self._tile_idx) + bidhb, l2_mod = params.l2_major_divmod.divmod(self._tile_idx) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 - if bidhb < self.num_hb_quotient: - block, bidhb_residual = self.l2_minor_divmod.divmod(l2_mod) + if bidhb < params.num_hb_quotient: + block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) else: - block, bidhb_residual = self.l2_minor_residual_divmod.divmod(l2_mod) - bidhb_actual = bidhb * self.l2_minor_divmod.divisor + bidhb_residual - batch_idx, head_idx = self.num_head_divmod.divmod(bidhb_actual) + block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) + bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual + batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) # Longest-processing-time-first - block = self.num_block_divmod.divisor - 1 - block - is_valid = self._tile_idx < self.total_blocks + block = params.num_block_divmod.divisor - 1 - block + is_valid = self._tile_idx < params.total_blocks return cutlass.utils.WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid ) @@ -347,20 +297,11 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work - self._tile_idx = self.total_blocks + self._tile_idx = self.params.total_blocks def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [ - self.total_blocks, - self.num_block_divmod, - self.num_head_divmod, - self.l2_minor_divmod, - self.l2_major_divmod, - self.l2_minor_residual_divmod, - self.num_hb_quotient, - self._tile_idx, - ]: + for obj in [self.params, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -368,19 +309,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [ - self.total_blocks, - self.num_block_divmod, - self.num_head_divmod, - self.l2_minor_divmod, - self.l2_major_divmod, - self.l2_minor_residual_divmod, - self.num_hb_quotient, - self._tile_idx, - ], - self._values_pos, - ): + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) @@ -406,6 +335,9 @@ def create( ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]) + assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, @@ -418,32 +350,8 @@ def create( lpt=args.lpt, ) - def __init__( - self, - num_head: Int32, - num_batch: Int32, - max_kvblock_in_l2: Int32, - tile_idx: Int32, - mCuSeqlensQ: Optional[cute.Tensor] = None, - mSeqUsedQ: Optional[cute.Tensor] = None, - tile_shape_mn: cutlass.Constexpr[[int, int]] = (128, 128), - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, - lpt: cutlass.Constexpr[bool] = False, - *, - loc=None, - ip=None, - ): - self.num_head = num_head - self.num_batch = num_batch - self.max_kvblock_in_l2 = max_kvblock_in_l2 - self.mCuSeqlensQ = mCuSeqlensQ - self.mSeqUsedQ = mSeqUsedQ - assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( - "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" - ) - self.tile_shape_mn = tile_shape_mn - self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa - self.lpt = lpt + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params self._tile_idx = tile_idx self._is_first_block = True self._loc = loc @@ -456,19 +364,7 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": tile_idx = cute.arch.block_idx()[0] - return SingleTileVarlenScheduler( - params.num_head, - params.num_batch, - params.max_kvblock_in_l2, - tile_idx, - mCuSeqlensQ=params.mCuSeqlensQ, - mSeqUsedQ=params.mSeqUsedQ, - tile_shape_mn=params.tile_shape_mn, - qhead_per_kvhead_packgqa=params.qhead_per_kvhead_packgqa, - lpt=params.lpt, - loc=loc, - ip=ip, - ) + return SingleTileVarlenScheduler(params, tile_idx, loc=loc, ip=ip) # called by host @staticmethod @@ -485,42 +381,44 @@ def get_grid_shape( @cute.jit def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params batch_idx = lane + bidb_start - if cutlass.const_expr(self.mSeqUsedQ is not None): + if cutlass.const_expr(params.mSeqUsedQ is not None): seqlen = Int32(0) - if batch_idx < self.num_batch: - seqlen = self.mSeqUsedQ[batch_idx] + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] else: - assert self.mCuSeqlensQ is not None + assert params.mCuSeqlensQ is not None cur_cu_seqlen = Int32(0) - if batch_idx <= self.num_batch: - cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) seqlen = next_cu_seqlen - cur_cu_seqlen - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): - seqlen *= self.qhead_per_kvhead_packgqa + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa return ( - cute.ceil_div(seqlen, self.tile_shape_mn[0]) - if batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 else Int32(0) ) @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) # Total number of blocks for the next 31 batches m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) # Same for all lanes - group_end_tile = m_blocks_in_group * self.num_head + group_end_tile = m_blocks_in_group * params.num_head # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) next_tile_idx = self._tile_idx while group_end_tile <= next_tile_idx: batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= self.num_batch: - batch_idx = Int32(self.num_batch) + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) group_end_tile = next_tile_idx + 1 else: num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) @@ -528,18 +426,18 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: m_blocks_in_group = cute.arch.shuffle_sync( num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 ) - group_end_tile += m_blocks_in_group * self.num_head + group_end_tile += m_blocks_in_group * params.num_head is_valid = False - if batch_idx >= self.num_batch: - block, head_idx, batch_idx = Int32(0), Int32(0), Int32(self.num_batch) + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) else: - group_start_tile = group_end_tile - m_blocks_in_group * self.num_head + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) # The next problem to process is the first one that does not have ending tile position # that is greater than or equal to tile index. batch_idx_in_group = cute.arch.popc( cute.arch.vote_ballot_sync( - group_start_tile + num_m_blocks_cumulative * self.num_head <= next_tile_idx + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx ) ) batch_idx += batch_idx_in_group @@ -549,22 +447,22 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) ) num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) - mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * self.num_head - if cutlass.const_expr(self.lpt): + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt): # This is a version of the SingleTileLPTScheduler, complicated by the fact that # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = num_m_blocks * self.tile_shape_mn[0] // self.qhead_per_kvhead_packgqa // self.tile_shape_mn[1] + num_n_blocks = num_m_blocks * params.tile_shape_mn[0] // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) # Seems faster to have this be a power of 2 - nheads_in_l2 = 16 if num_n_blocks * 16 <= self.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= self.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= self.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1))) - nheads_in_l2 = min(nheads_in_l2, self.num_head) + nheads_in_l2 = 16 if num_n_blocks * 16 <= params.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= params.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= params.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1))) + nheads_in_l2 = min(nheads_in_l2, params.num_head) mh_in_l2 = nheads_in_l2 * num_m_blocks section_idx = mh_block // mh_in_l2 l2_mod = mh_block - section_idx * mh_in_l2 # Deal with tail section - nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= self.num_head else self.num_head - section_idx * nheads_in_l2 + nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= params.num_head else params.num_head - section_idx * nheads_in_l2 block = l2_mod // nheads_in_this_section head_idx_residual = l2_mod - block * nheads_in_this_section head_idx = section_idx * nheads_in_l2 + head_idx_residual @@ -572,7 +470,7 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: else: head_idx = mh_block // num_m_blocks block = mh_block - head_idx * num_m_blocks - is_valid = self._is_first_block and batch_idx < self.num_batch + is_valid = self._is_first_block and batch_idx < params.num_batch # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) return cutlass.utils.WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid @@ -590,14 +488,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [ - self.num_head, - self.num_batch, - self.max_kvblock_in_l2, - self._tile_idx, - self.mCuSeqlensQ, - self.mSeqUsedQ, - ]: + for obj in [self.params, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -605,23 +496,8 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [ - self.num_head, - self.num_batch, - self.max_kvblock_in_l2, - self._tile_idx, - self.mCuSeqlensQ, - self.mSeqUsedQ, - ], - self._values_pos, + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return SingleTileVarlenScheduler( - *(tuple(obj_list)), - tile_shape_mn=self.tile_shape_mn, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead_packgqa, - lpt=self.lpt, - loc=self._loc, - ) + return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) From f8b4f155c9ecab05561ed915c6fe393f7a1fbfe5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Aug 2025 12:34:40 -0400 Subject: [PATCH 080/258] [Cute] Implement sink for fwd_sm90 --- flash_attn/cute/flash_fwd.py | 151 +++++++++++++++++++---------------- flash_attn/cute/interface.py | 7 +- flash_attn/cute/softmax.py | 8 +- 3 files changed, 94 insertions(+), 72 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 48a4a3203ff..390a451f5c9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -14,7 +14,7 @@ import cutlass import cutlass.cute as cute -from cutlass import const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup import cutlass.utils.ampere_helpers as sm80_utils_basic import cutlass.utils.hopper_helpers as sm90_utils_basic @@ -152,15 +152,15 @@ def _check_type( raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if const_expr(mLSE_type not in [None, cutlass.Float32]): + if const_expr(mLSE_type not in [None, Float32]): raise TypeError("LSE tensor must be Float32") - if const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensQ_type not in [None, Int32]): raise TypeError("cu_seqlens_q tensor must be Int32") - if const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + if const_expr(mCuSeqlensK_type not in [None, Int32]): raise TypeError("cu_seqlens_k tensor must be Int32") - if const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedQ_type not in [None, Int32]): raise TypeError("seqused_q tensor must be Int32") - if const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + if const_expr(mSeqUsedK_type not in [None, Int32]): raise TypeError("seqused_k tensor must be Int32") assert mQ_type == self.dtype @@ -255,8 +255,8 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - softcap: cutlass.Float32, + softmax_scale: Float32, + softcap: Float32, stream: cuda.CUstream, ): """Configures and launches the flash attention kernel. @@ -278,10 +278,10 @@ def epilogue( gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], tiled_mma: cute.TiledMma, - tidx: cutlass.Int32, - m_block: cutlass.Int32, - head_idx: cutlass.Int32, - batch_idx: cutlass.Int32, + tidx: Int32, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, ): # store acc_O rO = cute.make_fragment_like(acc_O, self.dtype) @@ -386,9 +386,9 @@ def load_Q( gmem_thr_copy: cute.TiledCopy, gQ: cute.Tensor, sQ: cute.Tensor, - block: cutlass.Int32, - seqlen: cutlass.Int32, - headdim: cutlass.Int32, + block: Int32, + seqlen: Int32, + headdim: Int32, ): tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) @@ -416,9 +416,9 @@ def load_K( tKcK: cute.Tensor, t0KcK: cute.Tensor, tKpK: cute.Tensor, - block: cutlass.Int32, - smem_pipe_write: cutlass.Int32, - seqlen: cutlass.Int32, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load K? @@ -460,9 +460,9 @@ def load_V( tVcV: cute.Tensor, t0VcV: cute.Tensor, tVpV: cute.Tensor, - block: cutlass.Int32, - smem_pipe_write: cutlass.Int32, - seqlen: cutlass.Int32, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load V? @@ -506,12 +506,12 @@ def _get_smem_layout_atom(self): def _get_tiled_mma(self): tiled_mma_qk = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) tiled_mma_pv = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) @@ -547,10 +547,10 @@ def __call__( mO: cute.Tensor, mLSE: Optional[cute.Tensor], stream: cuda.CUstream, - softmax_scale: Optional[cutlass.Float32] = None, - softcap: Optional[cutlass.Float32] = None, - window_size_left: Optional[cutlass.Int32] = None, - window_size_right: Optional[cutlass.Int32] = None, + softmax_scale: Optional[Float32] = None, + softcap: Optional[Float32] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, ): """Configures and launches the flash attention kernel. @@ -591,7 +591,7 @@ def __call__( softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = cutlass.Float32(softmax_scale / softcap) + softcap_val = Float32(softmax_scale / softcap) self.kernel( mQ, mK, @@ -629,10 +629,10 @@ def kernel( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - softmax_scale_log2: cutlass.Float32, - softcap_val: Optional[cutlass.Float32], - window_size_left: cutlass.Int32, - window_size_right: cutlass.Int32, + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Int32, + window_size_right: Int32, sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -704,7 +704,7 @@ def kernel( tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) acc_shape_O = thr_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_O = cute.make_fragment(acc_shape_O, Float32) acc_O.fill(0.0) # /////////////////////////////////////////////////////////////////////////////// @@ -833,8 +833,8 @@ def preprocess_Q(): ) # First iteration with seqlen masking - smem_pipe_read = cutlass.Int32(0) - smem_pipe_write = cutlass.Int32(self.num_stages - 1) + smem_pipe_read = Int32(0) + smem_pipe_write = Int32(self.num_stages - 1) compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True)) smem_pipe_read = self.advance_pipeline(smem_pipe_read) @@ -874,9 +874,9 @@ def preprocess_Q(): @cute.jit def compute_one_n_block( self, - n_block: cutlass.Int32, - smem_pipe_read: cutlass.Int32, - smem_pipe_write: cutlass.Int32, + n_block: Int32, + smem_pipe_read: Int32, + smem_pipe_write: Int32, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, @@ -897,7 +897,7 @@ def sync(): cute.arch.barrier() acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) - acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32) + acc_S = cute.make_fragment(acc_shape_S, Float32) acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S sync() @@ -987,7 +987,7 @@ def _get_tiled_mma(self): self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, - cutlass.Float32, + Float32, atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.n_block_size), ) @@ -996,7 +996,7 @@ def _get_tiled_mma(self): self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, - cutlass.Float32, + Float32, atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.head_dim_v_padded), a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, @@ -1006,7 +1006,7 @@ def _get_tiled_mma(self): self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, - cutlass.Float32, + Float32, atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.head_dim_v_padded), a_source=warpgroup.OperandSource.RMEM @@ -1063,16 +1063,16 @@ def __call__( mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, + softmax_scale: Float32, stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) - softcap: cutlass.Float32 | float | None = None, - window_size_left: cutlass.Int32 | int | None = None, - window_size_right: cutlass.Int32 | int | None = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, ): """Configures and launches the flash attention kernel. @@ -1080,7 +1080,6 @@ def __call__( mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ - assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) @@ -1191,11 +1190,11 @@ def __call__( softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E - softcap_val = cutlass.Float32(softmax_scale / softcap) + softcap_val = Float32(softmax_scale / softcap) if const_expr(window_size_left is not None): - window_size_left = cutlass.Int32(window_size_left) + window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): - window_size_right = cutlass.Int32(window_size_right) + window_size_right = Int32(window_size_right) self.kernel( tma_tensor_Q if const_expr(not self.pack_gqa) else mQ, tma_tensor_K, @@ -1214,6 +1213,7 @@ def __call__( softcap_val, window_size_left, window_size_right, + learnable_sink, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1253,10 +1253,11 @@ def kernel( tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], - softmax_scale_log2: cutlass.Float32, - softcap_val: Optional[cutlass.Float32], - window_size_left: Optional[cutlass.Int32], - window_size_right: Optional[cutlass.Int32], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1394,6 +1395,7 @@ def kernel( sVt, sP, sO, + learnable_sink, pipeline_k, pipeline_v, mbar_ptr_Q, @@ -1430,7 +1432,7 @@ def load( ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 if warp_idx_in_wg == 0: - q_producer_phase = cutlass.Int32(1) + q_producer_phase = Int32(1) kv_producer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.num_stages ) @@ -1514,15 +1516,16 @@ def mma( sVt: cute.Tensor, sP: Optional[cute.Tensor], sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, gmem_tiled_copy_Q: cute.TiledCopy, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], - tidx: cutlass.Int32, - softmax_scale_log2: cutlass.Float32, - softcap_val: cutlass.Float32, + tidx: Int32, + softmax_scale_log2: Float32, + softcap_val: Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, @@ -1561,7 +1564,7 @@ def mma( self.mma_init() acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) - acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_O = cute.make_fragment(acc_shape_O, Float32) # group parameters for mma_one_n_block mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) @@ -1574,7 +1577,7 @@ def mma( check_inf=True, ) - q_consumer_phase = cutlass.Int32(0) + q_consumer_phase = Int32(0) kv_consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.num_stages ) @@ -1629,7 +1632,7 @@ def scoremod_premask_fn(acc_S): # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 ) pipeline_k.consumer_wait(kv_consumer_state) sm90_utils.gemm( @@ -1716,7 +1719,21 @@ def scoremod_premask_fn(acc_S): self.warp_scheduler_barrier_arrive() # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize() + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_fragment_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.m_block_size + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + else: + sink_val = None + + row_scale = softmax.finalize(sink_val=sink_val) softmax.rescale_O(acc_O, row_scale) # /////////////////////////////////////////////////////////////////////////////// @@ -1733,7 +1750,7 @@ def scoremod_premask_fn(acc_S): @cute.jit def mma_one_n_block( self, - n_block: cutlass.Int32, + n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, @@ -1750,7 +1767,7 @@ def mma_one_n_block( O_should_accumulate: cutlass.Boolean = True, ): acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 ) pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) sm90_utils.gemm( @@ -1792,7 +1809,7 @@ def mma_one_n_block( @cute.jit def mma_one_n_block_intrawg_overlap( self, - n_block: cutlass.Int32, + n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, @@ -1810,7 +1827,7 @@ def mma_one_n_block_intrawg_overlap( smem_pipe_read_v = smem_pipe_read.clone() smem_pipe_read.advance() acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 ) pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() @@ -1884,7 +1901,7 @@ def load_K( tKgK: cute.Tensor, tKsK: cute.Tensor, pipeline: cutlass.pipeline.PipelineAsync, - block: cutlass.Int32, + block: Int32, producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, ): # TODO: mcast diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index da7690d9427..b02d1e91be6 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -148,7 +148,7 @@ def _flash_attn_fwd( for t in (q, k, v, out) ] lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, additive_sink_tensor = [ + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor = [ from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] @@ -185,7 +185,6 @@ def _flash_attn_fwd( if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" - assert learnable_sink is None, "Sm90 doesn't support additive sink" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, @@ -220,13 +219,13 @@ def _flash_attn_fwd( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, additive_sink_tensor, + softcap, window_size_left, window_size_right, learnable_sink_tensor, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, additive_sink_tensor, + softcap, window_size_left, window_size_right, learnable_sink_tensor, ) return out, lse diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index e0407e99cdf..6d8135d6461 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -84,12 +84,18 @@ def online_softmax( return row_scale @cute.jit - def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: + def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None) -> cute.Tensor: """Finalize the online softmax by computing the scale and logsumexp.""" + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(self.row_max, Float32) for r in cutlass.range(cute.size(self.row_sum), unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + self.row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - self.row_max[r] * self.scale_log2) # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = ( self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] From e1407dbe3f2025cda014ffce211c7f3b376c6c5b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Aug 2025 12:52:09 -0400 Subject: [PATCH 081/258] [Cute] Implement PackGQA with TMA for fwd_sm90 --- flash_attn/cute/flash_fwd.py | 55 +++++++++++++++++-------------- flash_attn/cute/tile_scheduler.py | 2 +- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 390a451f5c9..de5fea43b99 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1014,8 +1014,8 @@ def _get_tiled_mma(self): return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs def _get_shared_storage_cls(self): - # If PackGQA, we use cp.async to load Q, so we want sQ to align to 1024 bytes - sQ_alignment = 128 if const_expr(not self.pack_gqa) else 1024 + # If we use cp.async to load Q, we want sQ to align to 1024 bytes + sQ_alignment = 128 if const_expr(self.use_tma_Q) else 1024 sK_alignment = 128 sV_alignment = 128 sQ_struct, sK_struct, sV_struct = [ @@ -1104,17 +1104,31 @@ def __call__( self.num_threads_per_warp_group = 128 self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group self.num_producer_threads = 32 - self.num_Q_load_threads = self.num_mma_threads # If PackGQA, MMA threads load Q + self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q self.num_epilogue_threads = self.num_mma_threads self.num_mma_regs = 240 self.num_producer_regs = 24 # self.num_mma_regs = 232 # self.num_producer_regs = 40 self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) + self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm self._setup_attributes() SharedStorage = self._get_shared_storage_cls() + + if const_expr(self.pack_gqa): + shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) + stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) + mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) + shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) + stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) + mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + if const_expr(mLSE is not None): + shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) + stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) + mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + # TMA gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast @@ -1122,9 +1136,12 @@ def __call__( self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) - tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast - ) + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast + ) + else: + tma_atom_Q, tma_tensor_Q = None, None tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, @@ -1145,18 +1162,6 @@ def __call__( ) else: tma_atom_O = None - if const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) - if const_expr(mLSE is not None): - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: @@ -1196,7 +1201,7 @@ def __call__( if const_expr(window_size_right is not None): window_size_right = Int32(window_size_right) self.kernel( - tma_tensor_Q if const_expr(not self.pack_gqa) else mQ, + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, tma_tensor_K, tma_tensor_V, mO, @@ -1277,7 +1282,7 @@ def kernel( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: - if const_expr(not self.pack_gqa): + if const_expr(tma_atom_Q is not None): cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) @@ -1293,7 +1298,7 @@ def kernel( # if tidx < 2: # # barrierO num threads should be self.num_mma_threads # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - cute.arch.mbarrier_init(mbar_ptr_Q, 1 if const_expr(not self.pack_gqa) else self.num_Q_load_threads) + cute.arch.mbarrier_init(mbar_ptr_Q, 1 if const_expr(self.use_tma_Q) else self.num_Q_load_threads) # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) @@ -1454,7 +1459,7 @@ def load( mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) - if const_expr(not self.pack_gqa): + if const_expr(self.use_tma_Q): gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, @@ -1480,7 +1485,7 @@ def load( load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) # load_Q - if const_expr(not self.pack_gqa): + if const_expr(self.use_tma_Q): # TODO: wait for Q to be empty q_producer_phase ^= 1 with cute.arch.elect_one(): @@ -1606,8 +1611,8 @@ def scoremod_premask_fn(acc_S): mask_causal=self.is_causal, mask_local=self.is_local, ) softmax.reset() - # Load Q if PackGQA - if const_expr(self.pack_gqa): + # Load Q if not TMA_Q + if const_expr(not self.use_tma_Q): pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, head_idx, batch_idx] diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 1d7e2dbb32f..bea4496ecc2 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -335,7 +335,7 @@ def create( ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]) - assert self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None, ( + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) return SingleTileVarlenScheduler.Params( From 0e60e39473e8df549a20fb5353760f7a65b30e2d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Aug 2025 15:46:21 -0400 Subject: [PATCH 082/258] [Cute] Use R2P for masking in fwd_sm90 Actually doesn't seem to make it faster --- flash_attn/cute/mask.py | 87 ++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index d5cb09db7b4..28c019db7b3 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -41,13 +41,26 @@ def apply_mask( seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): - # traverse column index. - for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - # if t0ScS_mn[0, c][1] >= seqlenk_col_limit: - # acc_S_mn[None, c].fill(-cutlass.Float32.inf) - oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit - for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): - acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] + if cutlass.const_expr(False): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] + else: # R2P trick, see apply_mask_sm100 + # Instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., + # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # This is so that we can use the R2P instruction. + col_limit_transformed = seqlenk_col_limit // 8 * 2 + min(seqlenk_col_limit % 8, 2) + ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + mask = (1 << col_limit_right_s) - 1 + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf else: # Causal or local # If PackGQA, we split the work of compute divmod among threads in the same row threads_per_row = thr_mma.tv_layout_C.shape[0][0] @@ -75,12 +88,20 @@ def apply_mask( col_limit_right = row_idx + causal_row_offset if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - # traverse column index. - for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - # only consider the column index, so the row index sets to 0. - # if t0ScS_mn[0, c][1] >= col_limit_right: - # acc_S_mn[r, c] = -cutlass.Float32.inf - acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] + if cutlass.const_expr(False): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] + else: # R2P trick, see apply_mask_sm100 + col_limit_transformed = col_limit_right // 8 * 2 + min(col_limit_right % 8, 2) + ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + mask = (1 << col_limit_right_s) - 1 + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right @@ -136,7 +157,7 @@ def apply_mask_sm100( if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) - if cutlass.const_expr(not ncol % 16 == 0): + if cutlass.const_expr(False): for i in cutlass.range(ncol, unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf @@ -147,28 +168,25 @@ def apply_mask_sm100( else: # Bit manipulation, compiles down to the R2P instruction # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using - # Ideally we'd move by 32 instead of 16, but mask >> i isn't correct for i == 31 + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 # (see below). - for s in cutlass.range(ncol // 16, unroll_full=True): - col_limit_right_s = seqlenk_col_limit - s * 16 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): # Don't need to clamp to 32 since the shr.u32 instruction does that already - col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + col_limit_right_s = max(seqlenk_col_limit - s * 24, 0) # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) - # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_cur = %d", mask, col_limit_right_s, col_limit_right_cur) + mask = (1 << col_limit_right_s) - 1 + # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_s = %d", mask, col_limit_right_s, col_limit_right_s) # This needs to be range_constexpr, otherwise the compiler can't generate # the R2P instruction - for i in cutlass.range_constexpr(16): + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): # mask >> i does not produce correct result for 0b11..11 >> 31 # However, if we use utils.shr_u32, the compiler doesn't generate # the R2P instruction, so it's slower. - # Instead we just move by 16 instead of 32. - mask_i_bit = cutlass.Boolean(mask & (1 << i)) - # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) + # Instead we just move by 24 instead of 32. # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) - acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf + acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf # This is the equivalent of: - # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf # if tidx == 0: cute.print_tensor(acc_S) else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q @@ -182,7 +200,7 @@ def apply_mask_sm100( # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) - if cutlass.const_expr(not ncol % 16 == 0): + if cutlass.const_expr(False): for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] @@ -190,19 +208,16 @@ def apply_mask_sm100( else: # Bit manipulation, compiles down to the R2P instruction # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using - for s in cutlass.range(ncol // 16, unroll_full=True): - col_limit_right_s = col_limit_right - s * 16 - col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_right - s * 24, 0) # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) + mask = (1 << col_limit_right_s) - 1 # This needs to be range_constexpr, otherwise the compiler can't generate # the R2P instruction - for i in cutlass.range_constexpr(16): - # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) - mask_i_bit = cutlass.Boolean(mask & (1 << i)) - acc_S[s * 16 + i] = acc_S[s * 16 + i] if mask_i_bit else -cutlass.Float32.inf + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf # This is the equivalent of: - # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf else: local_row_offset_right = ( causal_row_offset + self.window_size_right From 199401d31f940d1f062eb9c0233b41ef62baa5ae Mon Sep 17 00:00:00 2001 From: jayhshah Date: Thu, 21 Aug 2025 19:44:03 -0700 Subject: [PATCH 083/258] Add sorting and head swizzle to varlen scheduler (#1823) * use LPT order in varlen kernel * add prefill decode benchmark script * add sort in prepare * add full implementation: * add varlen kvhead swizzle * add settings for swizzle ablation * add correction term for sort when causal * remove ablation options from frontend and clean up comments * add comments in prepare kernel * remove debug code and scripts * put back defaults in tests * remove excess Nones returned in python interface for varlen * revert opinionated change to setup.py on cuda version 12.9 * force inline sort op and make east const * more templating in varlen scheduler to cure some register spilling * fix exploding build by splitting compilation and add qol macros for hdimdiff * fix metadata mismatch with seqlenk in test script * extend prepare kernel to >992 batches and always call it for varlen * do inter-batch sort per every 992 batches * better names in combine and fix prepare condition in api --- hopper/flash.h | 8 +- hopper/flash_api.cpp | 85 +++++++-- hopper/flash_attn_interface.py | 3 +- hopper/flash_fwd_combine_kernel.h | 11 +- hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/flash_fwd_launch_template.h | 17 +- hopper/flash_prepare_scheduler.cu | 204 +++++++++++++++++---- hopper/setup.py | 26 ++- hopper/static_switch.h | 23 +++ hopper/test_flash_attn.py | 74 +++++--- hopper/tile_scheduler.hpp | 188 +++++++++++++------ hopper/tile_size.h | 7 +- 12 files changed, 499 insertions(+), 149 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index bee89e5f054..6848e8c9dbd 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -152,10 +152,16 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; + int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; + int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual + int * __restrict__ num_nheads_in_l2_ptr; bool skip_scheduler_metadata_computation; + bool varlen_sort_batches; + int tile_count_semaphore_offset; + bool head_swizzle; + bool prepare_varlen_pdl; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 33185bf2304..8ffd0d0baf9 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -39,6 +39,8 @@ PyObject* PyInit__C(void) #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -250,6 +252,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -257,6 +260,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -268,11 +272,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -283,6 +289,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -290,6 +297,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -301,11 +309,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -329,11 +339,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } + #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif @@ -525,8 +537,7 @@ mha_fwd_get_scheduler_metadata( bool has_softcap, int64_t num_splits, std::optional pack_gqa_, - int64_t sm_margin - ) { + int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); @@ -585,8 +596,9 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -603,18 +615,35 @@ mha_fwd_get_scheduler_metadata( // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::empty( + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + opts.dtype(torch::kInt32)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; } - if (params.num_splits_dynamic_ptr) { + if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); @@ -938,11 +967,11 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -955,8 +984,17 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); @@ -968,15 +1006,22 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } else { tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } - if (scheduler_needs_semaphore && !use_dynamic_split) { + if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1134,7 +1179,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5547f426da5..a2eb9594896 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -50,7 +50,8 @@ def _flash_attn_forward( scheduler_metadata=None, num_splits=1, pack_gqa=None, - sm_margin=0): + sm_margin=0, + ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969d9..05667698006 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -145,6 +145,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -164,6 +165,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -187,7 +189,9 @@ class FlashAttnFwdCombine { args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.semaphore_to_reset + args.varlen_batch_idx_ptr, + args.semaphore_to_reset, + }; } @@ -203,8 +207,9 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + int const maybe_virtual_batch = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924b4..a2ff25dcd5f 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -35,7 +35,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b8af2977f11..d48a4fd9562 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -57,8 +57,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + static constexpr bool LPT = Is_causal || Is_local; + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -149,14 +151,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, params.dv, sizeof(Element), + params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.num_m_blocks_ptr, + params.varlen_batch_idx_ptr, + params.num_nheads_in_l2_ptr }; - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + if (Varlen && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -189,7 +193,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -205,7 +209,6 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 7093fff32b6..1d810c015ed 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -2,6 +2,7 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ +#include #include "cutlass/fast_math.h" #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" @@ -10,8 +11,35 @@ #include "flash.h" +#include "static_switch.h" + namespace flash { +// Sort in descending order +template +struct PrepareSortOp +{ + __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) + { + return lhs > rhs; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, @@ -19,16 +47,28 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - // int* const num_m_blocks_ptr, + int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, - bool enable_pdl) { + int* const varlen_batch_idx_ptr, + // int* const num_n_blocks_ptr, + int* const num_nheads_in_l2_ptr, + bool enable_pdl, + bool is_causal, + bool packgqa, + int max_kvblocks_in_l2) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; - // Assume that there's only one block in the grid + static constexpr int BLOCK_DIM_X = NumWarps * 32; + static constexpr int ITEMS_PER_THREAD = 1; + static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); + using BlockMergeSort = cub::BlockMergeSort; + __shared__ int total_blocks_smem[kSmemSize]; - // There's only 1 block in the grid, so might as well start launching the main attn kernel + // Allocate shared memory for BlockMergeSort operations + __shared__ typename BlockMergeSort::TempStorage temp_storage; + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } @@ -38,8 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; @@ -50,13 +89,12 @@ __global__ void prepare_varlen_num_blocks_kernel( } else { seqlen = seqlen_q_static; } - seqlen *= qhead_per_khead; + if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; - auto get_num_n_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_n_blocks = [&](int batch_idx) { int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; int seqlen; if (seqused_k) { @@ -83,42 +121,130 @@ __global__ void prepare_varlen_num_blocks_kernel( }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int bidb_start = kNumBatchPerWarp * warp_idx; - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); - - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + int batch_cta_idx_offset = int(blockIdx.x) * 992; + int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; + int batch_idx = lane + bidb_start; + int num_m_blocks = get_num_m_blocks(batch_idx); + int num_n_blocks = get_num_n_blocks(batch_idx); + + auto get_nheads_in_l2 = [&](int n_blocks) { + int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 + : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 + : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 + : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 + : 1; + if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } + return min(nheads_in_l2, num_head); + }; + + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; + } else { + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } + + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + batch_idx = batch_cta_idx_offset + threadIdx.x; + if (batch_idx < num_batch && threadIdx.x < 992) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; + varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } } + } } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl) { - // Only support batch <= 992 (32 warps, each with 31 batches) - int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, enable_pdl); + int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); + int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 + int num_ctas = cutlass::ceil_div(params.b, 31 * 32); + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice + int const element_size = params.is_e4m3 ? 1 : 2; + int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); + int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; + BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { + NUM_WARP_SWITCH(num_warps, NumWarps, [&] { + flash::prepare_varlen_num_blocks_kernel<<>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.varlen_batch_idx_ptr, + // params.num_n_blocks_ptr, + params.num_nheads_in_l2_ptr, + enable_pdl, + params.is_causal, + packgqa, + max_kvblocks_in_l2); + }); + }); } diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..850fb0b520c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -64,6 +64,8 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -468,10 +470,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -481,7 +486,18 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all", "diff"] + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -495,6 +511,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 5e13b5f93a8..15a7d51364b 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -179,3 +179,26 @@ return __VA_ARGS__(); \ } \ }() + +#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE <= 1) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index f1247e689da..0b5a0e2af98 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -55,8 +55,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -75,7 +75,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -107,6 +107,8 @@ def test_flash_attn_output( ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) @@ -121,8 +123,11 @@ def test_flash_attn_output( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -193,6 +198,7 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out = flash_attn_func( q, k, @@ -286,8 +292,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -295,7 +301,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -305,7 +311,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -328,28 +334,38 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = nheads + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -458,8 +474,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # pack_gqa_vals = [False] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1] + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("seqused_q: ", seqused_q) + # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, @@ -477,6 +500,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -580,16 +605,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) @@ -597,9 +622,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -669,6 +694,7 @@ def test_flash_attn_kvcache( dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -850,17 +876,21 @@ def test_flash_attn_kvcache( sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + batch_size, + max_seqlen_q if varlen_q else seqlen_q, + seqlen_k if page_size is None else page_table.shape[1] * page_size, + nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, - num_splits=num_splits + num_splits=num_splits, ) else: scheduler_metadata = None @@ -895,7 +925,7 @@ def test_flash_attn_kvcache( rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, - return_softmax_lse=True + return_softmax_lse=True, ) if varlen_q: out = output_pad_fn(out) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1f90f66adc2..41e0bab1624 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -24,8 +24,11 @@ struct TileSchedulerArguments { int* const tile_count_semaphore = nullptr; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + // int const* const num_n_blocks_ptr = nullptr; + int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -463,7 +466,8 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -482,13 +486,17 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + // int const max_kvblocks_in_l2; cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; int const* const cu_seqlens; int const* const seqused; - // int* const num_m_blocks_ptr; int const* const num_splits_dynamic_ptr; + int const* const num_m_blocks_ptr; + int const* const varlen_batch_idx_ptr; + // int const* const num_n_blocks_ptr; + int const* const num_nheads_in_l2_ptr; }; static Params @@ -498,13 +506,20 @@ class VarlenDynamicPersistentTileScheduler { assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; + // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + // max_kvblocks_in_l2, cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr, + args.num_m_blocks_ptr, + args.varlen_batch_idx_ptr, + // aras.num_n_blocks_ptr, + args.num_nheads_in_l2_ptr}; } static dim3 @@ -525,8 +540,15 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { + auto get_actual_batch = [&](int virtual_batch) { + if constexpr(Prepared && Sort) { + return params.varlen_batch_idx_ptr[virtual_batch]; + } else { + return virtual_batch; + } + }; if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; + return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift @@ -540,7 +562,7 @@ class VarlenDynamicPersistentTileScheduler { // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } - return {block, bidh_actual, bidb, split_idx}; + return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; @@ -554,31 +576,39 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; + if constexpr (Prepared) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; + } else { + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlockM) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlockM) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; }; auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; + bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; + if constexpr (!Split) { + return is_valid ? 1 : 0; + } else if constexpr(Prepared) { + return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; + } else { + return is_valid ? params.nsplits_divmod.divisor : 0; + } }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane @@ -589,12 +619,14 @@ class VarlenDynamicPersistentTileScheduler { // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } + // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // if constexpr (Split) { + // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + // } + // NEW: current_work.tile_idx holds group_start_tile for starting batch + int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); @@ -626,27 +658,81 @@ class VarlenDynamicPersistentTileScheduler { bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int mh_block = next_tile_idx - group_start_tile; + int block, bidh; + if constexpr (LPT) { + if (!Split || num_splits == 1) { + // NOTE: code for computing nheads_in_l2 directly left as reference + // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; + // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks + // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); + // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } + // nheads_in_l2 = min(nheads_in_l2, params.num_head); + auto get_nheads_in_l2 = [&](int batch_idx) { + if constexpr(Prepared) { + return params.num_nheads_in_l2_ptr[batch_idx]; + } else { + return !PackGQA ? params.qhead_per_khead : 1; + } + }; + int nheads_in_l2 = get_nheads_in_l2(bidb); + int mh_in_l2 = nheads_in_l2 * num_m_blocks; + int section_idx = mh_block / mh_in_l2; + int l2_mod = mh_block - section_idx * mh_in_l2; + // tail section + int nheads_remainder = params.num_head - section_idx * nheads_in_l2; + int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; + block = l2_mod / nheads_in_this_section; + int bidh_residual = l2_mod - block * nheads_in_this_section; + bidh = section_idx * nheads_in_l2 + bidh_residual; + if constexpr(Split) { + // remember to set num_splits = 1 in work tile + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } else { + // NOTE: leave traverse heads first version for reference + // block = params.head_divmod.divmod(bidh, mh_block); + // if constexpr (Split) { + // int split_idx = block / num_m_blocks; + // block = block - split_idx * num_m_blocks; + // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // bidh = reinterpret_cast(bidh_packed); + // } + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } + block = num_m_blocks - 1 - block; + } else { + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - bidh = reinterpret_cast(bidh_packed); } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; + return {group_start_tile, block, bidh, bidb}; } template diff --git a/hopper/tile_size.h b/hopper/tile_size.h index e6cb31515c7..8353542c477 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -21,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( return {128, 96, true, false}; } else { // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen @@ -29,8 +29,9 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - // {128, 192, false, false} and {192, 128, false, true} are quite good too + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem From 632fe2a000a65bba523d7eec75b812efd5328d8e Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Sun, 24 Aug 2025 12:45:41 +0800 Subject: [PATCH 084/258] Fixes incorrect variable reference in comment (#1775) Corrects comment documentation to reference total_q instead of total_k for the output tensor dimensions, ensuring consistency with the actual parameter being described. --- csrc/flash_attn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index dd7a5c3f9b4..a7b5d36835d 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -515,7 +515,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. From 832d5448ce65c5fd163a446e51e93dbf770849db Mon Sep 17 00:00:00 2001 From: y-sq <58683402+y-sq@users.noreply.github.com> Date: Mon, 25 Aug 2025 04:44:22 -0700 Subject: [PATCH 085/258] Update the initialization of dk/dv_semaphore (#1839) When testing the deterministic option for the GQA case, we found it fell into deadlock issues. Initialization dk and dv_semaphore to zeros to fix this issue. --- hopper/flash_api.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8ffd0d0baf9..adb53fdab6b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1529,9 +1529,9 @@ std::tuple(); if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } From 478841a2c5b58870d533219e9d3c1d505ca9af4d Mon Sep 17 00:00:00 2001 From: Ravi Ghadia <40660742+ghadiaravi13@users.noreply.github.com> Date: Tue, 26 Aug 2025 13:49:29 -0700 Subject: [PATCH 086/258] Update tile_scheduler.hpp (#1841) --- hopper/tile_scheduler.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 41e0bab1624..3c9e42996b0 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -251,7 +251,7 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; + long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead @@ -382,9 +382,9 @@ class SingleTileBwdLPTScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k - int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; - int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float); - int const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); + long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum // Swizzle is the size of each "section". Round swizzle to a power of 2 // Need to be careful about the case where only one head will fit From 6f2b052488c8964e0e62380a4fbcff1ceb81492e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Wed, 27 Aug 2025 04:57:21 +0200 Subject: [PATCH 087/258] ci: Move build job to workflow template (#1835) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ci: Move build job to workflow template Signed-off-by: oliver könig * check out right tag Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * revert Signed-off-by: oliver könig --------- Signed-off-by: oliver könig --- .github/workflows/_build.yml | 152 ++++++++++++++++++++++++++++++ .github/workflows/publish.yml | 172 +++++----------------------------- 2 files changed, 178 insertions(+), 146 deletions(-) create mode 100644 .github/workflows/_build.yml diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 00000000000..d55c47fd910 --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,152 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + release-version: + description: "Upload wheel to this release" + required: false + type: string + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.26 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + pip install jinja2 + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + shell: bash + + - name: Build wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) NVCC_THREADS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8d2ea71e4df..0a668e291cb 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,16 +13,16 @@ on: - v* jobs: - setup_release: name: Create Release runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} shell: bash - - name: Create Release id: create_release uses: actions/create-release@v1 @@ -35,161 +35,43 @@ jobs: build_wheels: name: Build Wheel needs: setup_release - runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1', '2.8.0'] - cuda-version: ['12.9.1'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.4.0' - python-version: '3.13' - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Set CUDA and PyTorch versions - run: | - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - - - name: Free up disk space - if: ${{ runner.os == 'Linux' }} - # https://github.com/easimon/maximize-build-space/blob/master/action.yml - # https://github.com/easimon/maximize-build-space/tree/test-report - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - - - name: Set up swap space - if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 - with: - swap-size-gb: 10 - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.26 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - sub-packages: '["nvcc"]' - - - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error - # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable - pip install typing-extensions==4.12.2 - # We want to figure out the CUDA version to download pytorch - # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ - print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ - ) - if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 - pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} - fi - nvcc --version - python --version - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==75.8.0 - pip install ninja packaging wheel - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Limit MAX_JOBS otherwise the github runner goes OOM - # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) NVCC_THREADS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Log Built Wheels - run: | - ls dist - - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - - - name: Get Release with tag - id: get_current_release - uses: joutvhu/get-release@v1 - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} publish_package: name: Publish package needs: [build_wheels] - runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 with: - python-version: '3.10' - + python-version: "3.10" - name: Install dependencies run: | pip install ninja packaging wheel twine @@ -197,13 +79,11 @@ jobs: pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu - - name: Build core package env: FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist - - name: Deploy env: TWINE_USERNAME: "__token__" From b2476552432fd6ac991003db4564eb289dd77332 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Wed, 27 Aug 2025 16:43:37 +0200 Subject: [PATCH 088/258] ci: Build via workflow template (#1844) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ci: Move build job to workflow template Signed-off-by: oliver könig * check out right tag Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * revert Signed-off-by: oliver könig * ci: Allow build/deploy of arbitrary configurations (#1827) * ci: Allow build/deploy of arbitrary configurations Signed-off-by: oliver könig * add Signed-off-by: oliver könig * cleanui Signed-off-by: oliver könig * cxx11_abi Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * test Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * fix Signed-off-by: oliver könig * final Signed-off-by: oliver könig --------- Signed-off-by: oliver könig * upload Signed-off-by: oliver könig --------- Signed-off-by: oliver könig --- .github/workflows/_build.yml | 76 ++++++++++++++++++++++++++++++++--- .github/workflows/build.yml | 47 ++++++++++++++++++++++ .github/workflows/publish.yml | 1 + 3 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index d55c47fd910..47d7bb49055 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -23,6 +23,11 @@ on: description: "The C++11 ABI to use for the build" required: true type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false release-version: description: "Upload wheel to this release" required: false @@ -39,6 +44,9 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + ref: ${{ inputs.release-version }} + submodules: recursive - name: Set up Python uses: actions/setup-python@v5 @@ -109,9 +117,34 @@ jobs: python -c "import torch; print('PyTorch:', torch.__version__)" python -c "import torch; print('CUDA:', torch.version.cuda)" python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: bash + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true - name: Build wheel + id: build_wheel run: | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 @@ -122,11 +155,41 @@ jobs: export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Limit MAX_JOBS otherwise the github runner goes OOM # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) NVCC_THREADS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export FLASH_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar - name: Log Built Wheels run: | @@ -142,6 +205,7 @@ jobs: - name: Upload Release Asset id: upload_release_asset + if: inputs.upload-to-release uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000000..9a454b3fcde --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,47 @@ +name: Build wheels + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0a668e291cb..d11b703ef99 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -62,6 +62,7 @@ jobs: torch-version: ${{ matrix.torch-version }} cxx11_abi: ${{ matrix.cxx11_abi }} release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true publish_package: name: Publish package From d0ed097d0089865a8ef027d54fadf9428a44fcee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Fri, 29 Aug 2025 23:00:41 +0200 Subject: [PATCH 089/258] ci: Switch to workflow_dispatch (#1847) --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9a454b3fcde..25ea5e86b75 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,7 +1,7 @@ name: Build wheels on: - workflow_call: + workflow_dispatch: inputs: runs-on: description: "The runner to use for the build" From 203b9b3dba39d5d08dffb49c09aa622984dff07d Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Fri, 29 Aug 2025 23:25:35 +0200 Subject: [PATCH 090/258] [`FA3`] Allow returning LSE via kwarg (#1851) * lse output * style * style * revert test changes, introduce optional kwarg to output lse --- hopper/flash_attn_interface.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index a2eb9594896..a435e7a627d 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -168,6 +168,7 @@ def forward( deterministic=False, num_heads_q=None, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -210,8 +211,7 @@ def forward( ctx.deterministic = deterministic ctx.ndim = qkv.dim() ctx.sm_margin = sm_margin - # return out, softmax_lse - return out + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -270,6 +270,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -305,7 +306,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -363,6 +364,7 @@ def forward( pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -404,7 +406,7 @@ def forward( ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): @@ -451,6 +453,7 @@ def flash_attn_qkvpacked_func( deterministic=False, num_heads_q=None, sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -497,6 +500,7 @@ def flash_attn_qkvpacked_func( deterministic, num_heads_q, sm_margin, + return_attn_probs, ) @@ -515,6 +519,7 @@ def flash_attn_func( pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -576,6 +581,7 @@ def flash_attn_func( pack_gqa, deterministic, sm_margin, + return_attn_probs, ) @@ -600,6 +606,7 @@ def flash_attn_varlen_func( pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): return FlashAttnVarlenFunc.apply( q, @@ -622,6 +629,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, sm_margin, + return_attn_probs, ) From 27b64c7c9b25a4d279b2a42257dd936a8dd2dc23 Mon Sep 17 00:00:00 2001 From: Mingyang Date: Tue, 2 Sep 2025 21:21:09 +0800 Subject: [PATCH 091/258] [BugFix] fix flash_fwd.FlashAttentionForwardSm80 bugs (#1856) * [BugFix] fix softcap condition softcap should only be referenced when its not none, currently the logic is reversed and will result in an error * [BugFix] fix sm80 cuteDSL error 1. Current condition on softcap is wrong and will result in RuntimeError. Change the code to align with sm_100 2. Make window_size_left and window_size_right optional to align with sm_100 and all other interfaces. * Fix typo of range_constexpr * Fix seqlen --- flash_attn/cute/flash_fwd.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index de5fea43b99..783e76866c5 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -434,7 +434,7 @@ def load_K( else: seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) seqlen_limit -= tKcK[0][0] - for n in cutlass.range_constepxr(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: cute.copy( gmem_tiled_copy, @@ -468,7 +468,7 @@ def load_V( # Do we need to check if we overshoot kBlockN when we load V? is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_v): - for n in cutlass.range_constepxr(cute.size(tVsV.shape[1])): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None @@ -476,8 +476,8 @@ def load_V( seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in cutlass.range_constepxr(cute.size(predicate.shape[1])): - for i in cutlass.range_constepxr(cute.size(predicate.shape[0])): + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n cute.copy( gmem_tiled_copy, @@ -586,12 +586,13 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if const_expr(softcap is not None): + if const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E softcap_val = Float32(softmax_scale / softcap) + self.kernel( mQ, mK, @@ -631,8 +632,8 @@ def kernel( mLSE: Optional[cute.Tensor], softmax_scale_log2: Float32, softcap_val: Optional[Float32], - window_size_left: Int32, - window_size_right: Int32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -655,7 +656,7 @@ def kernel( window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoQK(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) + seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: @@ -802,7 +803,7 @@ def preprocess_Q(): preprocess_Q() cute.arch.barrier() # Make sure all threads have read smem_q before loading V - for stage in cutlass.range_constepxr(self.num_stages): + for stage in cutlass.range_constexpr(self.num_stages): if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) @@ -867,7 +868,7 @@ def preprocess_Q(): # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size ) From 6387433156558135a998d5568a9d74c1778666d8 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Tue, 2 Sep 2025 19:25:10 -0400 Subject: [PATCH 092/258] [FIX] Allow m_block_size == 192 and mma_pv_is_rs == False in Sm90 CuTe DSL (#1858) * update num_threads based on num wgs * fix bug when not intra_wg_overlap and not mma_pv_is_rs --- flash_attn/cute/flash_fwd.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 783e76866c5..d1b307acf02 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -951,10 +951,10 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase): arch = 90 - def __init__(self, *args, intra_wg_overlap: bool = True, **kwargs): + def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, **kwargs): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap - self.mma_pv_is_rs = True + self.mma_pv_is_rs = mma_pv_is_rs def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( @@ -1104,11 +1104,18 @@ def __call__( self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) self.num_producer_threads = 32 self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q self.num_epilogue_threads = self.num_mma_threads - self.num_mma_regs = 240 - self.num_producer_regs = 24 + self.num_mma_regs = ( + 256 + if self.num_mma_warp_groups == 1 + else (240 if self.num_mma_warp_groups == 2 else 160) + ) + self.num_producer_regs = ( + 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) + ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) @@ -1794,7 +1801,7 @@ def mma_one_n_block( # tOrP.store(tOrP_acc.load().to(self.dtype)) utils.cvt_f16(tOrP_acc, tOrP) if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(mma_params.tOrP) + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) softmax.rescale_O(mma_params.acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): @@ -1894,7 +1901,11 @@ def warp_scheduler_barrier_arrive(self): if const_expr(self.use_scheduler_barrier): assert self.num_mma_warp_groups in [2, 3] cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 - next_wg = 1 - cur_wg if const_expr(self.num_mma_warp_groups == 2) else (cur_wg + 1 if cur_wg < self.num_mma_warp_groups - 1 else 0) + if const_expr(self.num_mma_warp_groups == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_mma_warp_groups cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, From afc97c60f799e470886c154e3473df938f8fa93d Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 4 Sep 2025 23:28:12 +0200 Subject: [PATCH 093/258] make FA3 compatible with CUDA 13 Builds (#1860) Fix CUDA barrier init crash when num_consumers < NumThreadsPerWarpGroup Previously, integer division caused num_consumer_warpgroups_per_cluster to be 0 when params.num_consumers (e.g., 32) was less than NumThreadsPerWarpGroup (128), leading to a compiler failure during barrier initialization. Changed to round-up division to ensure a minimum value of 1. --- hopper/sm90_pipeline_no_cluster.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/sm90_pipeline_no_cluster.hpp b/hopper/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..1fb805aec1f 100644 --- a/hopper/sm90_pipeline_no_cluster.hpp +++ b/hopper/sm90_pipeline_no_cluster.hpp @@ -39,7 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( From dfb664994c1e5056961c90d5e4f70bf7acc8af10 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 5 Sep 2025 17:52:06 +0200 Subject: [PATCH 094/258] [BUILD] SBSA wheels + CUDA 13 Support (#1865) * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * [BUILD] Update CUDA toolkit and PyTorch versions in CI configuration * drop 12.4 * drop 12.4 * fix correct name * fix correct name * fix correct name * fix correct name * cibuildwheel.yml --- .github/workflows/_build.yml | 21 +++++++++++++++------ .github/workflows/publish.yml | 5 ++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 47d7bb49055..3bbd5f0a4f5 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -77,7 +77,7 @@ jobs: - name: Install CUDA ${{ inputs.cuda-version }} if: ${{ inputs.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.26 + uses: Jimver/cuda-toolkit@v0.2.27 id: cuda-toolkit with: cuda: ${{ inputs.cuda-version }} @@ -98,17 +98,26 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) + # detect if we're on ARM + if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + PLAT=linux_aarch64 + else + PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 + fi + echo "PLAT=$PLAT" >> $GITHUB_ENV if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl + TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl + pip install --no-cache-dir --pre "${TRITON_URL}" + pip install --no-cache-dir --pre "${TORCH_URL}" else pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d11b703ef99..e88090f336d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -40,7 +40,7 @@ jobs: matrix: # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04] + os: [ubuntu-22.04, ubuntu-22.04-arm] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] cuda-version: ["12.9.1"] @@ -49,6 +49,9 @@ jobs: # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ["FALSE", "TRUE"] + include: + - torch-version: "2.9.0.dev20250904" + cuda-version: "13.0" exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 From e8c7344717861b6ea520de3575770ca9a7fa3877 Mon Sep 17 00:00:00 2001 From: Rajesh Shashi Kumar <35628747+rajesh-s@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:00:26 -0500 Subject: [PATCH 095/258] benchmark: qualify all attention backends by methods list (#1881) --- benchmarks/benchmark_flash_attention.py | 77 +++++++++++++++---------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 341ae4b2139..9624ba0c334 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -54,7 +54,7 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + # Adding is faster than masked_fill_ scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) @@ -88,53 +88,65 @@ def time_fwd_bwd(func, *args, **kwargs): speed_f = {} speed_b = {} speed_f_b = {} + for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim - qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - f, b = time_fwd_bwd( - flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False - ) - time_f[config, "Flash2"] = f - time_b[config, "Flash2"] = b - - try: - qkv = qkv.detach().requires_grad_(True) + + # FlashAttention 2 + if "Flash2" in methods: + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) f, b = time_fwd_bwd( - attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, + repeats=repeats, verbose=False ) - except: # Skip if OOM - f, b = float('nan'), float('nan') - time_f[config, "Pytorch"] = f - time_b[config, "Pytorch"] = b - - if attention_triton is not None: - q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - # Try both values of sequence_parallel and pick the faster one + time_f[config, "Flash2"] = f + time_b[config, "Flash2"] = b + + # PyTorch baseline + if "Pytorch" in methods: + try: + # fresh tensor avoids grad-history reuse issues + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + f, b = time_fwd_bwd( + attention_pytorch, qkv, dropout_p, causal=causal, + repeats=repeats, verbose=False + ) + except Exception: + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + # Triton + if "Triton" in methods and attention_triton is not None: + q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] + # Try both values of sequence_parallel and pick the faster backward try: f, b = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), False, repeats=repeats, verbose=False ) - except: + except Exception: f, b = float('nan'), float('inf') try: _, b0 = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), True, repeats=repeats, verbose=False ) - except: + except Exception: b0 = float('inf') time_f[config, "Triton"] = f time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] + # xFormers CUTLASS + if "xformers.c" in methods and xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, @@ -143,9 +155,10 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "xformers.c"] = f time_b[config, "xformers.c"] = b - if xops is not None: - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] + # xFormers Flash + if "xformers.f" in methods and xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, @@ -154,8 +167,11 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "xformers.f"] = f time_b[config, "xformers.f"] = b + # Report print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") for method in methods: + if (config, method) not in time_f or (config, method) not in time_b: + continue time_f_b[config, method] = time_f[config, method] + time_b[config, method] speed_f[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), @@ -175,6 +191,5 @@ def time_fwd_bwd(func, *args, **kwargs): f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" ) - # with open('flash2_attn_time.plk', 'wb') as fp: -# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From b3846b059bf6b143d1cd56879933be30a9f78c81 Mon Sep 17 00:00:00 2001 From: mikaylagawarecki Date: Fri, 12 Sep 2025 15:28:35 -0400 Subject: [PATCH 096/258] ABI stable fa3 (#1791) * squashed * fixes * fixes * Fix narrow * Add TORCH_STABLE_ONLY flag * new_empty + zero_ --> new_zeros * revert flash_api.cpp and add flash_api_stable.cpp * update setup.py * Only pass TORCH_STABLE_ONLY for stable build * Address Jane's comments * > to >= --- hopper/flash_api_stable.cpp | 1973 +++++++++++++++++++++++++++++++++++ hopper/setup.py | 16 +- 2 files changed, 1987 insertions(+), 2 deletions(-) create mode 100644 hopper/flash_api_stable.cpp diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp new file mode 100644 index 00000000000..42601e5692d --- /dev/null +++ b/hopper/flash_api_stable.cpp @@ -0,0 +1,1973 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include + +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +using torch::stable::Tensor; + +namespace { +std::deque device_flags; +std::vector device_properties; + +void initVectors() { + static bool init_flag [[maybe_unused]] = []() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_flags.resize(device_count); + device_properties.resize(device_count); + return true; + }(); +} + +void initDeviceProperty(int device_index) { + cudaDeviceProp device_prop{}; + cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_properties[device_index] = device_prop; +} + +// Helper function to get device properties using raw CUDA APIs +cudaDeviceProp* get_device_prop() { + initVectors(); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDevice failed: " + + std::string(cudaGetErrorString(err))); + } + + std::call_once(device_flags[device_index], initDeviceProperty, device_index); + return &device_properties[device_index]; +} +} // anonymous namespace + + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the STABLE_TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + +#define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) \ + do { \ + auto expected_dims = std::vector{__VA_ARGS__}; \ + STD_TORCH_CHECK(x.dim() == static_cast(expected_dims.size()), #x " must have " + std::to_string(expected_dims.size()) + " dimensions, got " + std::to_string(x.dim())); \ + for (size_t i = 0; i < expected_dims.size(); ++i) { \ + STD_TORCH_CHECK(x.size(i) == expected_dims[i], #x " dimension " + std::to_string(i) + " must have size " + std::to_string(expected_dims[i]) + ", got " + std::to_string(x.size(i))); \ + } \ + } while (0) +#define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = q.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + STD_TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + STD_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + const Tensor out, + const Tensor dout, + Tensor dq, + Tensor dk, + Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + bool deterministic=false, + int const sm_margin=0) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + #endif + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + STD_TORCH_CHECK(params.num_splits >= 1); + ARCH_SWITCH(params.arch, Arch, [&] { + SPLIT_SWITCH(params.num_splits > 1, Split, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { + PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; + SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { + run_mha_fwd_constexpr(params, stream); + }); + }); + }); + }); + }); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHATTENTION_DISABLE_SPLIT + // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + // so that kBlockM is smaller and we have more parallelism. + if (params.is_fp32) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else if (params.is_bf16) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } + #else + STD_TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + +inline int get_max_headdim() { + #ifndef FLASHATTENTION_DISABLE_HDIM256 + return 256; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + return 192; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + return 128; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + return 96; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM64 + return 64; + #endif + return 0; +} + +inline int round_up_headdim(int head_size) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (head_size <= 96) { return 96; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (head_size <= 192) { return 192; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (head_size <= 256) { return 256; } + #endif + return 256; +} + +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + torch::headeronly::ScalarType qkv_dtype, + Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin) { + + STD_TORCH_CHECK(qkv_dtype == torch::headeronly::ScalarType::Half || qkv_dtype == torch::headeronly::ScalarType::BFloat16 || qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? static_cast(cu_seqlens_q_.value().data_ptr()) : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? static_cast(cu_seqlens_k_.value().data_ptr()) : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? static_cast(cu_seqlens_k_new_.value().data_ptr()): nullptr; + params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; + params.seqused_k = static_cast(seqused_k.data_ptr()); + params.leftpad_k = leftpad_k_.has_value() ? static_cast(leftpad_k_.value().data_ptr()) : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + torch::stable::accelerator::DeviceGuard device_guard{(char)seqused_k.get_device()}; + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::stable::new_empty( + seqused_k, + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + std::make_optional(torch::headeronly::ScalarType::Int)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + if (scheduler_needs_semaphore) { + if (!use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset; + } else { + params.tile_count_semaphore = nullptr; + } + } + + if (use_prepare_varlen) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple +mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16 || q_type == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + STD_TORCH_CHECK(page_table.scalar_type() == torch::headeronly::ScalarType::Int, "page_table must have dtype torch.int32"); + STD_TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + STD_TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + STD_TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + const int batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + if (!kv_batch_idx_.has_value()) { + STD_TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + STD_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + STD_TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + int const alignment = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? 16 : 8; + STD_TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + STD_TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto out_type = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? torch::headeronly::ScalarType::BFloat16 : q_type; + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q + ? torch::stable::new_empty(q, {batch_size, seqlen_q, num_heads, head_size_v}, std::make_optional(out_type)) + : torch::stable::new_empty(q, {total_q, num_heads, head_size_v}, std::make_optional(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + torch::stable::accelerator::DeviceGuard device_guard{(char)q.get_device()}; + + Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } + if (paged_KV) { + params.page_table = static_cast(page_table.data_ptr()); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + Tensor k_new, v_new; + STD_TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + STD_TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + STD_TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); + STD_TORCH_CHECK(cu_seqlens_k_new.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + STD_TORCH_CHECK(k_new.scalar_type() == q_type, "k_new must have the same dtype as query"); + STD_TORCH_CHECK(v_new.scalar_type() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); + STD_TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + STD_TORCH_CHECK(scheduler_metadata.scalar_type() == torch::headeronly::ScalarType::Int, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::stable::new_empty(q, {metadata_size}, torch::headeronly::ScalarType::Int); + } + if (scheduler_needs_semaphore && !use_prepare_varlen) { + torch::stable::zero_(tile_count_semaphore); // If varlen we'll manually do the zero-ing + } + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later + } + + if (q_v_.has_value()) { + STD_TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + STD_TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + STD_TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + Tensor q_v = q_v_.value(); + STD_TORCH_CHECK(q_v.scalar_type() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + STD_TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + STD_TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + STD_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + STD_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + STD_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + STD_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + STD_TORCH_CHECK(seqlens_rotary.scalar_type() == torch::headeronly::ScalarType::Int, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = static_cast(seqlens_rotary.data_ptr()); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); + STD_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + Tensor out_accum, softmax_lse_accum; + auto outaccum_type = torch::headeronly::ScalarType::Float; + if (params.num_splits > 1) { + STD_TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + if (q_type == torch::headeronly::ScalarType::Float8_e4m3fn) { + if (q_descale_.has_value()) { + auto q_descale = q_descale_.value(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + params.q_descale_ptr = static_cast(q_descale.data_ptr()); + params.q_descale_batch_stride = q_descale.stride(0); + params.q_descale_head_stride = q_descale.stride(1); + } else { + params.q_descale_ptr = nullptr; + } + if (k_descale_.has_value()) { + auto k_descale = k_descale_.value(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + params.k_descale_ptr = static_cast(k_descale.data_ptr()); + params.k_descale_batch_stride = k_descale.stride(0); + params.k_descale_head_stride = k_descale.stride(1); + } else { + params.k_descale_ptr = nullptr; + } + if (v_descale_.has_value()) { + auto v_descale = v_descale_.value(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + params.v_descale_ptr = static_cast(v_descale.data_ptr()); + params.v_descale_batch_stride = v_descale.stride(0); + params.v_descale_head_stride = v_descale.stride(1); + } else { + params.v_descale_ptr = nullptr; + } + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + #ifdef FLASHATTENTION_DISABLE_SPLIT + STD_TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + #endif + #ifdef FLASHATTENTION_DISABLE_PACKGQA + STD_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + #endif + #ifdef FLASHATTENTION_DISABLE_PAGEDKV + STD_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + #endif + #ifdef FLASHATTENTION_DISABLE_APPENDKV + STD_TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + #endif + + if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + run_mha_fwd(params, stream); + if (params.num_splits > 1) { + if (out_type == torch::headeronly::ScalarType::BFloat16) { + // Since we want output in BF16. Otherwise fwd_combine will output to FP16 + params.is_bf16 = true; + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 + // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. + // if (is_varlen_q && !seqused_q_.has_value()) { + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + auto slice = torch::stable::narrow(tile_count_semaphore, 0, params.tile_count_semaphore_offset, 1); + torch::stable::zero_(slice); + } + } else if (total_q > 0 && num_heads_k > 0) { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(out); + torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); + } + + // return {out, softmax_lse}; + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + STD_TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + ARCH_SWITCH(params.arch, Arch, [&] { + SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { + run_mha_bwd_constexpr(params, stream); + }); + }); +} +#endif + + +// b: batch_size +// s_q: seqlen_q +// s_k: seqlen_k +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple mha_bwd( + Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + STD_TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention only support fp16 and bf16 data type"); + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + STD_TORCH_CHECK(out.scalar_type() == q_type, "query and out must have the same dtype"); + STD_TORCH_CHECK(dout.scalar_type() == q_type, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + // auto const sizes = q.sizes(); + int const batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int const seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int const total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int const num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + STD_TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + if (is_causal) { window_size_right = 0; } + // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. + // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = dprops->major * 10 + dprops->minor; + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; + // Very important that these match the kernel configs + bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) + : (head_size_rounded <= 96 ? 64 + : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) + : 64)); + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + int const kBlockN_sm90 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 96 : 80); + int const kBlockN_sm80 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 + : (head_size_rounded <= 96 ? 128 + : (head_size_rounded <= 128 ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()){ + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + STD_TORCH_CHECK(dq.scalar_type() == q_type, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } + } else { + dq = torch::stable::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + STD_TORCH_CHECK(dk.scalar_type() == q_type, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } + } else { + dk = torch::stable::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + STD_TORCH_CHECK(dv.scalar_type() == q_type, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); + } + } else { + dv = torch::stable::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + torch::stable::accelerator::DeviceGuard device_guard{(char)q.get_device()}; + + // auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + Tensor softmax_d, softmax_lse_log2; + if (!is_varlen) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_d = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + Tensor dq_accum, dk_accum, dv_accum; + if (!is_varlen) { + dq_accum = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + dq_accum = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } else { + dk_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } + } + + Flash_bwd_params params; + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk, dv, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, + num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + 0, // attention_chunk + softcap, + deterministic, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::headeronly::ScalarType::Int)) : torch::empty({1}, opts.dtype(torch::headeronly::ScalarType::Int)); + // params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()); + // Will be zero'ed out in the backward preprocess kernel + Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dq_semaphore = static_cast(dq_semaphore.data_ptr()); + if (num_heads_k != num_heads && params.deterministic) { + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + Tensor dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + Tensor dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dk_semaphore = static_cast(dk_semaphore.data_ptr()); + params.dv_semaphore = static_cast(dv_semaphore.data_ptr()); + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + run_mha_bwd(params, stream); + } else if (total_k > 0 && num_heads_k > 0) { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(dk); + torch::stable::zero_(dv); + torch::stable::zero_(softmax_d); + } else if (total_q > 0 && num_heads_k > 0) { + torch::stable::zero_(dq); + torch::stable::zero_(softmax_d); + } + + return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; +} + +std::tuple +mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); + + auto out_partial_type = out_partial.scalar_type(); + STD_TORCH_CHECK(out_partial_type == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + STD_TORCH_CHECK(lse_partial.scalar_type() == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + + CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); + + STD_TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); + + // const auto sizes = out_partial.sizes(); + + const int num_splits = out_partial.size(0); + const int batch_size = out_partial.size(1); + const int seqlen = out_partial.size(2); + const int num_heads = out_partial.size(3); + const int head_size_og = out_partial.size(4); + STD_TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); + + CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); + CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); + + int const alignment = 4; + Tensor out_partial_padded; + auto pad = [](Tensor x, int alignment) { + return x.size(-1) % alignment == 0 ? x : torch::stable::pad(x, {0, alignment - x.size(-1) % alignment}); + }; + out_partial_padded = pad(out_partial, alignment); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, alignment); + + // auto opts = out_partial.options(); + torch::headeronly::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); + STD_TORCH_CHECK(out_type == torch::headeronly::ScalarType::Float || out_type == torch::headeronly::ScalarType::BFloat16 || out_type == torch::headeronly::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); + if (head_size_og % alignment != 0) { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + } else { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + torch::stable::accelerator::DeviceGuard device_guard{(char)out_partial.get_device()}; + + auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse = torch::stable::transpose(softmax_lse, 1, 2); + + Flash_fwd_params params {}; // Need to reset the params to set everything to zero + params.is_fp32 = out_type == torch::headeronly::ScalarType::Float; + params.is_bf16 = out_type == torch::headeronly::ScalarType::BFloat16; + params.oaccum_ptr = out_partial_padded.data_ptr(); + params.softmax_lseaccum_ptr = lse_partial.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + params.b = batch_size; + params.h = num_heads; + params.seqlen_q = seqlen; + params.dv = head_size; + params.num_splits = num_splits; + params.oaccum_split_stride = out_partial_padded.stride(0); + params.oaccum_row_stride = out_partial_padded.stride(2); + params.oaccum_head_stride = out_partial_padded.stride(3); + params.oaccum_batch_stride = out_partial_padded.stride(1); + params.lseaccum_split_stride = lse_partial.stride(0); + params.lseaccum_head_stride = lse_partial.stride(3); + params.lseaccum_batch_stride = lse_partial.stride(1); + params.o_row_stride = out.stride(1); + params.o_head_stride = out.stride(2); + params.o_batch_stride = out.stride(0); + params.arch = dprops->major * 10 + dprops->minor; + + if (seqlen > 0 && batch_size > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); + } + + Tensor out_padded = out; + if (head_size_og % alignment != 0) { + out = torch::stable::narrow(out, -1, 0, head_size_og); + // if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + +void boxed_mha_fwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto q = to(stack[0]); + auto k = to(stack[1]); + auto v = to(stack[2]); + auto k_new = to>(stack[3]); + auto v_new = to>(stack[4]); + auto q_v = to>(stack[5]); + auto out = to>(stack[6]); + auto cu_seqlens_q = to>(stack[7]); + auto cu_seqlens_k = to>(stack[8]); + auto cu_seqlens_k_new = to>(stack[9]); + auto seqused_q = to>(stack[10]); + auto seqused_k = to>(stack[11]); + auto max_seqlen_q = to>(stack[12]); + auto max_seqlen_k = to>(stack[13]); + auto page_table = to>(stack[14]); + auto kv_batch_idx = to>(stack[15]); + auto leftpad_k = to>(stack[16]); + auto rotary_cos = to>(stack[17]); + auto rotary_sin = to>(stack[18]); + auto seqlens_rotary = to>(stack[19]); + auto q_descale = to>(stack[20]); + auto k_descale = to>(stack[21]); + auto v_descale = to>(stack[22]); + auto softmax_scale = to>(stack[23]); + auto is_causal = to(stack[24]); + auto window_size_left = to(stack[25]); + auto window_size_right = to(stack[26]); + auto attention_chunk = to(stack[27]); + auto softcap = to(stack[28]); + auto is_rotary_interleaved = to(stack[29]); + auto scheduler_metadata = to>(stack[30]); + auto num_splits = to(stack[31]); + auto pack_gqa = to>(stack[32]); + auto sm_margin = to(stack[33]); + + auto [out_, softmax_lse, out_accum, softmax_lse_accum] = mha_fwd(q, k, v, k_new, v_new, q_v, out, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, is_causal, window_size_left, window_size_right, attention_chunk, softcap, is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin); + + + stack[0] = from(out_); + stack[1] = from(softmax_lse); + stack[2] = from(out_accum); + stack[3] = from(softmax_lse_accum); +} + +void boxed_mha_bwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto dout = to(stack[0]); + auto q = to(stack[1]); + auto k = to(stack[2]); + auto v = to(stack[3]); + auto out = to(stack[4]); + auto softmax_lse = to(stack[5]); + auto dq = to>(stack[6]); + auto dk = to>(stack[7]); + auto dv = to>(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto seqused_q = to>(stack[11]); + auto seqused_k = to>(stack[12]); + auto max_seqlen_q = to>(stack[13]); + auto max_seqlen_k = to>(stack[14]); + auto softmax_scale = to>(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto softcap = to(stack[19]); + auto deterministic = to(stack[20]); + auto sm_margin = to(stack[21]); + + auto [dq_, dk_, dv_, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); + + stack[0] = from(dq_); + stack[1] = from(dk_); + stack[2] = from(dv_); + stack[3] = from(softmax_d); + stack[4] = from(softmax_lse_log2); + stack[5] = from(dq_accum); + stack[6] = from(dk_accum); + stack[7] = from(dv_accum); +} + +void boxed_mha_combine( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto out_partial = to(stack[0]); + auto lse_partial = to(stack[1]); + auto out = to>(stack[2]); + auto out_dtype = to>(stack[3]); + + auto [out_, softmax_lse] = mha_combine(out_partial, lse_partial, out, out_dtype); + + stack[0] = from(out_); + stack[1] = from(softmax_lse); +} + +void boxed_mha_fwd_get_scheduler_metadata( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto batch_size = to(stack[0]); + auto max_seqlen_q = to(stack[1]); + auto max_seqlen_k = to(stack[2]); + auto num_heads = to(stack[3]); + auto num_heads_k = to(stack[4]); + auto headdim = to(stack[5]); + auto headdim_v = to(stack[6]); + auto qkv_dtype = to(stack[7]); + auto seqused_k = to(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto cu_seqlens_k_new = to>(stack[11]); + auto seqused_q = to>(stack[12]); + auto leftpad_k = to>(stack[13]); + auto page_size = to>(stack[14]); + auto max_seqlen_k_new = to(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto attention_chunk = to(stack[19]); + auto has_softcap = to(stack[20]); + auto num_splits = to(stack[21]); + auto pack_gqa = to>(stack[22]); + auto sm_margin = to(stack[23]); + + auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin); + + stack[0] = from(scheduler_metadata); +} + +STABLE_TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &boxed_mha_fwd); + m.impl("bwd", &boxed_mha_bwd); + m.impl("fwd_combine", &boxed_mha_combine); + m.impl("get_scheduler_metadata", &boxed_mha_fwd_get_scheduler_metadata); +} diff --git a/hopper/setup.py b/hopper/setup.py index 850fb0b520c..74713208aa0 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -526,8 +526,20 @@ def nvcc_threads_args(): if DISABLE_BACKWARD: sources_bwd_sm90 = [] sources_bwd_sm80 = [] + + # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version + torch_version = parse(torch.__version__) + target_version = parse("2.9.0.dev20250830") + stable_args = [] + + if torch_version >= target_version: + flash_api_source = "flash_api_stable.cpp" + stable_args = ["-DTORCH_STABLE_ONLY"] # Checks against including unstable Tensor APIs + else: + flash_api_source = "flash_api.cpp" + sources = ( - ["flash_api.cpp"] + [flash_api_source] + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90 + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90 ) @@ -566,7 +578,7 @@ def nvcc_threads_args(): name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + feature_args, + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, From 7bdb426659f976fdf269a5255b0a08abd08d62b8 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 21:32:48 +0200 Subject: [PATCH 097/258] [NVIDIA] Enable Blackwell Family Specific (#1882) * fix typo * Update setup.py * Update setup.py * Update setup.py * Update setup.py --- .github/workflows/publish.yml | 2 +- setup.py | 74 ++++++++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e88090f336d..26013ad5d67 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -51,7 +51,7 @@ jobs: cxx11_abi: ["FALSE", "TRUE"] include: - torch-version: "2.9.0.dev20250904" - cuda-version: "13.0" + cuda-version: "13.0.0" exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 diff --git a/setup.py b/setup.py index a108c412c00..9a406839e7f 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";") + return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") def get_platform(): @@ -94,6 +94,59 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version +def add_cuda_gencodes(cc_flag, archs, bare_metal_version): + """ + Adds -gencode flags based on nvcc capabilities: + - sm_80/90 (regular) + - sm_100/120 on CUDA >= 12.8 + - Use 100f on CUDA >= 12.9 (Blackwell family-specific) + - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) + - Embed PTX for newest arch for forward compatibility + """ + # Always-regular 80 + if "80" in archs: + cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + + # Hopper 9.0 needs >= 11.8 + if bare_metal_version >= Version("11.8") and "90" in archs: + cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] + + # Blackwell 10.x requires >= 12.8 + if bare_metal_version >= Version("12.8"): + if "100" in archs: + # CUDA 12.9 introduced "family-specific" for Blackwell (100f) + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if "120" in archs: + # sm_120 is supported in CUDA 12.8/12.9+ toolkits + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + else: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 + if "110" in archs: + if bare_metal_version >= Version("13.0"): + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] + else: + # Provide Thor support for CUDA 12.9 via sm_101 + if bare_metal_version >= Version("12.8"): + cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] + # else: no Thor support in older toolkits + + # PTX for newest requested arch (forward-compat) + numeric = [a for a in archs if a.isdigit()] + if numeric: + newest = max(numeric, key=int) + cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + + return cc_flag + + def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) @@ -175,20 +228,11 @@ def validate_and_update_archs(archs): "FlashAttention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) - - if "80" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: - if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") + # Build -gencode (regular + PTX + family-specific 'f' when available) + add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) + else: + # No nvcc present; warnings already emitted above + pass # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI From e980f0f6e15ae3a7bc2a29e5610e8a9bfe25f7a6 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 12 Sep 2025 19:38:04 -0700 Subject: [PATCH 098/258] fix typo in flops calculation for local attention (#1883) --- benchmarks/benchmark_attn.py | 2 +- hopper/benchmark_attn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index b3902110eea..7830477a68a 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -70,7 +70,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 33e5d282716..e94d325d42d 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -68,7 +68,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) From 2cc6fd6abbc5f1100e51eab63d92b678fda06c7d Mon Sep 17 00:00:00 2001 From: Ted Zadouri Date: Sat, 13 Sep 2025 14:52:17 -0400 Subject: [PATCH 099/258] flash-attn-cute bwd sm90 (#1868) --- flash_attn/cute/block_info.py | 12 + flash_attn/cute/flash_bwd_postprocess.py | 206 +++- flash_attn/cute/flash_bwd_sm90.py | 1392 ++++++++++++++++++++++ flash_attn/cute/hopper_helpers.py | 23 + flash_attn/cute/named_barrier.py | 13 + 5 files changed, 1644 insertions(+), 2 deletions(-) create mode 100644 flash_attn/cute/flash_bwd_sm90.py diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 2914e42e2ab..50e6371dda3 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -42,6 +42,18 @@ def get_n_block_min_max( n_block_min = cutlass.max(n_idx_left // self.n_block_size, 0) return n_block_min, n_block_max + @cute.jit + def get_m_block_min_max( + self, seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32 + ) -> Tuple[cutlass.Int32, cutlass.Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.m_block_size) + + m_block_min = 0 + + return m_block_min, m_block_max + + + @cute.jit def get_n_block_min_causal_local_mask( self, diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 6a408906d53..b0fa2704138 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -8,9 +8,9 @@ import cutlass import cutlass.cute as cute -from cutlass.cute.nvgpu import cpasync, warp - +from cutlass.cute.nvgpu import cpasync, warp, warpgroup from flash_attn.cute import ampere_helpers as sm80_utils +import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import utils @@ -304,3 +304,205 @@ def kernel( tdQgdQ[None, rest_m, None], pred=tdQpdQ[None, rest_m, None], ) + + +class FlashAttentionBackwardPostprocess_sm90(FlashAttentionBackwardPostprocess): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.universal_copy_bits = 128 + + def _setup_attributes(self): + self.sdQaccum_layout = cute.make_layout( + shape=(self.m_block_size * self.head_dim_padded, ), + ) + + sdQ_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + cutlass.utils.hopper_helpers.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded + ), + self.dtype + ) + self.sdQ_layout = cute.tile_to_shape( + sdQ_layout_atom, + (self.m_block_size, self.head_dim_padded), + (0, 1) + ) + # G->S + async_copy_elements = self.universal_copy_bits // cutlass.Float32.width + self.G2S_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=self.universal_copy_bits + ), + cute.make_layout(self.tiled_mma.size), + cute.make_layout(async_copy_elements) + ) + + # S->R + self.S2R_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=self.universal_copy_bits), + cute.make_layout(self.tiled_mma.size), + cute.make_layout(async_copy_elements) + ) + + @cute.jit + def __call__( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + stream: cuda.CUstream, + ): + + mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1,3,2,0])) + mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2,1,0])) + + # tiled_mma + tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + cutlass.Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), + tiler_mn=(64, self.head_dim_padded) + ) + + self.tiled_mma = tiled_mma + self.num_mma_threads = tiled_mma.size + self._setup_attributes() + + + # TMA setup + tma_atom_dQ, mdQ = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdQ, + self.sdQ_layout, + (self.m_block_size, self.head_dim_padded), + ) + + seqlen = mdQ.shape[0] + grid_dim = [ + cute.ceil_div(seqlen, self.m_block_size), + cute.size(mdQ.shape[2]), + cute.size(mdQ.shape[3]), + ] + smem_size = max( + cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), + cute.size_in_bytes(self.dtype, self.sdQ_layout) + ) + self.kernel( + mdQaccum, + mdQ, + tma_atom_dQ, + tiled_mma, + self.sdQaccum_layout, + self.sdQ_layout, + self.G2S_tiled_copy_dQaccum, + self.S2R_tiled_copy_dQaccum, + scale, + ).launch( + grid=grid_dim, + block=[self.num_mma_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + tiled_mma: cute.TiledMma, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + g2s_tiled_copy_dQaccum: cute.TiledCopy, + s2r_tiled_copy_dQaccum: cute.TiledCopy, + scale: cutlass.Float32, + ): + # basic setup + tidx = cute.arch.thread_idx()[0] + m_block, head_idx, batch_idx = cute.arch.block_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + smem = cutlass.utils.SmemAllocator() + sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=128) + sdQ = cute.make_tensor( + cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), + sdQ_layout.outer + ) + + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_dQ) + + # G->S + gdQaccum = cute.local_tile( + mdQaccum[None, head_idx, batch_idx], + (self.m_block_size * self.head_dim_padded, ), + (m_block,) + ) + + gmem_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQaccumgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + tdQaccumsdQaccum = gmem_thr_copy_dQaccum.partition_D(sdQaccum) + + cute.copy(g2s_tiled_copy_dQaccum, tdQaccumgdQaccum, tdQaccumsdQaccum) + cute.arch.barrier() + + # S->R + acc_dQaccum = cute.make_fragment( + tiled_mma.partition_shape_C((self.m_block_size, self.head_dim_padded)), + cutlass.Float32 + ) + acc_dQaccum.fill(0) + + smem_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_S(sdQaccum) + + + tdQaccumrdQaccum = cute.make_tensor(acc_dQaccum.iterator, cute.make_layout(tdQaccumsdQaccum.shape)) + cute.copy(smem_thr_copy_dQaccum, tdQaccumsdQaccum, tdQaccumrdQaccum) + + + # Scale + FP32->BF16/FP16 + acc_mmaA_view = cute.make_tensor(acc_dQaccum.iterator, utils.convert_layout_acc_frgA(acc_dQaccum.layout)) + rdQ = cute.make_fragment_like(acc_mmaA_view, self.dtype) + + acc_dQaccum.store(acc_dQaccum.load() * scale) + utils.cvt_f16(acc_mmaA_view, rdQ) # BF16/FP16 output + + + # R->S (StMatrix) + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, #BF16/FP16 + ) + + smem_thr_copy = cute.make_tiled_copy_C(smem_copy_atom, tiled_mma).get_slice(tidx) + tdQsdQ = smem_thr_copy.partition_D(sdQ) + tdQrdQ = cute.make_tensor(rdQ.iterator, cute.make_layout(tdQsdQ.shape)) + + cute.copy(smem_thr_copy, tdQrdQ, tdQsdQ) + cute.arch.barrier() + + #S->G (TMA) + gdQ = cute.local_tile( + mdQ[None, None, head_idx, batch_idx], + (self.m_block_size, self.head_dim_padded), + (m_block, 0) + ) + + tdQsdQ, tdQgdQ = cpasync.tma_partition( + tma_atom_dQ, + 0, + cute.make_layout(1), + cute.group_modes(sdQ, 0, 2), + cute.group_modes(gdQ, 0, 2) + ) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + if warp_idx == 4: # only one warp writes + cute.copy(tma_atom_dQ, tdQsdQ, tdQgdQ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py new file mode 100644 index 00000000000..8163fb3663c --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -0,0 +1,1392 @@ +import math +from typing import Callable, Optional, Type +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warpgroup +#import cutlass.pipeline +import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass import const_expr + +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase +from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd + +class FlashAttentionBackwardSm90: + arch = 90 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + m_block_size: int = 64, + n_block_size: int = 128, + num_stages: int = 2, + num_threads: int = 384, + Q_in_regs: bool = False, + ): + + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.qhead_per_kvhead = qhead_per_kvhead + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.num_threads = num_threads + self.num_stages = num_stages + self.Q_in_regs = Q_in_regs + + @staticmethod + def can_implement( + dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages, num_threads, + Q_in_regs=False + ) -> bool: + + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + + if (m_block_size * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mdO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric], + mdPsum_type: Type[cutlass.Numeric], + mdQaccum_type: Type[cutlass.Numeric], + mdK_type: Type[cutlass.Numeric], + mdV_type: Type[cutlass.Numeric], + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [cutlass.Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mdPsum_type not in [cutlass.Float32]): + raise TypeError("dPsum tensor must be Float32") + if const_expr(mdQaccum_type not in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not (mdK_type == mdV_type == mQ_type)): + raise TypeError("mdK and mdV tensors must have the same data type as mQ") + else: + if const_expr(not (mdK_type == mdV_type == cutlass.Float32)): + raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") + assert mQ_type == self.dtype + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, + self.dtype, + self.head_dim_padded + ), + self.dtype + ) + sK_layout_atom = sQ_layout_atom + + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, + self.dtype, + self.head_dim_v_padded + ), + self.dtype + ) + sPdS_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, + self.dtype, + self.n_block_size + ), + self.dtype + ) + sdO_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, + self.dtype, + self.head_dim_padded + ), + self.dtype + ) + + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom + + + def _setup_attributes(self): + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom = self._get_smem_layout_atom() + + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + + self.sQ_layout = cute.tile_to_shape(sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2),) + self.sK_layout = cute.tile_to_shape(sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),) + self.sV_layout = cute.tile_to_shape(sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),) + self.sdO_layout = cute.tile_to_shape(sdO_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2),) + + self.sPdS_layout = cute.tile_to_shape(sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),) + self.sdQaccum_layout = cute.make_layout(shape=(self.m_block_size * self.head_dim_padded, ),) + + + # dQaccum R->S + self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits), + cute.make_layout(self.num_mma_threads), + cute.make_layout(universal_copy_bits // cutlass.Float32.width) + ) + + # dV: S->G + tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + tdV_layout = cute.make_ordered_layout( + (self.num_mma_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), + ) + self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv( + atom_universal_copy, + tdV_layout, + cute.make_layout((1, async_copy_elems)) + ) + + # dK: S->G + tK_shape_dim_1 = sK_layout_atom.outer.shape[1] // async_copy_elems + tdK_layout = cute.make_ordered_layout( + (self.num_mma_threads // tK_shape_dim_1, tK_shape_dim_1), + order=(1, 0), + ) + self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv( + atom_universal_copy, + tdK_layout, + cute.make_layout((1, async_copy_elems)) + ) + + def _get_tiled_mma(self): + + # C = A @ B.T + tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + cutlass.Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), + tiler_mn=(64, self.n_block_size), + ) + # C = A.T @ B + tiled_mma_dKV = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN, + cutlass.Float32, + atom_layout_mnk=(self.n_block_size // 64 , 1, 1), + tiler_mn=(64, self.head_dim_padded), + ) + # C = A @ B + tiled_mma_dQaccum = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + cutlass.Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), + tiler_mn=(64, self.head_dim_padded), + ) + + return tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum + + + def _get_shared_storage_cls(self): + sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 128 + + sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ + cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] + for (layout, type, alignment) in [ + (self.sQ_layout, self.dtype, sQ_alignment), + (self.sK_layout, self.dtype, sK_alignment), + (self.sV_layout, self.dtype, sV_alighment), + (self.sdO_layout, self.dtype, sdO_alignment), + (self.sdQaccum_layout, cutlass.Float32, sdQaccum_alignment) + ] + ] + + cosize_sPdS = cute.cosize(self.sPdS_layout) + sPdS_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sPdS], 1024] + sLSE_struct = cute.struct.Align[cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128] + sdPsum_struct = cute.struct.Align[cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128] + + mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_LSE_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_dPsum_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_dO_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, 2] + + + @cute.struct + class SharedStorageQKV: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + mbar_ptr_lse: mbar_ptr_LSE_struct + mbar_ptr_dpsum: mbar_ptr_dPsum_struct + mbar_ptr_dO: mbar_ptr_dO_struct + + sQ: sQ_struct + sV: sV_struct + sK: sK_struct + sPdS: sPdS_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sdO: sdO_struct + sdQaccum: sdQaccum_struct + + return SharedStorageQKV + + @cute.jit + def __call__(self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + + mdO: cute.Tensor, + mLSE: cute.Tensor, + + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + + softmax_scale: cutlass.Float32, + stream: cuda.CUstream, + + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + + softcap: cutlass.Float32 | float | None = None, + window_size_left: cutlass.Int32 | int | None = None, + window_size_right: cutlass.Int32 | int | None = None, + ): + + self._check_type( + *(t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)) + ) + + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdK, mdV, mdO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=layout_transpose)) + for t in (mQ, mK, mV, mdK, mdV, mdO) + ] + + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + mLSE, mdPsum, mdQaccum = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_dPsum_dQaccum_transpose)) + for t in (mLSE, mdPsum, mdQaccum) + ] + + + tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum = self._get_tiled_mma() + + self.tiled_mma_SdP = tiled_mma_SdP + self.tiled_mma_dKV = tiled_mma_dKV + self.tiled_mma_sdQaccum = tiled_mma_dQaccum + + self.num_mma_threads = tiled_mma_SdP.size + + self.num_threads_per_warp_group = 128 + self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_producer_threads = 32 + + self.num_mma_regs = 240 + self.num_producer_regs = 24 + + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + + + self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) + self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) + self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sK_layout, mode=[0, 1])) + + self.tma_copy_do_bytes = cute.size_in_bytes(mdO.element_type, cute.select(self.sdO_layout, mode=[0,1])) + self.tma_copy_lse_bytes = self.m_block_size * 4 + self.tma_copy_dPsum_bytes = self.m_block_size * 4 + + + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mQ, + cute.select(self.sQ_layout, mode=[0, 1]), + (self.m_block_size, self.head_dim_padded), + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim_padded), + 1 + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mV, + cute.select(self.sV_layout, mode=[0,1]), + (self.n_block_size, self.head_dim_v_padded), + 1 + ) + tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mdO, + cute.select(self.sdO_layout, mode=[0,1]), + (self.m_block_size, self.head_dim_padded) + ) + tma_atom_LSE, tma_tensor_LSE = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mLSE, + cute.make_layout(self.m_block_size), (self.m_block_size,), + ) + tma_atom_dPsum, tma_tensor_dPsum = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mdPsum, + cute.make_layout(self.m_block_size), (self.m_block_size, ), + ) + TileScheduler = SingleTileScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.n_block_size), + cute.size(mK.shape[2]), + cute.size(mK.shape[3]), + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.m_block_size, self.n_block_size), + mCuSeqlensQ=None, + mSeqUsedQ=None, + qhead_per_kvhead_packgqa= 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=False, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + tma_tensor_LSE, + tma_tensor_dPsum, + tma_tensor_dO, + + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_LSE, + tma_atom_dPsum, + tma_atom_dO, + + mdK, + mdV, + mdQaccum, + + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sPdS_layout, + self.sdO_layout, + self.sdQaccum_layout, + + self.gmem_tiled_copy_dV, + self.gmem_tiled_copy_dK, + self.r2s_tiled_copy_dQaccum, + + tiled_mma_SdP, + tiled_mma_dKV, + tiled_mma_dQaccum, + + softmax_scale_log2, + softmax_scale, + tile_sched_params, + TileScheduler, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_LSE: Optional[cute.CopyAtom], + tma_atom_dPsum: Optional[cute.CopyAtom], + tma_atom_dO: Optional[cute.CopyAtom], + + mdK: cute.Tensor, + mdV: cute.Tensor, + mdQaccum: cute.Tensor, + + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, + tiled_mma_dQaccum: cute.TiledMma, + + softmax_scale_log2, + softmax_scale, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + + # prefetch TMA descriptors + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_LSE) + cpasync.prefetch_descriptor(tma_atom_dPsum) + cpasync.prefetch_descriptor(tma_atom_dO) + + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + mbar_ptr_K = storage.mbar_ptr_K.data_ptr() + mbar_ptr_V = storage.mbar_ptr_V.data_ptr() + + # mbarrier init + if warp_idx == 1: + cute.arch.mbarrier_init(mbar_ptr_K, 1) + cute.arch.mbarrier_init(mbar_ptr_V, 1) + + pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group) + + pipeline_q = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_Q.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_q_bytes, + init_wait=False, + ) + pipeline_lse = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_lse.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_lse_bytes, + init_wait=False, + ) + pipeline_dpsum = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_dpsum.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_dPsum_bytes, + init_wait=False, + ) + pipeline_do = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_dO.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_do_bytes, + init_wait=False, + ) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQt = utils.transpose_view(sQ) + + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + + sLSE_load = storage.sLSE.get_tensor(cute.make_layout( + (self.m_block_size, self.num_stages), + stride=(1, cute.round_up(self.m_block_size, 64)) + )) + sLSE_mma = storage.sLSE.get_tensor(cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages), + stride=(1, 0, cute.round_up(self.m_block_size, 64)) + )) + sdPsum_load = storage.sdPsum.get_tensor(cute.make_layout( + (self.m_block_size, self.num_stages), + stride=(1, cute.round_up(self.m_block_size, 64)) + )) + sdPsum_mma = storage.sdPsum.get_tensor(cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages), + stride=(1, 0, cute.round_up(self.m_block_size, 64)) + )) + + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + + + + sP = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sPt = utils.transpose_view(sP) + + sdS = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sdSt = utils.transpose_view(sdS) + + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sdOt = utils.transpose_view(sdO) + + + block_info = BlockInfo(self.m_block_size, self.n_block_size, False, False,None, None, qhead_per_kvhead_packgqa=1,) + SeqlenInfoCls = partial( + SeqlenInfoQK, seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=None, mCuSeqlensK=None, + mSeqUsedQ=None, mSeqUsedK=None + ) + + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + if warp_idx == 0: + self.load( + mQ, + mK, + mV, + mLSE, + mdPsum, + mdO, + + sQ, + sK, + sV, + sLSE_load, + sdPsum_load, + sdO, + + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_LSE, + tma_atom_dPsum, + tma_atom_dO, + + pipeline_q, + pipeline_lse, + pipeline_dpsum, + pipeline_do, + + mbar_ptr_K, + mbar_ptr_V, + + SeqlenInfoCls, + TileSchedulerCls, + ) + if warp_idx == 1: + cute.arch.barrier_arrive(barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + self.dQaccum_writer( + mdQaccum, + sdQaccum, + TileSchedulerCls, + SeqlenInfoCls, + ) + else: + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + + self.mma( + tiled_mma_SdP, + tiled_mma_dKV, + tiled_mma_dQaccum, + + mdK, + mdV, + mdQaccum, + + sQ, + sQt, + sK, + sV, + + sP, + sPt, + + sdS, + sdSt, + + sdO, + sdOt, + + sLSE_mma, + sdPsum_mma, + + sdQaccum, + + pipeline_q, + pipeline_lse, + pipeline_dpsum, + pipeline_do, + + mbar_ptr_K, + mbar_ptr_V, + tidx, + gmem_tiled_copy_dV, + gmem_tiled_copy_dK, + r2s_tiled_copy_dQaccum, + + softmax_scale_log2, + softmax_scale, + + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + sdO: cute.Tensor, + + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + + tma_atom_LSE: cute.CopyAtom, + tma_atom_dPsum: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, + pipeline_dpsum: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + + mbar_ptr_K: cutlass.Pointer, + mbar_ptr_V: cutlass.Pointer, + + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + + if warp_idx_in_wg == 0: + producer_state = pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.num_stages) + + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + mK_cur = mK[None, None, head_idx, batch_idx] + gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + + mV_cur = mV[None, None, head_idx, batch_idx] + gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + + mQ_cur = mQ[None, None, head_idx, batch_idx] + gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) + + mLSE_cur = mLSE[None, head_idx, batch_idx] + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) + + mdPsum_cur = mdPsum[None, head_idx, batch_idx] + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + + mdO_cur = mdO[None, None, head_idx, batch_idx] + gdO = cute.local_tile(mdO_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) + + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 2), + cute.group_modes(gQ, 0, 2), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 2), + cute.group_modes(gK, 0, 2), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 2), + cute.group_modes(gV, 0, 2), + ) + tLSEsLSE, tLSEgLSE = cpasync.tma_partition( + tma_atom_LSE, + 0, + cute.make_layout(1), + sLSE, + gLSE, + ) + tdPsumsdPsum, tdPsumgdPsum = cpasync.tma_partition( + tma_atom_dPsum, + 0, + cute.make_layout(1), + sdPsum, + gdPsum, + ) + tdOsdO, tdOgdO = cpasync.tma_partition( + tma_atom_dO, + 0, + cute.make_layout(1), + cute.group_modes(sdO, 0, 2), + cute.group_modes(gdO, 0, 2), + ) + + load_Q = partial(self.load_m_tile, tma_atom_Q, tQgQ, tQsQ, pipeline_q) + load_LSE = partial(self.load_m_tile, tma_atom_LSE, tLSEgLSE, tLSEsLSE, pipeline_lse) + load_dPsum = partial(self.load_m_tile, tma_atom_dPsum, tdPsumgdPsum, tdPsumsdPsum, pipeline_dpsum) + load_dO = partial(self.load_m_tile, tma_atom_dO, tdOgdO, tdOsdO, pipeline_dO) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_K, self.tma_copy_k_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_V, self.tma_copy_v_bytes) + + cute.copy(tma_atom_K, tKgK, tKsK, tma_bar_ptr=mbar_ptr_K) + cute.copy(tma_atom_V, tVgV, tVsV, tma_bar_ptr=mbar_ptr_V) + + m_block_min, m_block_max = 0, cute.ceil_div(seqlen.seqlen_q, self.m_block_size) + + for i in cutlass.range(m_block_max - m_block_min, unroll=2): + m_block = m_block_max - i - 1 + + load_Q(m_block, producer_state=producer_state) + load_LSE(m_block, producer_state=producer_state) + load_dPsum(m_block, producer_state=producer_state) + load_dO(m_block, producer_state=producer_state) + + producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def mma( + self, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, + tiled_mma_dQaccum: cute.TiledMma, + + mdK: cute.Tensor, + mdV: cute.Tensor, + mdQaccum: cute.Tensor, + + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + + sP: cute.Tensor, + sPt: cute.Tensor, + + sdS: cute.Tensor, + sdSt: cute.Tensor, + + sdO: cute.Tensor, + sdOt: cute.Tensor, + + sLSE_mma: cute.Tensor, + sdPsum_mma: cute.Tensor, + + sdQaccum: cute.Tensor, + + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, + pipeline_dPsum: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + + mbar_ptr_K: cutlass.Pointer, + mbar_ptr_V: cutlass.Pointer, + + tidx: cutlass.Int32, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + + softmax_scale_log2: cutlass.Float32, + softmax_scale: cutlass.Float32, + + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout(self.num_mma_warp_groups, stride=self.num_threads_per_warp_group) + + wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQaccum = tiled_mma_dQaccum.get_slice(warp_group_thread_layout(warp_group_idx)) + + smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(tidx) + + # S = Q @ K.T + tSrQ = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sQ)) + tSrK = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sK)) + + # dP = dO @ V.T + tdPrdO = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sdO)) + tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) + + # P = exp(S-LSE) + tPsP = smem_thr_copy_PdS.partition_D(sP) + + LSEslice = (None, 0, None) + tLSEsLSE_2D = utils.make_acc_tensor_mn_view(tiled_mma_SdP.get_slice(tidx).partition_C(sLSE_mma))[LSEslice] + + # dS = P*(dP-dPsum) + tdSsdS = smem_thr_copy_PdS.partition_D(sdS) + + dPsumslice = (None, 0, None) + tdPsumsdPsum_2D = utils.make_acc_tensor_mn_view(tiled_mma_SdP.get_slice(tidx).partition_C(sdPsum_mma))[dPsumslice] + + # dV += P.T @ dO + tdVrPt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sPt)) + tdVrdOt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sdOt)) + + # dK += dS.T @ Q + tdKrdSt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sdSt)) + tdKrQt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sQt)) + + # dQ = dS @ K + sKt = utils.transpose_view(sK) + tdQaccumrdS = tiled_mma_dQaccum.make_fragment_A(wg_mma_dQaccum.partition_A(sdS)) + tdQaccumrK = tiled_mma_dQaccum.make_fragment_B(wg_mma_dQaccum.partition_B(sKt)) + + + smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) + tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + + acc_dV = cute.make_fragment( + tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), + cutlass.Float32 + ) + acc_dK = cute.make_fragment( + tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), + cutlass.Float32 + ) + + acc_dV.fill(0.0) + acc_dK.fill(0.0) + + mma_one_m_block_all = partial(self.mma_one_m_block, + tiled_mma_SdP=tiled_mma_SdP, tiled_mma_dKV=tiled_mma_dKV, tiled_mma_dQaccum=tiled_mma_dQaccum, + pipeline_q=pipeline_q, pipeline_lse=pipeline_lse, + pipeline_dPsum=pipeline_dPsum, pipeline_dO=pipeline_dO, + tLSEsLSE_2D=tLSEsLSE_2D, tdPsumsdPsum_2D=tdPsumsdPsum_2D, sP=sP, sdS=sdS, sdQaccum=sdQaccum, acc_dV=acc_dV, acc_dK=acc_dK, + tSrQ=tSrQ, tSrK=tSrK, + tPsP=tPsP, tdSsdS=tdSsdS, + tdVrPt=tdVrPt, tdVrdOt=tdVrdOt, + tdKrdSt=tdKrdSt, tdKrQt=tdKrQt, + tdPrdO=tdPrdO, tdPrV=tdPrV, + tdQaccumrdS=tdQaccumrdS, tdQaccumrK=tdQaccumrK, tdQaccumsdQaccum=tdQaccumsdQaccum, + smem_thr_copy_PdS=smem_thr_copy_PdS, + smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, + ) + + KV_consumer_phase = cutlass.Int32(0) + consumer_state = pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.num_stages) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + + cute.arch.mbarrier_wait(mbar_ptr_K, phase=KV_consumer_phase) + cute.arch.mbarrier_wait(mbar_ptr_V, phase=KV_consumer_phase) + + KV_consumer_phase ^= 1 + + for m_block in cutlass.range(m_block_max - m_block_min, unroll=1): + m_block_idx = m_block_max - 1 - m_block + + consumer_state = mma_one_m_block_all( + warp_group_idx, + n_block, + m_block_idx, + head_idx, + batch_idx, + consumer_state, + softmax_scale_log2=softmax_scale_log2, + ) + + #scale dK + acc_dK.store(acc_dK.load() * softmax_scale) + + self.epilogue_dKV( + acc_dV, mdV, sV, + acc_dK, mdK, sK, + seqlen, + gmem_tiled_copy_dV, gmem_tiled_copy_dK, + tiled_mma_dKV, + tidx, n_block, head_idx, batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def mma_one_m_block( + self, + warp_group_idx, + n_block: cutlass.Int32, + m_block: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, + tiled_mma_dQaccum: cute.TiledMma, + + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, + pipeline_dPsum: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + + tLSEsLSE_2D: cute.Tensor, + tdPsumsdPsum_2D: cute.Tensor, + sP: Optional[cute.Tensor], + sdS: Optional[cute.Tensor], + sdQaccum: cute.Tensor, + + acc_dV: cute.Tensor, + acc_dK: cute.Tensor, + + + tSrQ: cute.Tensor, + tSrK: cute.Tensor, + + tPsP: Optional[cute.Tensor], + tdSsdS: Optional[cute.Tensor], + + tdVrPt: cute.Tensor, + tdVrdOt: cute.Tensor, + + tdKrdSt: cute.Tensor, + tdKrQt: cute.Tensor, + + tdPrdO: cute.Tensor, + tdPrV: cute.Tensor, + tdQaccumrdS: cute.Tensor, + tdQaccumrK: cute.Tensor, + tdQaccumsdQaccum: cute.Tensor, + + smem_thr_copy_PdS: cute.TiledCopy, + smem_thr_copy_dQaccum: cute.TiledCopy, + softmax_scale_log2: cutlass.Float32 = 1.0, + ): + + + # (1) [GEMM 1] S = Q @ K^T + pipeline_q.consumer_wait(smem_pipe_read, pipeline_q.consumer_try_wait(smem_pipe_read)) + acc_S = cute.make_fragment( + tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), + cutlass.Float32 + ) + + sm90_utils.gemm( + tiled_mma_SdP, acc_S, + tSrQ[None, None, None, smem_pipe_read.index], + tSrK, + zero_init=True, + wg_wait=0 + ) + + # (2) [Pointwise 1] P = exp(S - LSE) + pipeline_lse.consumer_wait(smem_pipe_read, pipeline_lse.consumer_try_wait(smem_pipe_read)) + + tLSErLSE = cute.make_fragment_like(tLSEsLSE_2D[None, 0]) + cute.autovec_copy(tLSEsLSE_2D[None, smem_pipe_read.index], tLSErLSE) + + acc_P_mn = utils.make_acc_tensor_mn_view(acc_S) + for r in cutlass.range_constexpr(cute.size(acc_P_mn, mode=[0])): + acc_P_mn[r, None].store(cute.exp2(acc_P_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) + + # fp32->bf16 + tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) + utils.cvt_f16(tdVrP_acc, tdVrP) + + # cp: rmem->smem + tPrP = smem_thr_copy_PdS.retile(tdVrP) + + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + cute.copy(smem_thr_copy_PdS, tPrP, tPsP) + + + ''' + if warp_group_idx == 0 and cute.arch.thread_idx()[0] == 128 and m_block == 0 and n_block == 0 and head_idx == 0 and batch_idx == 0: + for j in cutlass.range_constexpr(16): + cute.printf("%.15f", tPrP[j].to(cutlass.Float32)) + ''' + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + + pipeline_lse.consumer_release(smem_pipe_read) + + + # (3) [GEMM 2] dP = dO @ V.T + pipeline_dO.consumer_wait(smem_pipe_read, pipeline_dO.consumer_try_wait(smem_pipe_read)) + acc_dP = cute.make_fragment( + tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), + cutlass.Float32 + ) + + sm90_utils.gemm( + tiled_mma_SdP, acc_dP, + tdPrdO[None, None, None, smem_pipe_read.index], + tdPrV, + zero_init=True, + wg_wait=-0 + ) + + # (4) [GEMM 3] dV += P.T @ dO + sm90_utils.gemm( + tiled_mma_dKV, acc_dV, + tdVrPt, + tdVrdOt[None, None, None, smem_pipe_read.index], + zero_init=False, + wg_wait=0 + ) + + pipeline_dO.consumer_release(smem_pipe_read) + + # (4) [Pointwise 2] dS = P*(dP-dPsum) + pipeline_dPsum.consumer_wait(smem_pipe_read, pipeline_dPsum.consumer_try_wait(smem_pipe_read)) + + # dPsum + tdPsumrdPsum = cute.make_fragment_like(tdPsumsdPsum_2D[None, 0]) + cute.autovec_copy(tdPsumsdPsum_2D[None, smem_pipe_read.index], tdPsumrdPsum) + + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) + for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + acc_dP_mn[r, None].store( + acc_P_mn[r, None].load() * (acc_dP_mn[r, None].load() - tdPsumrdPsum[r]) + ) + + # fp32->bf16 + tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) + tdKrdS = cute.make_fragment_like(tdKrdS_acc, self.dtype) + utils.cvt_f16(tdKrdS_acc, tdKrdS) + + tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) + + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + + cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + + pipeline_dPsum.consumer_release(smem_pipe_read) + + + + # (6) [GEMM 4] dQ = dS @ K + acc_dQ = cute.make_fragment( + tiled_mma_dQaccum.partition_shape_C((self.m_block_size, self.head_dim_padded)), + cutlass.Float32 + ) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) + sm90_utils.gemm( + tiled_mma_dQaccum, acc_dQ, + tdQaccumrdS, + tdQaccumrK, + zero_init=True, + wg_wait=0 + ) + + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) + cute.arch.barrier(barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + + tdQaccumrdQaccum_tmp = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQaccumsdQaccum.shape)) + cute.copy(smem_thr_copy_dQaccum, tdQaccumrdQaccum_tmp, tdQaccumsdQaccum) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier_arrive(barrier_id=int(NamedBarrierBwd.dQFull), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + + # (7) [GEMM 5] dK += dS.T @ Q + sm90_utils.gemm( + tiled_mma_dKV, acc_dK, + tdKrdSt, + tdKrQt[None, None, None, smem_pipe_read.index], + zero_init=False, + wg_wait=0 + ) + pipeline_q.consumer_release(smem_pipe_read) + + smem_pipe_read.advance() + return smem_pipe_read + + + @cute.jit + def epilogue_dKV( + self, + acc_dV: cute.Tensor, + mdV: cute.Tensor, + sV: cute.Tensor, + + acc_dK: cute.Tensor, + mdK: cute.Tensor, + sK: cute.Tensor, + + + seqlen: SeqlenInfoQK, + + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + + tiled_mma_dKV: cute.TiledMma, + + tidx: cutlass.Int32, + n_block: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32 + ): + + ### RMEM --> SMEM + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) + + rdK = cute.make_fragment_like(acc_dK, self.dtype) + rdK.store(acc_dK.load().to(self.dtype)) + + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads) + + + smem_copy_atom_dKV = cute.make_copy_atom(cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dKV).get_slice(tidx) + + + taccdVrdV = smem_thr_copy_dKV.retile(rdV) + taccdVsdV = smem_thr_copy_dKV.partition_D(sV) # reuse sV SMEM + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + + taccdKrdK = smem_thr_copy_dKV.retile(rdK) + taccdKsdK = smem_thr_copy_dKV.partition_D(sK) # reuse sK SMEM + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + + + # SMEM -> GMEM + cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + mdV_cur = mdV[None, None, head_idx, batch_idx] + + cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + mdK_cur = mdK[None, None, head_idx, batch_idx] + + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads) + gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) + + tdVsdV = gmem_thr_copy_dV.partition_S(sV) + tdVrdV = cute.make_fragment_like(tdVsdV, self.dtype) + cute.autovec_copy(tdVsdV, tdVrdV) + + tdKsdK = gmem_thr_copy_dK.partition_S(sK) + tdKrdK = cute.make_fragment_like(tdKsdK, self.dtype) + cute.autovec_copy(tdKsdK, tdKrdK) + + gdV = cute.local_tile(mdV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + tdVgdV = gmem_thr_copy_dV.partition_D(gdV) + + gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + tdKgdK = gmem_thr_copy_dK.partition_D(gdK) + + tdVcdV = gmem_thr_copy_dV.partition_S(cdV) + t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) + tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[1]) + + tdKcdK = gmem_thr_copy_dK.partition_S(cdK) + tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[1]) + + for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): + row_idx = n_block * self.n_block_size + t0dVcdV[0, rest_m, 0][0] + if row_idx < seqlen.seqlen_k: + cute.copy( + gmem_tiled_copy_dV, + tdVrdV[None, rest_m, None], + tdVgdV[None, rest_m, None], + pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, + ) + cute.copy( + gmem_tiled_copy_dK, + tdKrdK[None, rest_m, None], + tdKgdK[None, rest_m, None], + pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + + @cute.jit + def dQaccum_writer( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + TileSchedulerCls: cutlass.Constexpr[Callable], + SeqlenInfoCls: cutlass.Constexpr[Callable], + ): + + tile_elems = cute.cosize(sdQaccum.layout) + tile_bytes = cutlass.Int32(tile_elems * 4) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + # GMEM + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + + base_flat = cute.domain_offset( + (seqlen.offset_q * self.head_dim_padded, ), + mdQaccum_cur + ) + + m_block_min = cutlass.Int32(0) + m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) + + for it_m in cutlass.range(m_block_max - m_block_min, unroll=1): + m_block = m_block_max -1 - it_m + + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFull), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE + ) + + gdQaccum_block = cute.local_tile( + base_flat, + (tile_elems, ), + (m_block, ) + ) + + with cute.arch.elect_one(): + sm90_utils.tma_reduce_add_bulk_f32( + sdQaccum.iterator, + gdQaccum_block.iterator, + tile_bytes, + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmpty), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def load_m_tile( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + pipeline: cutlass.pipeline.PipelineAsync, + block: cutlass.Int32, + producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + ): + pipeline.producer_acquire(producer_state) + cute.copy( + tma_atom, + tXgX[None, block], + tXsX[None, producer_state.index], + tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 3a57e43da08..acb0273effd 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -3,6 +3,9 @@ import cutlass.cute as cute from cutlass.cute.nvgpu import warpgroup +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import dsl_user_op + @cute.jit def gemm( @@ -29,3 +32,23 @@ def gemm( warpgroup.commit_group() if cutlass.const_expr(wg_wait >= 0): warpgroup.wait_group(wg_wait) + + +@dsl_user_op +def tma_reduce_add_bulk_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: cutlass.Int32, + *, loc=None, ip=None + ): + cute.make_mma_atom + smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_u32, store_bytes.ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 99a76222bce..5a7f52e7497 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -10,3 +10,16 @@ class NamedBarrierFwd(enum.IntEnum): WarpSchedulerWG3 = enum.auto() PFull = enum.auto() PEmpty = enum.auto() + + +class NamedBarrierBwd(enum.IntEnum): + Epilogue = enum.auto() + WarpSchedulerWG1 = enum.auto() + WarpSchedulerWG2 = enum.auto() + WarpSchedulerWG3 = enum.auto() + PdS = enum.auto() + #dQEmpty = 9 + #dQEmpty = 9 + + dQFull = enum.auto() + dQEmpty = enum.auto() From 8ecf128f683266735ba68e3c106ff67a2611886e Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 16 Sep 2025 22:41:30 -0700 Subject: [PATCH 100/258] [Cute] Make testing utils standlone for cute (#1892) --- flash_attn/cute/testing.py | 404 ++++++++++++++++++++++++++++++++++ tests/cute/test_flash_attn.py | 9 +- 2 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 flash_attn/cute/testing.py diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py new file mode 100644 index 00000000000..690d0145479 --- /dev/null +++ b/flash_attn/cute/testing.py @@ -0,0 +1,404 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + grad_values = grad_output[indices] + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + + if zero_lengths: + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + qv=None, + kvpacked=False, + qkvpacked=False, + query_unused_mask=None, + key_unused_mask=None, +): + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + learnable_sink - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + if key_padding_mask is not None: + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index f3042f07635..a654e90d23e 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -12,8 +12,13 @@ except ImportError: apply_rotary_emb = None -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.utils.testing import attention_ref, generate_qkv, generate_random_padding_mask +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine From 589cc20db3a982c8427bb19b42cf146a1a302bc1 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 16 Sep 2025 22:43:24 -0700 Subject: [PATCH 101/258] Bump pin for CuTeDSL (#1891) --- flash_attn/cute/interface.py | 2 +- flash_attn/cute/mask.py | 9 +++++++++ flash_attn/cute/pyproject.toml | 4 ++-- flash_attn/cute/softmax.py | 31 ++++++++++++++++++++++++++++++- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b02d1e91be6..f25125c2cc3 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 28c019db7b3..0f99add2cce 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -76,6 +76,12 @@ def apply_mask( causal_row_offset = ( 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset ) + c = 0 + col_limit_transformed = 0 + ncol: cute.Constexpr = 0 + col_limit_right_s = 0 + mask = 0 + in_bound = False if cutlass.const_expr(mask_causal): for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. @@ -113,6 +119,7 @@ def apply_mask( if cutlass.const_expr(self.window_size_left is not None) else None ) + c = 0 for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size @@ -133,6 +140,7 @@ def apply_mask( # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): col_idx = t0ScS_mn[0, c][1] + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) # only consider the column index, so the row index sets to 0. if col_idx >= col_limit_right or col_idx < col_limit_left: acc_S_mn[r, c] = -cutlass.Float32.inf @@ -193,6 +201,7 @@ def apply_mask_sm100( row_idx = tScS_t2r[0][0] + m_block * self.m_block_size if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa + c = 0 if cutlass.const_expr(mask_causal): col_limit_right = row_idx + causal_row_offset if cutlass.const_expr(mask_seqlen): diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 8c4d89e52e1..f53acf1a3df 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.1.0", + "nvidia-cutlass-dsl==4.2.0", "torch", "einops", ] @@ -47,4 +47,4 @@ ignore = [ "E731", # do not assign a lambda expression, use a def "E741", # Do not use variables named 'I', 'O', or 'l' "F841", # local variable is assigned to but never used -] \ No newline at end of file +] diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 6d8135d6461..2821a8e22f3 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -3,6 +3,7 @@ import math import operator from typing import Tuple +from dataclasses import dataclass import cutlass import cutlass.cute as cute @@ -19,9 +20,32 @@ def __init__( arch: cutlass.Constexpr[int] = 80, ): self.scale_log2 = scale_log2 + self.num_rows = num_rows + self.arch = arch self.row_max = cute.make_fragment(num_rows, Float32) self.row_sum = cute.make_fragment_like(self.row_max) - self.arch = arch + + def __extract_mlir_values__(self): + non_constexpr_fields = [self.scale_log2, self.row_max, self.row_sum] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + field_names = ['scale_log2', 'row_max', 'row_sum'] + reconstructed_fields = {} + for name, n_items in zip(field_names, self._values_pos): + original_field = getattr(self, name) + reconstructed_fields[name] = cutlass.new_from_mlir_values(original_field, values[:n_items]) + values = values[n_items:] + + new_obj = self.__class__(reconstructed_fields['scale_log2'], self.num_rows, self.arch) + new_obj.row_max = reconstructed_fields['row_max'] + new_obj.row_sum = reconstructed_fields['row_sum'] + return new_obj def reset(self) -> None: self.row_max.fill(-Float32.inf) @@ -131,6 +155,11 @@ def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[flo super().__init__(scale_log2, num_rows=1, arch=100) self.rescale_threshold = rescale_threshold + def __new_from_mlir_values__(self, values): + new_obj = super().__new_from_mlir_values__(values) + new_obj.rescale_threshold = self.rescale_threshold + return new_obj + @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): From 5c1627a7a1cda9c32cb9b937a053564e663f81bc Mon Sep 17 00:00:00 2001 From: jayhshah Date: Wed, 17 Sep 2025 14:58:45 -0700 Subject: [PATCH 102/258] Improve causal backward determinism perf with SPT schedule (#1893) * add spt scheduler for causal bwd determinism * add new torch check for det hdim 256 to stable api --- hopper/epilogue_bwd.hpp | 11 +- hopper/flash_api.cpp | 1 + hopper/flash_api_stable.cpp | 1 + hopper/flash_bwd_launch_template.h | 14 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 14 +- hopper/test_flash_attn_bwd_determinism.py | 706 ++++++++++++++++++++++ hopper/tile_scheduler.hpp | 56 +- 7 files changed, 773 insertions(+), 30 deletions(-) create mode 100644 hopper/test_flash_attn_bwd_determinism.py diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index 6d9b5f4f596..fdae7616683 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -109,6 +109,7 @@ struct CollectiveEpilogueBwd { Element* ptr_dV; ShapedKV const shape_dV; StridedKV const stride_dV; + int const num_batch; int const num_heads_q; int* dk_semaphore; int* dv_semaphore; @@ -369,7 +370,8 @@ struct CollectiveEpilogueBwdGQA { ElementAccum* ptr_dVaccum; ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; - int num_heads_q; + int const num_batch; + int const num_heads_q; int* dk_semaphore; int* dv_semaphore; int const* cu_seqlens; @@ -387,6 +389,7 @@ struct CollectiveEpilogueBwdGQA { cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; int* dv_semaphore; + int const num_batch; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; @@ -400,7 +403,7 @@ struct CollectiveEpilogueBwdGQA { return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, - args.cu_seqlens, args.seqused}; + args.num_batch, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -449,8 +452,8 @@ struct CollectiveEpilogueBwdGQA { cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); } - // int const num_batch = params.num_batch; - int const num_batch = get<2>(params.shape_dKaccum); + int const num_batch = params.num_batch; + // int const num_batch = get<2>(params.shape_dKaccum); // erroneously returns 1 for varlen int const num_head_kv = get<1>(params.shape_dKaccum); int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; using Barrier = cutlass::GenericBarrier; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index adb53fdab6b..8d0b2438acc 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1361,6 +1361,7 @@ std::tuplemajor * 10 + at::cuda::getCurrentDeviceProperties()->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; + TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 42601e5692d..32b9d226e2d 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1426,6 +1426,7 @@ std::tuple mha_b int const arch = dprops->major * 10 + dprops->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; + STD_TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index b6e8810b25f..6df3231cdd4 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -94,8 +94,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { flash::CollectiveEpilogueBwdGQA >; using Scheduler = std::conditional_t< - Is_causal && !Varlen, - flash::SingleTileBwdLPTScheduler, + Is_causal, + flash::SingleTileBwdLPTScheduler, flash::SingleTileScheduler >; using AttnKernel = std::conditional_t< @@ -165,6 +165,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), + params.b, params.h, params.dk_semaphore, params.dv_semaphore, @@ -301,10 +302,11 @@ template(params, stream); - run_flash_bwd(params, stream); -// }); + BOOL_SWITCH(params.deterministic, Deterministic_, [&] { + static constexpr bool Deterministic = Deterministic_ && kHeadDim < 256; + // run_flash_bwd(params, stream); + run_flash_bwd(params, stream); + }); }); }); } diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index ec34e20eca1..0232b90e54a 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -607,7 +607,8 @@ struct CollectiveMainloopBwdSm90 { seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { + // Though if local and deterministic, still need to increment dq semaphore + if constexpr ((Is_causal || Is_local || Varlen) && !(Is_local && Deterministic)) { if (m_block_max <= m_block_min) { return; } } @@ -626,10 +627,18 @@ struct CollectiveMainloopBwdSm90 { using Barrier = cutlass::GenericBarrier; bool const lane_predicate = cute::elect_one_sync(); int m_block = m_block_min; + constexpr int kBlockM = get<0>(TileShape_MNK{}); + constexpr int kBlockN = get<1>(TileShape_MNK{}); + int n_block_global_max = cute::ceil_div(seqlen_info.seqlen_k, kBlockN); #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + if constexpr(Is_causal) { + int n_block_max_for_m_block = std::min(n_block_global_max, cute::ceil_div((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q, kBlockN)); + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block_max_for_m_block - 1 - n_block); + } else { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } } #pragma unroll for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { @@ -649,7 +658,6 @@ struct CollectiveMainloopBwdSm90 { } } if constexpr (Is_local && Deterministic) { - constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { diff --git a/hopper/test_flash_attn_bwd_determinism.py b/hopper/test_flash_attn_bwd_determinism.py new file mode 100644 index 00000000000..b443c8948d4 --- /dev/null +++ b/hopper/test_flash_attn_bwd_determinism.py @@ -0,0 +1,706 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + +from flash_attn_interface import _flash_attn_backward + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +# deterministic mode not supported for hdim 256 +DISABLE_HDIM256 = True + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + # (4224, 4224), + # (8192, 8192), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out, softmax_lse = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + return_attn_probs=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq, dk, dv, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dv2 = torch.empty_like(dv) + dq2, dk2, dv2, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq2, + dk2, + dv2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + print(f"✅ Iteration {i} passed!") + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (1024, 1024), + (2048, 2048), + (4096, 4096), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, +): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # batch_size = 2 + # nheads = 1 + # nheads_kv = nheads + + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + print("cu_seqlens_q: ", cu_seqlens_q) + print("cu_seqlens_k: ", cu_seqlens_k) + print("seqused_q: ", seqused_q) + print("seqused_k: ", seqused_k) + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out_unpad, softmax_lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad = torch.empty_like(q_unpad) + dk_unpad = torch.empty_like(k_unpad) + dv_unpad = torch.empty_like(v_unpad) + dq_unpad, dk_unpad, dv_unpad, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad, + dk_unpad, + dv_unpad, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + print(dq_unpad.shape) + print(dk_unpad.shape) + print(dv_unpad.shape) + + print(dq.shape) + print(dk.shape) + print(dv.shape) + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq_unpad2 = torch.empty_like(q_unpad) + dk_unpad2 = torch.empty_like(k_unpad) + dv_unpad2 = torch.empty_like(v_unpad) + dq_unpad2, dk_unpad2, dv_unpad2, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad2, + dk_unpad2, + dv_unpad2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + + dq2 = dq_pad_fn(dq_unpad2) + dk2 = dk_pad_fn(dk_unpad2) + dv2 = dk_pad_fn(dv_unpad2) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk2.masked_fill_(k_zero_masking, 0.0) + dv2.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq2.masked_fill_(q_zero_masking, 0.0) + + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + + print(f"✅ Iteration {i} passed!") \ No newline at end of file diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 3c9e42996b0..241eaed40f8 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -364,6 +364,7 @@ class DynamicPersistentTileScheduler { /////////////////////////////////////////////////////////////////////////////// +template class SingleTileBwdLPTScheduler { public: @@ -373,10 +374,13 @@ class SingleTileBwdLPTScheduler { // Device side kernel params struct Params { int const total_blocks; - cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const block_divmod, head_divmod; cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; + int const seqlen; + int const* const cu_seqlens; + int const* const seqused; }; static Params @@ -401,7 +405,8 @@ class SingleTileBwdLPTScheduler { cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), - (args.num_head * args.num_batch) / swizzle}; + (args.num_head * args.num_batch) / swizzle, + args.seqlen, !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; } static dim3 @@ -410,28 +415,19 @@ class SingleTileBwdLPTScheduler { } struct WorkTileInfo { - int tile_idx; + int block; + int bidh; + int bidb; CUTLASS_DEVICE bool is_valid(Params const& params) const { - return tile_idx < params.total_blocks; + return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - int block, bidh, bidb; - int l2_mod, bidhb, bidhb_residual; - bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - if (bidhb < params.num_hb_quotient) { - block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); - } else { - block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); - } - bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); return {block, bidh, bidb, 0 /*split_idx*/}; } @@ -444,7 +440,33 @@ class SingleTileBwdLPTScheduler { CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { - return {int(blockIdx.x)}; + int tile_idx = blockIdx.x; + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + bool is_valid_tile = true; + int num_blocks; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[bidb] + : (params.cu_seqlens ? params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] : params.seqlen); + num_blocks = cute::ceil_div(seqlen, Int{}); + is_valid_tile = block < num_blocks; + } else { + num_blocks = params.block_divmod.divisor; + } + if constexpr (SPT) { + block = num_blocks - block - 1; + } + return {block, bidh, is_valid_tile ? bidb : -1}; } CUTLASS_DEVICE @@ -459,7 +481,7 @@ class SingleTileBwdLPTScheduler { CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {params.total_blocks}; + return {0, 0, -1}; } }; From 1ceaa984b2f348caea18b39a98458d33b4ea7a09 Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 23 Sep 2025 22:51:43 +0200 Subject: [PATCH 103/258] Upgrade to cutlass v4.2.1 (#1905) --- flash_attn/cute/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index f53acf1a3df..0c34f83f1cf 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.2.0", + "nvidia-cutlass-dsl==4.2.1", "torch", "einops", ] From 3b24b08d1af944189e14c2c54816e6f8b78bbbe2 Mon Sep 17 00:00:00 2001 From: brandonsun Date: Thu, 25 Sep 2025 00:09:30 +0800 Subject: [PATCH 104/258] switch to use cutlass.utils.get_smem_capacity_in_bytes instead of deprecated cutlass.utils.ampere_helpers.SMEM_CAPACITY (#1906) --- flash_attn/cute/flash_bwd.py | 4 ++-- flash_attn/cute/flash_fwd.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 619e0408cd4..a6d061b19b5 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -11,7 +11,7 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp -import cutlass.utils.ampere_helpers as sm80_utils_basic +import cutlass.utils as utils_basic from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import utils @@ -125,7 +125,7 @@ def can_implement( smem_usage_V = n_block_size * head_dim_v * 2 smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False return True diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d1b307acf02..b70da9a5264 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -16,7 +16,7 @@ import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup -import cutlass.utils.ampere_helpers as sm80_utils_basic +import cutlass.utils as utils_basic import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import ampere_helpers as sm80_utils @@ -127,7 +127,7 @@ def can_implement( smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 - smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False # Check if twice the block size is divisible by the number of threads From 0165c96fff7a7cd2e152aa9659f75c972a702f5d Mon Sep 17 00:00:00 2001 From: JackCharlesZhang <113156832+JackCharlesZhang@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:34:03 -0700 Subject: [PATCH 105/258] Add Missing None Gradient in FA3 QKVPacked (#1908) Co-authored-by: Jack Zhang --- hopper/flash_attn_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index a435e7a627d..1158ee02ad2 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -248,7 +248,7 @@ def backward(ctx, dout, *args): ctx.sm_margin, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): From add175637c5d54b74bc25372e49ce282d6f236fc Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 25 Sep 2025 10:22:47 +0200 Subject: [PATCH 106/258] C++11 fix warnings (#1904) * errors are with C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char). * errors are with C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char). * errors are with C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char). * errors are with C++11 narrowing warnings (treated as errors in strict builds) when initializing at::cuda::CUDAGuard with a non-constant char cast to c10::DeviceIndex (signed char). * Update flash_api_stable.cpp * upstream cutlass v4.2.1 csrc --- csrc/cutlass | 2 +- hopper/flash_api.cpp | 12 +++++++++--- hopper/flash_api_stable.cpp | 12 ++++++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index dc4817921ed..c6aeb9179c5 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit dc4817921edda44a549197ff3a9dcf5df0636e7b +Subproject commit c6aeb9179c5f74a0fcdbd28527bf4b6ba8c60752 diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8d0b2438acc..0233da799f2 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -41,6 +41,12 @@ PyObject* PyInit__C(void) #define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 +namespace { +inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { + return at::cuda::CUDAGuard(static_cast(t.get_device())); +} +} // namespace + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -609,7 +615,7 @@ mha_fwd_get_scheduler_metadata( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(seqused_k); auto opts = seqused_k.options(); // This needs to be set after get_num_splits @@ -876,7 +882,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); at::Tensor softmax_lse; if (!is_varlen_q) { @@ -1463,7 +1469,7 @@ std::tuple using torch::stable::Tensor; +namespace tsa = torch::stable::accelerator; namespace { +inline tsa::DeviceGuard make_device_guard(const Tensor& t) { + return tsa::DeviceGuard(static_cast(t.get_device())); +} std::deque device_flags; std::vector device_properties; @@ -673,7 +677,7 @@ mha_fwd_get_scheduler_metadata( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - torch::stable::accelerator::DeviceGuard device_guard{(char)seqused_k.get_device()}; + auto device_guard = make_device_guard(seqused_k); // This needs to be set after get_num_splits Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic @@ -939,7 +943,7 @@ mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_ // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - torch::stable::accelerator::DeviceGuard device_guard{(char)q.get_device()}; + auto device_guard = make_device_guard(q); Tensor softmax_lse; if (!is_varlen_q) { @@ -1528,7 +1532,7 @@ std::tuple mha_b // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - torch::stable::accelerator::DeviceGuard device_guard{(char)q.get_device()}; + auto device_guard = make_device_guard(q); // auto opts = q.options(); // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 @@ -1691,7 +1695,7 @@ mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x nu // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - torch::stable::accelerator::DeviceGuard device_guard{(char)out_partial.get_device()}; + auto device_guard = make_device_guard(out_partial); auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float)); softmax_lse = torch::stable::transpose(softmax_lse, 1, 2); From cc0a79b87c42dfbb74c23fdc97d87e2ff720f5e1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Sep 2025 21:16:54 -0400 Subject: [PATCH 107/258] [Cute] Write ex2 emulation in a more readable form --- flash_attn/cute/softmax.py | 11 +-- flash_attn/cute/utils.py | 166 ++++++++++++++++++++++++++++++------- 2 files changed, 143 insertions(+), 34 deletions(-) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 2821a8e22f3..3bfa3a3363c 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -198,7 +198,7 @@ def scale_subtract_rowmax( assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): - acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (-row_max_scaled, -row_max_scaled), @@ -235,7 +235,8 @@ def apply_exp2_convert( acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) else: - acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) @@ -250,14 +251,14 @@ def scale_apply_exp2_convert( assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" minus_row_max_scaled = -row_max * self.scale_log2 for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): - acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (minus_row_max_scaled, minus_row_max_scaled), ) # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): - # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + # acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( # (acc_S_row[i], acc_S_row[i + 1]), # (self.scale_log2, self.scale_log2), # (minus_row_max_scaled, minus_row_max_scaled), @@ -276,7 +277,7 @@ def scale_apply_exp2_convert( for j in cutlass.range_constexpr(frg_cnt): for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( - # cute.arch.fma_packed_f32x2( + # utils.fma_packed_f32x2( # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), # (self.scale_log2, self.scale_log2), # (minus_row_max_scaled, minus_row_max_scaled), diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 0a26fc9866f..0f3b2bd5533 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -2,16 +2,29 @@ import math from typing import Type, Callable, Optional, Tuple +from functools import partial import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32 +from cutlass import Float32, Int32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm, arith, vector from cutlass.cute.runtime import from_dlpack +# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default +fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN +) + + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -25,7 +38,7 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if cutlass.const_expr(swapAB): + if const_expr(swapAB): return cute.make_tiled_copy_B(copy_atom, tiled_mma) else: return cute.make_tiled_copy_A(copy_atom, tiled_mma) @@ -34,7 +47,7 @@ def make_tiled_copy_A( def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: - if cutlass.const_expr(swapAB): + if const_expr(swapAB): return cute.make_tiled_copy_A(copy_atom, tiled_mma) else: return cute.make_tiled_copy_B(copy_atom, tiled_mma) @@ -43,7 +56,7 @@ def make_tiled_copy_B( def mma_make_fragment_A( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if cutlass.const_expr(swapAB): + if const_expr(swapAB): return mma_make_fragment_B(smem, thr_mma) else: return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) @@ -52,7 +65,7 @@ def mma_make_fragment_A( def mma_make_fragment_B( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: - if cutlass.const_expr(swapAB): + if const_expr(swapAB): return mma_make_fragment_A(smem, thr_mma) else: return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) @@ -61,7 +74,7 @@ def mma_make_fragment_B( def get_smem_store_atom( arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] ) -> cute.CopyAtom: - if cutlass.const_expr(arch < 90): + if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), element_type, @@ -80,7 +93,7 @@ def warp_reduce( op: Callable, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: - if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + if const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) for i in cutlass.range_constexpr(cute.size(val.shape)): @@ -131,7 +144,7 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) # TODO: Sm90 FP8 - if cutlass.const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 l = cute.logical_divide( acc_layout, ((None, None, 2), None, None) ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) @@ -195,7 +208,7 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: :return: exp2 value :rtype: cute.TensorSSA or Float32 """ - if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + if const_expr(isinstance(x, cute.TensorSSA)): res = cute.make_fragment(x.shape, Float32) res.store(x) for i in cutlass.range_constexpr(cute.size(x.shape)): @@ -244,8 +257,8 @@ def fmax( def fmax_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: - if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): - # if cutlass.const_expr(init_val is None): + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + # if const_expr(init_val is None): # init_val = -cutlass.Float32.if # return x.reduce(cute.ReductionOp.MAX, init_val, 0) res = cute.make_fragment(x.shape, Float32) @@ -255,7 +268,7 @@ def fmax_reduce( # local_max[0] = fmax(local_max[0], res[i + 0]) # local_max[1] = fmax(local_max[1], res[i + 1]) # local_max[0] = fmax(local_max[0], local_max[1]) - # return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) + # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) local_max = [res[0], res[1], res[2], res[3]] for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): local_max[0] = fmax(local_max[0], res[i + 0]) @@ -265,7 +278,7 @@ def fmax_reduce( local_max[0] = fmax(local_max[0], local_max[1]) local_max[2] = fmax(local_max[2], local_max[3]) local_max[0] = fmax(local_max[0], local_max[2]) - return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max # We instead force the 3-input max. @@ -273,7 +286,7 @@ def fmax_reduce( res.store(x) local_max = [ fmax(init_val, res[0], res[1]) - if cutlass.const_expr(init_val is not None) + if const_expr(init_val is not None) else fmax(res[0], res[1]), fmax(res[2], res[3]), fmax(res[4], res[5]), @@ -292,8 +305,8 @@ def fmax_reduce( def fadd_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: - if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): - if cutlass.const_expr(init_val is None): + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): init_val = Float32.zero return x.reduce(cute.ReductionOp.ADD, init_val, 0) # res = cute.make_fragment(x.shape, Float32) @@ -307,25 +320,25 @@ def fadd_reduce( # local_sum[0] += local_sum[1] # local_sum[2] += local_sum[3] # local_sum[0] += local_sum[2] - # return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val + # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val else: res = cute.make_fragment(x.shape, Float32) res.store(x) local_sum_0 = ( - cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) - # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) - if cutlass.const_expr(init_val is not None) + add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) + if const_expr(init_val is not None) else (res[0], res[1]) ) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): - local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) - local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) - local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) - local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) - local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) - local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) - local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + local_sum[0] = add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[2]) return local_sum[0][0] + local_sum[0][1] @@ -395,7 +408,7 @@ def cp_async_mbarrier_arrive_shared( def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: warp_group_idx = cute.arch.thread_idx()[0] // 128 - if cutlass.const_expr(sync): + if const_expr(sync): warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) return warp_group_idx @@ -456,7 +469,7 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> @cute.jit def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: - if cutlass.const_expr(lane is None): + if const_expr(lane is None): lane = cute.arch.lane_idx() # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): @@ -497,6 +510,101 @@ def cvt_f16(src: cute.Tensor, dst: cute.Tensor): dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) +@cute.jit +@dsl_user_op +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@cute.jit +@dsl_user_op +def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + f"add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: + # We assume x <= 127.0 + poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, Float32(-127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version +@dsl_user_op +def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, Float32(-127.0)), cute.arch.fmax(y, Float32(-127.0))) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM) + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) + xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + @dsl_user_op def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: out_f32x2 = llvm.inline_asm( From 5059fd53e602bcc00336bb5cc8a85e50940485cb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Sep 2025 21:33:29 -0400 Subject: [PATCH 108/258] [Cute] Simplify utils.py a bit --- flash_attn/cute/flash_fwd.py | 2 +- flash_attn/cute/utils.py | 38 +++++------------------------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index b70da9a5264..0cb7cc6b500 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1632,7 +1632,7 @@ def scoremod_premask_fn(acc_S): # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, # headdim=mQ.shape[1]) pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 0f3b2bd5533..205ba7de182 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -185,21 +185,6 @@ def transpose_view(a: cute.Tensor) -> cute.Tensor: return cute.composition(a, cute.make_ordered_layout(shape, order=order)) -@dsl_user_op -def exp2f_asm(a: float | Float32, *, loc=None, ip=None) -> Float32: - return Float32( - llvm.inline_asm( - T.f32(), - [Float32(a).ir_value(loc=loc, ip=ip)], - "ex2.approx.ftz.f32 $0, $1;", - "=f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - @cute.jit def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: """exp2f calculation for both vector and scalar. @@ -284,10 +269,9 @@ def fmax_reduce( # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) + local_max_0 = fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1]) local_max = [ - fmax(init_val, res[0], res[1]) - if const_expr(init_val is not None) - else fmax(res[0], res[1]), + local_max_0, fmax(res[2], res[3]), fmax(res[4], res[5]), fmax(res[6], res[7]), @@ -375,7 +359,7 @@ def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> flat_stride = cute.flatten_to_tuple(x.stride) assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) - return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + return x.iterator + offset @cute.jit @@ -394,18 +378,6 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: return tApA -@dsl_user_op -def cp_async_mbarrier_arrive_shared( - mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None -) -> None: - nvvm.cp_async_mbarrier_arrive_shared( - mbar_ptr.llvm_ptr, - noinc=noinc, - loc=loc, - ip=ip, - ) - - def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: warp_group_idx = cute.arch.thread_idx()[0] // 128 if const_expr(sync): @@ -575,7 +547,7 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: # We assume x <= 127.0 poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) fp32_round_int = float(2**23 + 2**22) - x_clamped = cute.arch.fmax(x, Float32(-127.0)) + x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) # The integer floor of x is now in the last 8 bits of x_rounded @@ -592,7 +564,7 @@ def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float # We assume x <= 127.0 and y <= 127.0 poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) fp32_round_int = float(2**23 + 2**22) - xy_clamped = (cute.arch.fmax(x, Float32(-127.0)), cute.arch.fmax(y, Float32(-127.0))) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM) # The integer floor of x & y are now in the last 8 bits of xy_rounded From c485eeade0c3ec9ce186c3640c52c9f1ce090b81 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 1 Oct 2025 18:26:06 -0400 Subject: [PATCH 109/258] [Cute] Remove arith & vector import in utils.py --- flash_attn/cute/blackwell_helpers.py | 3 ++- flash_attn/cute/flash_fwd_sm100.py | 4 ++-- flash_attn/cute/utils.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index ea464168faa..ad5124c04ce 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -16,10 +16,11 @@ def gemm( tCrA: cute.Tensor, tCrB: cute.Tensor, zero_init: bool | cutlass.Boolean = False, -) -> None: +) -> cute.TiledMma: for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + return tiled_mma def i64_to_i32x2(i: int) -> Tuple[int, int]: diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 186b2190318..348fd39f8dd 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1024,7 +1024,7 @@ def mma( # are empty. For subsequent iterations, the wait happened at the end # of the while loop. # 3. gemm - # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) sK_cur = sK[None, None, None, mma_kv_consumer_state.index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem(sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase) @@ -1085,7 +1085,7 @@ def mma( # Don't need to wait for the softmax warp to have finished reading the previous # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si # has been read and Pi has been written. - # sm100_utils.gemm(tiled_mma_qk, tStS0, tSrQs[0], tSrK[None, None, None, Ki_index], zero_init=True) + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) sK_cur = sK[None, None, None, Ki_index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 205ba7de182..c361e347949 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -9,7 +9,7 @@ from cutlass import Float32, Int32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import nvvm, llvm, arith, vector +from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack From cbd2490424179d8acb76a6a062d912a5d760a218 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 7 Oct 2025 14:24:53 -0700 Subject: [PATCH 110/258] [CuteDSL] Fix test (#1925) --- flash_attn/cute/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index c361e347949..2c5bc242a43 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -359,7 +359,14 @@ def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> flat_stride = cute.flatten_to_tuple(x.stride) assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) - return x.iterator + offset + # HACK: we assume that applying the offset does not change the pointer alignment + byte_offset = offset * x.element_type.width // 8 + return cute.make_ptr( + x.element_type, + x.iterator.toint() + byte_offset, + x.memspace, + assumed_align=x.iterator.alignment, + ) @cute.jit From 5183de433587a8aedd2450e9f18166c24521af29 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 7 Oct 2025 21:01:04 -0700 Subject: [PATCH 111/258] Refactors to enable FlexAttention (#1840) * Refactors to enable FlexAttention * Thread throught the buffers to the score_mod * add-test * add fastdivmod * comments * comments --- .gitignore | 1 + flash_attn/cute/flash_fwd.py | 234 ++++++++++--- flash_attn/cute/flash_fwd_sm100.py | 130 ++++++- flash_attn/cute/interface.py | 66 +++- flash_attn/cute/softmax.py | 99 +++++- flash_attn/cute/utils.py | 38 +++ tests/cute/test_score_mod.py | 525 +++++++++++++++++++++++++++++ 7 files changed, 1010 insertions(+), 83 deletions(-) create mode 100644 tests/cute/test_score_mod.py diff --git a/.gitignore b/.gitignore index 1f1f8028863..060470d3c6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.ncu-rep .DS_store +.vscode # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 0cb7cc6b500..3d17df958cc 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -7,7 +7,7 @@ import math from types import SimpleNamespace -from typing import Type, Callable, Optional, Tuple +from typing import Type, Callable, Optional from functools import partial import cuda.bindings.driver as cuda @@ -23,14 +23,14 @@ from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase - +from flash_attn.cute.fast_math import FastDivmod class FlashAttentionForwardBase: @@ -50,6 +50,8 @@ def __init__( num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + has_buffers: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -65,6 +67,8 @@ def __init__( :param num_threads: number of threads :type num_threads: int :param is_causal: is causal + :param score_mod: A callable that takes the attention scores and applies a modification. + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -85,6 +89,12 @@ def __init__( self.num_threads = num_threads self.num_stages = num_stages self.Q_in_regs = Q_in_regs + self.score_mod = score_mod + self.qk_acc_dtype = Float32 + if cutlass.const_expr(has_buffers): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 @staticmethod def can_implement( @@ -256,7 +266,6 @@ def __call__( mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale: Float32, - softcap: Float32, stream: cuda.CUstream, ): """Configures and launches the flash attention kernel. @@ -548,10 +557,10 @@ def __call__( mLSE: Optional[cute.Tensor], stream: cuda.CUstream, softmax_scale: Optional[Float32] = None, - softcap: Optional[Float32] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, + buffers=None, ): """Configures and launches the flash attention kernel. @@ -580,19 +589,25 @@ def __call__( cute.size(mQ.shape[2]), cute.size(mQ.shape[3]), ) - # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - # Right after this, we multiply by log2(e) before applying exp2. - # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) - # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if const_expr(softcap is None): - softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = None + if const_expr(self.score_mod is None): + softmax_scale_log2 = Float32(softmax_scale * LOG2_E) + softmax_scale = None else: - softmax_scale_log2 = softcap * LOG2_E - softcap_val = Float32(softmax_scale / softcap) - + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = Float32(LOG2_E) + softmax_scale = Float32(softmax_scale) + + fastdiv_mods = None + if cutlass.const_expr(buffers is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmod.create(seqlen_q) + seqlen_k_divmod = FastDivmod.create(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.kernel( mQ, mK, @@ -600,7 +615,7 @@ def __call__( mO, mLSE, softmax_scale_log2, - softcap_val, + softmax_scale, window_size_left, window_size_right, self.sQ_layout, @@ -615,6 +630,8 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, + buffers, + fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -631,7 +648,7 @@ def kernel( mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale_log2: Float32, - softcap_val: Optional[Float32], + softmax_scale: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, @@ -646,6 +663,8 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, + buffers=None, + fastdiv_mods=None, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() @@ -750,7 +769,7 @@ def kernel( tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) softmax.reset() # group parameters for compute_one_n_block @@ -768,15 +787,12 @@ def kernel( seqlen=seqlen.seqlen_k) load_V = partial(self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k) - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if const_expr(softcap_val is not None): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) compute_one_n_block = partial( self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, load_K=load_K, load_V=load_V, scoremod_premask_fn=scoremod_premask_fn, + softmax=softmax, load_K=load_K, load_V=load_V, score_mod=self.score_mod, + batch_idx=batch_size, head_idx=num_head, m_block=m_block, buffers=buffers, + fastdiv_mods=fastdiv_mods, ) # /////////////////////////////////////////////////////////////////////////////// @@ -883,7 +899,12 @@ def compute_one_n_block( softmax: Softmax, load_K: Callable, load_V: Callable, - scoremod_premask_fn: Callable, + score_mod: Callable | None, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + buffers=None, + fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, @@ -917,7 +938,19 @@ def load_V_next(): # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) - scoremod_premask_fn(acc_S) + if cutlass.const_expr(score_mod is not None): + self.apply_score_mod( + acc_S, + mma_params.thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax=softmax, + buffers=buffers, + fastdiv_mods=fastdiv_mods, + ) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) def load_K_next(): if n_block - self.num_stages >= 0: @@ -1071,10 +1104,10 @@ def __call__( mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, + buffers=None, ): """Configures and launches the flash attention kernel. @@ -1192,22 +1225,29 @@ def __call__( ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - # Right after this, we multiply by log2(e) before applying exp2. - # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) - # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if const_expr(softcap is None): + if const_expr(self.score_mod is None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = None + softmax_scale = None else: - softmax_scale_log2 = softcap * LOG2_E - softcap_val = Float32(softmax_scale / softcap) + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale if const_expr(window_size_left is not None): window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if cutlass.const_expr(buffers is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmod.create(seqlen_q) + seqlen_k_divmod = FastDivmod.create(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.kernel( tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, tma_tensor_K, @@ -1223,7 +1263,7 @@ def __call__( tma_atom_V, tma_atom_O, softmax_scale_log2, - softcap_val, + softmax_scale, window_size_left, window_size_right, learnable_sink, @@ -1242,6 +1282,8 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + buffers, + fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -1267,7 +1309,7 @@ def kernel( tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, - softcap_val: Optional[Float32], + softmax_scale: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], @@ -1286,6 +1328,8 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + buffers=None, + fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor @@ -1417,11 +1461,13 @@ def kernel( tma_atom_O, tidx, softmax_scale_log2, - softcap_val, + softmax_scale, block_info, SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + buffers, + fastdiv_mods, ) @cute.jit @@ -1538,11 +1584,13 @@ def mma( tma_atom_O: Optional[cute.CopyAtom], tidx: Int32, softmax_scale_log2: Float32, - softcap_val: Float32, + softmax_scale: Optional[Float32], block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, + buffers=None, + fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1587,6 +1635,7 @@ def mma( tiled_mma_qk=tiled_mma_qk, tiled_mma_pv=tiled_mma_pv, tiled_mma_pv_rs=tiled_mma_pv_rs, pipeline_k=pipeline_k, pipeline_v=pipeline_v, mma_params=mma_params, smem_copy_params=smem_copy_params, + thr_mma_qk=thr_mma_qk, check_inf=True, ) @@ -1599,19 +1648,16 @@ def mma( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # if work_tile.is_valid_tile: - # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - # -inf to e.g. -50.0, which can affect the attention softmax. - def scoremod_premask_fn(acc_S): - if const_expr(softcap_val is not None): - acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + m_block, head_idx, batch_idx = work_tile.tile_idx + score_mod = self.score_mod mma_one_n_block = partial( - mma_one_n_block_all, softmax=softmax, scoremod_premask_fn=scoremod_premask_fn + mma_one_n_block_all, softmax=softmax, score_mod=score_mod, + batch_idx=batch_idx, head_idx=head_idx, m_block=m_block, buffers=buffers, + fastdiv_mods=fastdiv_mods ) - - m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( @@ -1653,7 +1699,19 @@ def scoremod_premask_fn(acc_S): zero_init=True, wg_wait=0 ) pipeline_k.consumer_release(kv_consumer_state) - scoremod_premask_fn(acc_S) + # Use vectorized score modification + if cutlass.const_expr(score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block_max - 1, + softmax=softmax, + buffers=buffers, + fastdiv_mods=fastdiv_mods, + ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) @@ -1773,7 +1831,13 @@ def mma_one_n_block( mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, - scoremod_premask_fn: Callable, + score_mod: Callable, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + thr_mma_qk: cute.TiledMma, + buffers=None, + fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, @@ -1791,7 +1855,18 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) - scoremod_premask_fn(acc_S) + if cutlass.const_expr(score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax=softmax, + buffers=buffers, + fastdiv_mods=fastdiv_mods, + ) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) @@ -1832,7 +1907,13 @@ def mma_one_n_block_intrawg_overlap( mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, - scoremod_premask_fn: Callable, + score_mod: Callable, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + thr_mma_qk: cute.TiledMma, + buffers=None, + fastdiv_mods=None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, O_should_accumulate: cutlass.Boolean = True, @@ -1858,7 +1939,18 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) - scoremod_premask_fn(acc_S) + if cutlass.const_expr(score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax=softmax, + buffers=buffers, + fastdiv_mods=fastdiv_mods, + ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) @@ -1890,6 +1982,38 @@ def mma_init(self): number_of_threads=2 * self.num_threads_per_warp_group, ) + @cute.jit + def apply_score_mod( + self, + acc_S, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + buffers=None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax.softmax_scale, + self.vec_size, + self.qk_acc_dtype, + buffers, + fastdiv_mods, + constant_q_idx=None + ) + def warp_scheduler_barrier_sync(self): if const_expr(self.use_scheduler_barrier): cute.arch.barrier( diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 348fd39f8dd..7781e6c3364 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -29,7 +29,7 @@ import flash_attn.cute.utils as utils # import flash_attn.cute.pipeline as pipeline from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA @@ -64,6 +64,8 @@ def __init__( m_block_size: int = 128, n_block_size: int = 128, is_persistent: bool = True, + score_mod: cutlass.Constexpr | None = None, + has_buffers: cutlass.Constexpr = False, ): # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -94,6 +96,11 @@ def __init__( self.pack_gqa = pack_gqa if pack_gqa: assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + self.score_mod = score_mod + if cutlass.const_expr(has_buffers): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False @@ -195,10 +202,10 @@ def __call__( mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, + buffers = None # Not typing for now since conversion behaves a lil funny ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -465,22 +472,30 @@ class SharedStorage: self.shared_storage = SharedStorage - # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. - # Right after this, we multiply by log2(e) before applying exp2. - # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val - # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) - # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if const_expr(softcap is None): + if const_expr(self.score_mod is None): softmax_scale_log2 = softmax_scale * LOG2_E - softcap_val = None + softmax_scale = None else: - softmax_scale_log2 = softcap * LOG2_E - softcap_val = Float32(softmax_scale / softcap) + # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + if const_expr(window_size_left is not None): window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if cutlass.const_expr(buffers is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmod.create(seqlen_q) + seqlen_k_divmod = FastDivmod.create(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + # Launch the kernel synchronously self.kernel( tma_tensor_Q, @@ -498,7 +513,7 @@ class SharedStorage: tma_atom_V, tma_atom_O, softmax_scale_log2, - softcap_val, + softmax_scale, window_size_left, window_size_right, learnable_sink, @@ -511,6 +526,8 @@ class SharedStorage: tiled_mma_qk, tiled_mma_pv, tile_sched_params, + buffers, + fastdiv_mods, ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], @@ -539,7 +556,7 @@ def kernel( tma_atom_V: cute.CopyAtom, tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, - softcap_val: Optional[Float32], + softmax_scale: Float32 | None, window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], @@ -552,6 +569,8 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, + buffers = None, + fastdiv_mods = (None, None), ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -582,6 +601,7 @@ def kernel( storage = smem.allocate(self.shared_storage) mbar_ptr = storage.mbar_ptr.data_ptr() + # Use the first N warps to initialize barriers if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): @@ -779,6 +799,7 @@ def kernel( softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, + softmax_scale=softmax_scale, thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, @@ -788,13 +809,19 @@ def kernel( SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, TileSchedulerCls=TileSchedulerCls, + buffers=buffers, + fastdiv_mods=fastdiv_mods, ) if const_expr(not self.s0_s1_barrier): stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) softmax_loop( stage=stage, - tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), tStS.layout)) + tStSi=cute.make_tensor( + tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), + tStS.layout + ), + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) else: # If there's s0_s1_barrier, it's faster to have 2 WGs having different code @@ -1146,6 +1173,7 @@ def softmax_loop( self, stage: int | Int32, softmax_scale_log2: Float32, + softmax_scale: Float32, thr_mma_qk: cute.core.ThrMma, tStSi: cute.Tensor, sScale: cute.Tensor, @@ -1156,6 +1184,8 @@ def softmax_loop( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, + buffers = None, + fastdiv_mods = (None, None) ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1224,9 +1254,9 @@ def softmax_loop( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask_sm100, m_block=m_block * 2 + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local + mask.apply_mask_sm100, m_block=self.q_stage * m_block + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local ) - softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0) + softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) softmax.reset() softmax_step = partial( @@ -1243,6 +1273,12 @@ def softmax_loop( tStP_r2t=tStP_r2t, sScale=sScale, stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + buffers=buffers, + fastdiv_mods=fastdiv_mods, ) cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) @@ -1330,6 +1366,12 @@ def softmax_step( tStP_r2t: cute.Tensor, sScale: cute.Tensor, stage: int | Int32, + batch_idx: Int32, + head_idx: Int32, + m_block: Int32, + seqlen, + buffers = None, + fastdiv_mods = (None, None), mask_fn: Optional[Callable] = None, is_first: bool = False, ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: @@ -1355,12 +1397,27 @@ def softmax_step( tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape # Wait for Si cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if cutlass.const_expr(self.score_mod is not None): + self.apply_score_mod( + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + buffers, + fastdiv_mods + ) + if const_expr(mask_fn is not None): mask_fn(tSrS_t2r, n_block=n_block) row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) @@ -1907,3 +1964,44 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) + + @cute.jit + def apply_score_mod( + self, + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + buffers=None, + fastdiv_mods=(None, None), + ): + """Apply score modification for SM100 (constant q_idx).""" + # Prepare index tensor with extra partition + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) + tScS = thr_mma_qk.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + + # Shared q_idx for all scores + q_idx_wrapped = tScS_t2r[0][0] + if cutlass.const_expr(buffers is not None): + seqlen_q_divmod, _ = fastdiv_mods + _, q_idx_wrapped = seqlen_q_divmod.divmod(tScS_t2r[0][0]) + + apply_score_mod_inner( + tSrS_t2r, + tScS_t2r, + self.score_mod, + batch_idx, + head_idx, + softmax.softmax_scale, + self.vec_size, + self.qk_acc_dtype, + buffers, + fastdiv_mods, + constant_q_idx=q_idx_wrapped + ) \ No newline at end of file diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f25125c2cc3..fc1c91c0365 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -20,7 +20,7 @@ # - bwd pass optimized for Hopper/Blackwell import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable import torch @@ -49,7 +49,6 @@ def maybe_contiguous(x): torch.float32: cutlass.Float32, } - def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -73,7 +72,22 @@ def _flash_attn_fwd( num_threads: int = 384, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, + score_mod: Callable | None = None, + return_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + buffers: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for FlashAttention. + + Args: + ... + score_mod: A callable that takes the attention scores and applies a modification. + return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate + out: Optional pre-allocated output tensor. If None, will be allocated internally. + lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. + buffers: Some score_mods will want to read from global buffers. This is how we thread them through to the inner kernel. + """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -137,10 +151,25 @@ def _flash_attn_fwd( out_torch_dtype = q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) - out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) requires_grad = q.requires_grad or k.requires_grad or v.requires_grad - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad else None + + if out is None: + out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) + else: + expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) + assert out.shape == expected_out_shape, f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" + assert out.dtype == out_torch_dtype, f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" + assert out.device == device, f"out tensor device {out.device} does not match input device {device}" + assert out.is_cuda, "out tensor must be on CUDA device" + + if lse is None: + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad or return_lse else None + elif lse is not None: + assert lse.shape == lse_shape, f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" + assert lse.dtype == torch.float32, f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" + assert lse.device == device, f"lse tensor device {lse.device} does not match input device {device}" + assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor = [ @@ -173,8 +202,24 @@ def _flash_attn_fwd( if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): pack_gqa = False + if softcap is not None: + assert score_mod is None, "softcap and score_mod cannot be used together" + score_mod = utils.create_softcap_scoremod(softcap) + + if score_mod is not None: + is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None + if is_varlen: + raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + + cute_buffers = None + if buffers is not None: + cute_buffers = [from_dlpack(buf) for buf in buffers] + compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, utils.hash_callable(score_mod) if score_mod is not None else None, + buffers is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, window_size_left is not None, window_size_right is not None, @@ -182,6 +227,7 @@ def _flash_attn_fwd( m_block_size, n_block_size, num_threads, pack_gqa, compute_capability, ) + if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" @@ -200,6 +246,8 @@ def _flash_attn_fwd( num_stages=2, num_threads=num_threads, Q_in_regs=False, + score_mod=score_mod, + has_buffers=buffers is not None, ) elif compute_capability == 10: assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" @@ -211,28 +259,30 @@ def _flash_attn_fwd( is_local=local, pack_gqa=pack_gqa, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, + score_mod=score_mod, + has_buffers=buffers is not None, ) else: raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") # TODO: check @can_implement + # TODO caching for buffers; cute_buffers _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, learnable_sink_tensor, + window_size_left, window_size_right, learnable_sink_tensor, cute_buffers, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, learnable_sink_tensor, + window_size_left, window_size_right, learnable_sink_tensor, cute_buffers ) return out, lse _flash_attn_fwd.compile_cache = {} - def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 3bfa3a3363c..682265b7cc2 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -18,15 +18,17 @@ def __init__( scale_log2: Float32, num_rows: cutlass.Constexpr[int], arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None ): self.scale_log2 = scale_log2 self.num_rows = num_rows self.arch = arch + self.softmax_scale = softmax_scale self.row_max = cute.make_fragment(num_rows, Float32) self.row_sum = cute.make_fragment_like(self.row_max) def __extract_mlir_values__(self): - non_constexpr_fields = [self.scale_log2, self.row_max, self.row_sum] + non_constexpr_fields = [self.scale_log2, self.row_max, self.row_sum, self.softmax_scale] values, self._values_pos = [], [] for obj in non_constexpr_fields: obj_values = cutlass.extract_mlir_values(obj) @@ -35,7 +37,7 @@ def __extract_mlir_values__(self): return values def __new_from_mlir_values__(self, values): - field_names = ['scale_log2', 'row_max', 'row_sum'] + field_names = ['scale_log2', 'row_max', 'row_sum', 'softmax_scale'] reconstructed_fields = {} for name, n_items in zip(field_names, self._values_pos): original_field = getattr(self, name) @@ -45,6 +47,7 @@ def __new_from_mlir_values__(self, values): new_obj = self.__class__(reconstructed_fields['scale_log2'], self.num_rows, self.arch) new_obj.row_max = reconstructed_fields['row_max'] new_obj.row_sum = reconstructed_fields['row_sum'] + new_obj.softmax_scale = reconstructed_fields['softmax_scale'] return new_obj def reset(self) -> None: @@ -151,8 +154,8 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: class SoftmaxSm100(Softmax): - def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0): - super().__init__(scale_log2, num_rows=1, arch=100) + def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None): + super().__init__(scale_log2, num_rows=1, arch=100, softmax_scale=softmax_scale) self.rescale_threshold = rescale_threshold def __new_from_mlir_values__(self, values): @@ -290,3 +293,91 @@ def scale_apply_exp2_convert( acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) + + +@cute.jit +def apply_score_mod_inner( + score_tensor, + index_tensor, + score_mod: cutlass.Constexpr, + batch_idx, + head_idx, + softmax_scale, + vec_size:cutlass.Constexpr, + qk_acc_dtype: cutlass.Constexpr, + buffers, + fastdiv_mods, + constant_q_idx:cutlass.Constexpr, +): + """Shared implementation for applying score modification. + + Args: + score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100) + index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100) + score_mod: The score modification function to apply + batch_idx: Batch index + head_idx: Head index + softmax_scale: Scale to apply + vec_size: Vector size for processing elements + qk_acc_dtype: Data type for accumulator + buffers: Optional buffers for FlexAttention + fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping + constant_q_idx: If provided, use this constant for all q_idx values + If None, compute q_idx per-element + """ + n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) + score_vec = cute.make_fragment(vec_size, qk_acc_dtype) + kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + # SSA values for batch and head (constant across all elements) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) + + # Handle q_idx based on whether it's constant + q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): + for j in cutlass.range(vec_size, unroll_full=True): + score_vec[j] = score_tensor[i + j] * softmax_scale + + # If we will do loads we mod, in order to not read OOB + if cutlass.const_expr(buffers is not None and fastdiv_mods is not None): + if cutlass.const_expr(constant_q_idx is None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + _, q_idx_wrapped = seqlen_q_divmod.divmod(index_tensor[i + j][0]) + q_idx_vec[j] = q_idx_wrapped + else: + _, seqlen_k_divmod = fastdiv_mods + + _, kv_idx_wrapped = seqlen_k_divmod.divmod(index_tensor[i + j][1]) + kv_idx_vec[j] = kv_idx_wrapped + else: + # No bounds checking - direct indexing + if constant_q_idx is None: + q_idx_vec[j] = index_tensor[i + j][0] + kv_idx_vec[j] = index_tensor[i + j][1] + + # Convert to SSA for score_mod call + score_ssa = score_vec.load() + kv_idx_ssa = kv_idx_vec.load() + if cutlass.const_expr(constant_q_idx is None): + q_idx_ssa = q_idx_vec.load() + else: + q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,)) + + buffer_args = [] + if cutlass.const_expr(buffers is not None): + buffer_args = buffers + + post_mod_scores = score_mod( + score_ssa, + batch_idx_ssa, + head_idx_ssa, + q_idx=q_idx_ssa, + kv_idx=kv_idx_ssa, + buffers=buffer_args + ) + + # Write back modified scores + score_vec.store(post_mod_scores) + for j in cutlass.range(vec_size, unroll_full=True): + score_tensor[i + j] = score_vec[j] diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 2c5bc242a43..6d48aca644d 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -1,6 +1,8 @@ # Copyright (c) 2025, Tri Dao. import math +import hashlib +import inspect from typing import Type, Callable, Optional, Tuple from functools import partial @@ -24,6 +26,34 @@ rnd=nvvm.RoundingModeKind.RN ) +def hash_callable(func: Callable) -> str: + """Hash a callable based on the source code or bytecode and closure values.""" + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for cell in func.__closure__: + cell_value = cell.cell_contents + hasher.update(repr(cell_value).encode()) + + return hasher.hexdigest() + + +def create_softcap_scoremod(softcap_val): + inv_softcap = 1.0 / softcap_val + + def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers): + scores = acc_S_SSA * inv_softcap + return scores * cute.math.tanh(scores, fastmath=True) + + return scoremod_premask_fn def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( @@ -676,3 +706,11 @@ def coord_offset_i64( ) new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))) return cute.make_tensor(new_ptr, new_layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """ Convert a scalar to a cute TensorSSA of shape (1,) and given dtype """ + vec = cute.make_fragment(1, dtype) + vec[0] = a + return vec.load() diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py new file mode 100644 index 00000000000..014d7969184 --- /dev/null +++ b/tests/cute/test_score_mod.py @@ -0,0 +1,525 @@ +import pytest +import torch +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import math as mlir_math +import operator +from torch.nn.attention.flex_attention import flex_attention +from flash_attn.cute.interface import _flash_attn_fwd + + +@cute.jit +def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = tSrS_ssa + tSrS_ssa = tmp0 + return tSrS_ssa + + +@cute.jit +def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = operator.ge(tmp0, tmp1) + tmp3 = tSrS_ssa + tmp4 = cute.where(tmp2, tmp3, cute.full_like(tmp3, float("-inf"))) + tSrS_ssa = tmp4 + return tSrS_ssa + + +@cute.jit +def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = tSrS_ssa + tmp1 = q_idx + tmp2 = kv_idx + tmp3 = tmp1 - tmp2 + tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) + tmp5 = tmp4.to(cutlass.Float32) + tmp6 = tmp0 + tmp5 + tSrS_ssa = tmp6 + return tSrS_ssa + + +@cute.jit +def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = tSrS_ssa + tmp1 = q_idx + tmp2 = kv_idx + tmp3 = tmp1 - tmp2 + tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype) + tmp5 = tmp4 * cute.full_like(tmp4, 2) + tmp6 = tmp5.to(cutlass.Float32) + tmp7 = tmp0 + tmp6 + tSrS_ssa = tmp7 + return tSrS_ssa + + +@cute.jit +def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = tSrS_ssa + tmp1 = tmp0 * cute.full_like(tmp0, 2) + tSrS_ssa = tmp1 + return tSrS_ssa + + +@cute.jit +def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = tSrS_ssa + tmp1 = tmp0.to(cutlass.Float32) + tmp2 = h_idx + tmp3 = tmp2 + cute.full_like(tmp2, 1) + tmp4 = tmp3 * cute.full_like(tmp3, -8) + tmp5 = tmp4.to(cutlass.Float32) + tmp6 = tmp5 * cute.full_like(tmp5, 0.125) + tmp7 = tmp6 * cute.full_like(tmp6, 0.6931471805599453) + tmp8 = cute.math.exp2(tmp7 * 1.4426950408889634) + tmp9 = q_idx + tmp10 = kv_idx + tmp11 = tmp9 - tmp10 + tmp12 = cute.TensorSSA(mlir_math.absi(tmp11), tmp11.shape, tmp11.dtype) + tmp13 = tmp12.to(cutlass.Float32) + tmp14 = tmp8 * tmp13 + tmp15 = tmp1 - tmp14 + tSrS_ssa = tmp15 + return tSrS_ssa + + +@cute.jit +def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tmp0 - tmp1 + tmp3 = cute.TensorSSA(mlir_math.absi(tmp2), tmp2.shape, tmp2.dtype) + tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256)) + tmp5 = tSrS_ssa + tmp6 = cute.where(tmp4, tmp5, cute.full_like(tmp5, float("-inf"))) + tSrS_ssa = tmp6 + return tSrS_ssa + + +@cute.jit +def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tSrS_ssa + tmp3 = cute.where( + operator.eq(tmp0 // 64, tmp1 // 64), tmp2, cute.full_like(tmp2, float("-inf")) + ) + tSrS_ssa = tmp3 + return tSrS_ssa + + +@cute.jit +def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + tmp0 = q_idx + tmp1 = kv_idx + tmp2 = tmp0 - tmp1 + tmp3 = operator.ge(tmp2, cute.full_like(tmp2, 0)) + tmp4 = tSrS_ssa + tmp5 = cute.where(tmp3, tmp4, cute.full_like(tmp4, float("-inf"))) + tSrS_ssa = tmp5 + return tSrS_ssa + + +@cute.jit +def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + batch_bias = buffers[0] + + # Detect dtype from buffer element type + dtype = batch_bias.element_type + + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = batch_bias[b_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + bias_val + + +@cute.jit +def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): + head_bias = buffers[0] + pos_bias = buffers[1] + + # Detect dtype from buffer element type + dtype = head_bias.element_type + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_frag[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx) + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_frag[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + + +# Eager reference functions for comparison +def identity_eager(score, b, h, q_idx, kv_idx): + return score + + +def causal_mask_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx >= kv_idx, score, float("-inf")) + + +def relative_bias_eager(score, b, h, q_idx, kv_idx): + return score + torch.abs(q_idx - kv_idx) + + +def relative_bias_v2_eager(score, b, h, q_idx, kv_idx): + return score + 2 * torch.abs(q_idx - kv_idx) + + +def times_two_eager(score, b, h, q_idx, kv_idx): + return score * 2 + + +def alibi_bias_eager(score, b, h, q_idx, kv_idx): + slope = 2 ** (-8 * (h + 1) / 8) + return score - slope * torch.abs(q_idx - kv_idx) + + +def sliding_window_eager(score, b, h, q_idx, kv_idx): + return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) + + +def block_diagonal_eager(score, b, h, q_idx, kv_idx): + q_block = q_idx // 64 + kv_block = kv_idx // 64 + return torch.where(q_block == kv_block, score, float("-inf")) + + +def causal_mask_v2_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) + + +def batch_bias(bias_tensor): + """Per-batch bias (tests batch indexing).""" + + def batch_bias_mod(score, b, h, q_idx, kv_idx): + return score + bias_tensor[b] + + return batch_bias_mod + + +def dual_buffer_bias(head_bias, pos_scale): + """Dual buffer loading (tests loading from 2 separate tensors).""" + + def dual_buffer_mod(score, b, h, q_idx, kv_idx): + head_component = head_bias[h] + pos_component = pos_scale[q_idx] + return score + pos_component + head_component + + return dual_buffer_mod + + +# Test pairs: (cute_jit_function, eager_reference_function) +TEST_PAIRS = [ + (score_mod_1, None), + (score_mod_2, causal_mask_eager), + (score_mod_3, relative_bias_eager), + (score_mod_4, relative_bias_v2_eager), + (score_mod_5, times_two_eager), + (score_mod_6, alibi_bias_eager), + (score_mod_7, sliding_window_eager), + (score_mod_8, block_diagonal_eager), + (score_mod_9, causal_mask_v2_eager), +] + +# Test pairs with buffers: (cute_jit_function, eager_reference_function_factory) +TEST_PAIRS_WITH_BUFFERS = [ + (score_mod_10, batch_bias), + (score_mod_11, dual_buffer_bias), +] + + +def create_tensors( + batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 +): + q = torch.randn(batch_size, num_heads, seqlen_q, dim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + return q, k, v + + +def run_cute_flash(q, k, v, cute_score_mod, buffers=None) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map( + lambda x: x.transpose(1, 2), (q, k, v) + ) + out = torch.empty_like(q_transposed) + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + return_lse=True, + score_mod=cute_score_mod, + out=out, + lse=None, + buffers=buffers, + ) + return out.transpose(1, 2) + + +def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: + if dtype is not None: + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) + return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, num_heads, dtype, score_mod_pair): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod = score_mod_pair + + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_heads, dtype=dtype + ) + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash(q, k, v, cute_score_mod) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_BUFFERS) +def test_cute_vs_flex_attention_with_buffers( + seqlen_q, seqlen_kv, num_heads, dtype, score_mod_pair +): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod_factory = score_mod_pair + + batch_size = 2 + q, k, v = create_tensors( + batch_size=batch_size, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, + num_heads=num_heads, + dtype=dtype, + ) + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + buffers = [buffer] + eager_score_mod = eager_score_mod_factory(buffer) + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + buffers = [head_bias, pos_scale] + eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) + assert head_bias.shape == (num_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash(q, k, v, cute_score_mod, buffers=buffers) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.xfail(raises=NotImplementedError, reason="PackGQA with score_mod not yet supported") +def test_packgqa_with_score_mod(): + """Test that PackGQA works correctly with score_mod index wrapping. + + Without proper index wrapping, q_idx will be in packed space + (0 to qhead_per_kvhead * seqlen_q - 1) instead of logical space (0 to seqlen_q - 1). + This causes causal masking to be incorrect. + """ + torch.random.manual_seed(42) + + batch_size = 2 + seqlen_q = 128 + seqlen_kv = 128 + qhead_per_kvhead = 4 + num_heads_kv = 2 + num_heads = num_heads_kv * qhead_per_kvhead + dtype = torch.bfloat16 + + q = torch.randn(batch_size, num_heads, seqlen_q, 128, device="cuda", dtype=dtype) + k = torch.randn(batch_size, num_heads_kv, seqlen_kv, 128, device="cuda", dtype=dtype) + v = torch.randn(batch_size, num_heads_kv, seqlen_kv, 128, device="cuda", dtype=dtype) + + q_transposed, k_transposed, v_transposed = map( + lambda x: x.transpose(1, 2), (q, k, v) + ) + out_cute = torch.empty_like(q_transposed) + + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + return_lse=True, + score_mod=score_mod_2, + out=out_cute, + lse=None, + pack_gqa=True, + ) + out_cute = out_cute.transpose(1, 2) + + out_ref_fp32 = run_flex_reference(q, k, v, causal_mask_eager, dtype=torch.float32) + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + assert not torch.isnan(out_cute).any(), "Output contains NaN values" + assert torch.isfinite(out_cute).all(), "Output contains infinite values" + assert cute_error <= fwd_atol * 10, ( + f"CuTE error {cute_error:.2e} exceeds tolerance {fwd_atol * 10:.2e}" + ) + + +@pytest.mark.xfail(raises=NotImplementedError, reason="Varlen with score_mod not yet supported") +def test_varlen_with_score_mod(): + """Test that varlen (variable length sequences) works with score_mod. + + For varlen, tokens from different sequences should not attend to each other. + Without proper index mapping, the causal mask will be applied to the global + indices instead of per-sequence logical indices. + """ + torch.random.manual_seed(42) + + seqlens = [64, 56, 128] + total_seq = sum(seqlens) + num_heads = 4 + dtype = torch.bfloat16 + + cu_seqlens = torch.tensor([0] + list(torch.tensor(seqlens).cumsum(0).tolist()), device="cuda", dtype=torch.int32) + q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) + + out_cute = torch.empty_like(q) + + _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + return_lse=True, + score_mod=score_mod_2, + out=out_cute, + lse=None, + ) + + assert not torch.isnan(out_cute).any(), "Output contains NaN values" + assert torch.isfinite(out_cute).all(), "Output contains infinite values" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From a38d69d65b12b7ddc98caecc77e86aa46ea1534e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 11 Oct 2025 19:00:14 -0400 Subject: [PATCH 112/258] [Cute] Fix softmax for cutlass-dsl==4.2.1 --- flash_attn/cute/cute_dsl_utils.py | 124 ++++++++++++++++++++++++++++++ flash_attn/cute/flash_fwd.py | 9 +-- flash_attn/cute/softmax.py | 51 ++++-------- 3 files changed, 145 insertions(+), 39 deletions(-) create mode 100644 flash_attn/cute/cute_dsl_utils.py diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py new file mode 100644 index 00000000000..6deeac30d34 --- /dev/null +++ b/flash_attn/cute/cute_dsl_utils.py @@ -0,0 +1,124 @@ +# Copyright (c) 2025, Tri Dao. + +import os +import pathlib +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + output = cute_compile_og(*args, **kwargs) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 3d17df958cc..ac2a301971b 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -769,7 +769,7 @@ def kernel( tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) softmax.reset() # group parameters for compute_one_n_block @@ -1650,7 +1650,7 @@ def mma( # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) - softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) m_block, head_idx, batch_idx = work_tile.tile_idx score_mod = self.score_mod mma_one_n_block = partial( @@ -1789,7 +1789,7 @@ def mma( else: self.warp_scheduler_barrier_arrive() - # normalize acc_O by row_sum and calculate the lse + sink_val = None if const_expr(learnable_sink is not None): if const_expr(not self.pack_gqa): sink_val = Float32(learnable_sink[head_idx]) @@ -1801,9 +1801,8 @@ def mma( row = m_block * self.m_block_size + tScS_mn[r][0] q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead sink_val[r] = Float32(learnable_sink[q_head_idx]) - else: - sink_val = None + # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize(sink_val=sink_val) softmax.rescale_O(acc_O, row_scale) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 682265b7cc2..fcd4c32c13c 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -10,45 +10,28 @@ from cutlass import Float32 import flash_attn.cute.utils as utils +from flash_attn.cute.cute_dsl_utils import ParamsBase -class Softmax: - def __init__( - self, +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( scale_log2: Float32, num_rows: cutlass.Constexpr[int], arch: cutlass.Constexpr[int] = 80, softmax_scale: Float32 | None = None ): - self.scale_log2 = scale_log2 - self.num_rows = num_rows - self.arch = arch - self.softmax_scale = softmax_scale - self.row_max = cute.make_fragment(num_rows, Float32) - self.row_sum = cute.make_fragment_like(self.row_max) - - def __extract_mlir_values__(self): - non_constexpr_fields = [self.scale_log2, self.row_max, self.row_sum, self.softmax_scale] - values, self._values_pos = [], [] - for obj in non_constexpr_fields: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - field_names = ['scale_log2', 'row_max', 'row_sum', 'softmax_scale'] - reconstructed_fields = {} - for name, n_items in zip(field_names, self._values_pos): - original_field = getattr(self, name) - reconstructed_fields[name] = cutlass.new_from_mlir_values(original_field, values[:n_items]) - values = values[n_items:] - - new_obj = self.__class__(reconstructed_fields['scale_log2'], self.num_rows, self.arch) - new_obj.row_max = reconstructed_fields['row_max'] - new_obj.row_sum = reconstructed_fields['row_sum'] - new_obj.softmax_scale = reconstructed_fields['softmax_scale'] - return new_obj + row_max = cute.make_fragment(num_rows, Float32) + row_sum = cute.make_fragment(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) def reset(self) -> None: self.row_max.fill(-Float32.inf) @@ -82,7 +65,7 @@ def online_softmax( acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) # Each iteration processes one row of acc_S - for r in cutlass.range(cute.size(self.row_max), unroll_full=True): + for r in cutlass.range_constexpr(cute.size(self.row_max)): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) row_max_cur = self._compute_row_max( acc_S_row, @@ -118,7 +101,7 @@ def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | # quad reduction for row_sum as we didn't do it during each iteration of online softmax self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(self.row_max, Float32) - for r in cutlass.range(cute.size(self.row_sum), unroll_full=True): + for r in cutlass.range_constexpr(cute.size(self.row_sum)): if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) From 437b35a99b7f5da37646982fb0bed98f0c59d3ad Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 11 Oct 2025 20:38:41 -0400 Subject: [PATCH 113/258] [Cute] Fix softmax for fwd_sm100 --- flash_attn/cute/flash_fwd_sm100.py | 4 ++-- flash_attn/cute/softmax.py | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7781e6c3364..cb52f157ad3 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1256,7 +1256,7 @@ def softmax_loop( mask_fn = partial( mask.apply_mask_sm100, m_block=self.q_stage * m_block + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local ) - softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) + softmax = SoftmaxSm100.create(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) softmax.reset() softmax_step = partial( @@ -2004,4 +2004,4 @@ def apply_score_mod( buffers, fastdiv_mods, constant_q_idx=q_idx_wrapped - ) \ No newline at end of file + ) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index fcd4c32c13c..b283e7c7035 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -136,15 +136,21 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) +@dataclass class SoftmaxSm100(Softmax): - def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None): - super().__init__(scale_log2, num_rows=1, arch=100, softmax_scale=softmax_scale) - self.rescale_threshold = rescale_threshold - - def __new_from_mlir_values__(self, values): - new_obj = super().__new_from_mlir_values__(values) - new_obj.rescale_threshold = self.rescale_threshold - return new_obj + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_fragment(num_rows, Float32) + row_sum = cute.make_fragment(num_rows, Float32) + return SoftmaxSm100(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale, rescale_threshold=rescale_threshold) @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: From ea03e0644c22a282d2ccd2b75844c76e4acb436b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 11 Oct 2025 21:30:56 -0400 Subject: [PATCH 114/258] [Cute,Bwd] Simplify bwd_preprocessing kernel --- flash_attn/cute/copy_utils.py | 129 +++++++++++++++++++++++ flash_attn/cute/flash_bwd.py | 3 + flash_attn/cute/flash_bwd_postprocess.py | 4 + flash_attn/cute/flash_bwd_preprocess.py | 82 +++++--------- flash_attn/cute/interface.py | 2 +- 5 files changed, 164 insertions(+), 56 deletions(-) create mode 100644 flash_attn/cute/copy_utils.py diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py new file mode 100644 index 00000000000..9ac20207444 --- /dev/null +++ b/flash_attn/cute/copy_utils.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Optional, Type, Tuple, Callable + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline + + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1), + cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + return copy_tma, s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index a6d061b19b5..de2d4e74ea7 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -347,6 +347,9 @@ def __call__( # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV))) + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] self._setup_attributes() SharedStorage = self._get_shared_storage_cls() tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index b0fa2704138..ddad08beb5b 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -151,6 +151,10 @@ def __call__( if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mdQaccum, mdQ = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mdQaccum, mdQ)] + num_mma_warps = self.num_threads // 32 AtomLayoutdQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index a5da7b7009e..13080d7c2e4 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -9,8 +9,10 @@ import cutlass import cutlass.cute as cute +from cutlass import Float32 from flash_attn.cute import utils +from flash_attn.cute import copy_utils class FlashAttentionBackwardPreprocess: @@ -82,44 +84,13 @@ def _setup_attributes(self): else (32 if self.head_dim_padded % 32 == 0 else 16) ) ) + self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(self.dtype, gmem_k_block_size, self.num_threads) universal_copy_bits = 128 - async_copy_elems = universal_copy_bits // self.dtype.width - # atom_universal_copy: universal copy atom for O & dO load - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.dtype, - num_bits_per_copy=universal_copy_bits, - ) - # tOdO_layout: thread layout for O & dO load - self.gmem_threads_per_row = gmem_k_block_size // async_copy_elems - assert self.num_threads % self.gmem_threads_per_row == 0 - tOdO_layout = cute.make_ordered_layout( - (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row), - order=(1, 0), - ) - # Value layouts for copies - vOdO_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( - atom_universal_copy, tOdO_layout, vOdO_layout - ) - self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv( - atom_universal_copy, tOdO_layout, vOdO_layout - ) - - async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width - atom_universal_copy_accum = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - cutlass.Float32, - num_bits_per_copy=universal_copy_bits, - ) + num_copy_elems_dQaccum = universal_copy_bits // Float32.width assert ( - self.m_block_size * self.head_dim_padded // async_copy_elems_accum + self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum ) % self.num_threads == 0 - self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - atom_universal_copy_accum, - cute.make_layout(self.num_threads), - cute.make_layout(async_copy_elems_accum), - ) + self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(Float32, self.num_threads, num_copy_elems_dQaccum) @cute.jit def __call__( @@ -137,18 +108,22 @@ def __call__( raise TypeError("All tensors must have the same data type") if cutlass.const_expr(not mO.element_type in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(not mdPsum.element_type in [cutlass.Float32]): + if cutlass.const_expr(not mdPsum.element_type in [Float32]): raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + if cutlass.const_expr(not mdQaccum.element_type in [Float32]): raise TypeError("dQaccum tensor must be Float32") if cutlass.const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" - if cutlass.const_expr(not mLSE.element_type in [cutlass.Float32]): + if cutlass.const_expr(not mLSE.element_type in [Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(not mLSElog2.element_type in [cutlass.Float32]): + if cutlass.const_expr(not mLSElog2.element_type in [Float32]): raise TypeError("LSElog2 tensor must be Float32") + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mO, mdO, mdQaccum = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mO, mdO, mdQaccum)] + self._setup_attributes() # grid_dim: (m_block, num_head, batch_size) @@ -165,7 +140,6 @@ def __call__( mLSElog2, mdQaccum, self.gmem_tiled_copy_O, - self.gmem_tiled_copy_dO, self.gmem_tiled_copy_dQaccum, ).launch( grid=grid_dim, @@ -183,7 +157,6 @@ def kernel( mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, - gmem_tiled_copy_dO: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, ): # Thread index, block index @@ -199,23 +172,20 @@ def kernel( gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - gmem_thr_copy_dO = gmem_tiled_copy_dO.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K) tOgO = gmem_thr_copy_O.partition_S(gO) - tOgdO = gmem_thr_copy_dO.partition_S(gdO) + tOgdO = gmem_thr_copy_O.partition_S(gdO) # /////////////////////////////////////////////////////////////////////////////// # Predicate: Mark indices that need to copy when problem_shape isn't a multiple # of tile_shape # /////////////////////////////////////////////////////////////////////////////// # Construct identity layout for KV - cOdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tOcO = gmem_thr_copy_O.partition_S(cOdO) - t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cOdO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[3]) - tOcdO = gmem_thr_copy_dO.partition_S(cOdO) - t0OcdO = gmem_thr_copy_dO.get_slice(0).partition_S(cOdO) - tOpdO = utils.predicate_k(tOcdO, limit=mdO.shape[3]) + tOpdO = utils.predicate_k(tOcO, limit=mdO.shape[3]) seqlen_q = mO.shape[1] seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) @@ -224,7 +194,7 @@ def kernel( gLSE = cute.local_tile( mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,) ) - lse = cutlass.Float32.inf + lse = Float32.inf if tidx < seqlen_q - m_block * self.m_block_size: lse = gLSE[tidx] @@ -244,17 +214,19 @@ def kernel( pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) cute.copy( - gmem_thr_copy_dO, + gmem_thr_copy_O, tOgdO[None, m, None], tOrdO[None, m, None], pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) # Sum across the "k" dimension - dpsum = (tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32)).reduce( + dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) ) - dpsum = utils.warp_reduce(dpsum, operator.add, width=self.gmem_threads_per_row) - dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), cutlass.Float32) + threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] + assert cute.arch.WARP_SIZE % threads_per_row == 0 + dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) + dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) dP_sum.store(dpsum) # Write dPsum from rmem -> gmem @@ -285,4 +257,4 @@ def kernel( ) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.m_block_size: - gLSElog2[tidx] = lse * LOG2_E if lse != -cutlass.Float32.inf else 0.0 + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index fc1c91c0365..3e5a31311ac 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -515,7 +515,7 @@ def backward(ctx, dout, *args): ctx.causal, ctx.softcap, ) - return dq, dk, dv, *((None,) * 5) + return dq, dk, dv, *((None,) * 10) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): From fbdba01e006f8deab10c240fede3913d34d30464 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 00:22:49 -0400 Subject: [PATCH 115/258] [Cute,Fwd,Sm90] Simplify by passing around functions --- flash_attn/cute/flash_fwd.py | 258 ++++++++++++++---------------- flash_attn/cute/hopper_helpers.py | 36 ++++- flash_attn/cute/seqlen_info.py | 40 +++-- flash_attn/cute/utils.py | 4 + 4 files changed, 184 insertions(+), 154 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index ac2a301971b..ac3656bb807 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -14,14 +14,16 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup import cutlass.utils as utils_basic +from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils +from flash_attn.cute import copy_utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -32,6 +34,23 @@ from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase from flash_attn.cute.fast_math import FastDivmod + +def mma_qk(tiled_mma_qk: cute.TiledMma, shape: cute.Shape, tSrQ: cute.Tensor, tSrK: cute.Tensor, smem_idx: Int32, wg_wait: int = -1) -> cute.Tensor: + acc_S = cute.make_fragment(tiled_mma_qk.partition_shape_C(shape), Float32) + sm90_utils.gemm( + tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, smem_idx], zero_init=True, wg_wait=wg_wait + ) + return acc_S + + +def mma_pv(tiled_mma_pv: cute.TiledMma, acc_O: cute.Tensor, tOrP: cute.Tensor, tOrVt: cute.Tensor, smem_idx: Int32, zero_init: Boolean, wg_wait: int = -1) -> None: + sm90_utils.gemm( + tiled_mma_pv, acc_O, tOrP, + tOrVt[None, None, None, smem_idx], + zero_init=zero_init, wg_wait=wg_wait + ) + + class FlashAttentionForwardBase: arch: int = 80 @@ -992,14 +1011,14 @@ def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = Tr def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded + LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded ), self.dtype ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded + LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded ), self.dtype ) @@ -1007,7 +1026,7 @@ def _get_smem_layout_atom(self): if not self.mma_pv_is_rs: sP_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size + LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size ), self.dtype ) @@ -1122,17 +1141,12 @@ def __call__( new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) - for t in (mQ, mO) - ] + mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] - mK, mV = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) - for t in (mK, mV) - ] + mK, mV = [utils.select(t, KV_layout_transpose) for t in (mK, mV)] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None + mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None + tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 @@ -1156,6 +1170,22 @@ def __call__( self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.m_block_size, self.head_dim_padded), None), + (mK, (self.n_block_size, self.head_dim_padded), self.num_stages), + (mV, (self.n_block_size, self.head_dim_v_padded), self.num_stages), + (mO, (self.m_block_size, self.head_dim_v_padded), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.dtype, LayoutEnum.ROW_MAJOR, (self.m_block_size, self.n_block_size) + ) + SharedStorage = self._get_shared_storage_cls() if const_expr(self.pack_gqa): @@ -1177,12 +1207,11 @@ def __call__( self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) + tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast ) - else: - tma_atom_Q, tma_tensor_Q = None, None tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, @@ -1197,12 +1226,11 @@ def __call__( (self.n_block_size, self.head_dim_v_padded), 1 # No mcast for now ) + tma_atom_O, tma_tensor_O = None, None if const_expr(self.use_tma_O): - tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast ) - else: - tma_atom_O = None if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: @@ -1252,7 +1280,7 @@ def __call__( tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, tma_tensor_K, tma_tensor_V, - mO, + tma_tensor_O if const_expr(self.use_tma_O) else mO, mLSE, mCuSeqlensQ, mCuSeqlensK, @@ -1334,12 +1362,9 @@ def kernel( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: - if const_expr(tma_atom_Q is not None): - cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) - if const_expr(self.use_tma_O): - cpasync.prefetch_descriptor(tma_atom_O) + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -1385,15 +1410,11 @@ def kernel( sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma sVt = utils.transpose_view(sV) + sP = None if const_expr(sP_layout is not None): - sP_pi = storage.sP.get_tensor(sP_layout) sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) - else: - sP, sP_pi = None, None # reuse sQ's data iterator - sO_pi = storage.sQ.get_tensor(sO_layout) - # TODO: idk why not using sO_pi is faster - sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) block_info = BlockInfo( self.m_block_size, self.n_block_size, self.is_causal, self.is_local, @@ -1506,11 +1527,14 @@ def load( else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + # mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] else: mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] + # mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + # mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) if const_expr(self.use_tma_Q): @@ -1522,22 +1546,12 @@ def load( cute.group_modes(sQ, 0, 2), cute.group_modes(gQ, 0, 2), ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, - cute.make_layout(1), - cute.group_modes(sK, 0, 2), - cute.group_modes(gK, 0, 2), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, - cute.make_layout(1), - cute.group_modes(sV, 0, 2), - cute.group_modes(gV, 0, 2), - ) - load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) - load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) + # TODO: mcast + # TODO check warp_idx if we have 128 producer threads + load_K, _, _ = copy_utils.tma_get_copy_fn(tma_atom_K, 0, cute.make_layout(1), gK, sK) + load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) + load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) + load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) # load_Q if const_expr(self.use_tma_Q): # TODO: wait for Q to be empty @@ -1550,8 +1564,10 @@ def load( # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) for i in cutlass.range(n_block_max - n_block_min, unroll=2): n_block = n_block_max - i - 1 - load_K(n_block, producer_state=kv_producer_state) - load_V(n_block, producer_state=kv_producer_state) + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1626,15 +1642,19 @@ def mma( acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) acc_O = cute.make_fragment(acc_shape_O, Float32) - # group parameters for mma_one_n_block - mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + mma_qk_fn = partial(mma_qk, tiled_mma_qk, (self.m_block_size, self.n_block_size), tSrQ, tSrK) + mma_pv_fn = partial(mma_pv, tiled_mma_pv, acc_O, tOrP, tOrVt) + mma_one_n_block_all = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, - tiled_mma_qk=tiled_mma_qk, tiled_mma_pv=tiled_mma_pv, tiled_mma_pv_rs=tiled_mma_pv_rs, + mma_qk_fn=mma_qk_fn, + mma_pv_fn=mma_pv_fn, + tiled_mma_pv_rs=tiled_mma_pv_rs, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - mma_params=mma_params, smem_copy_params=smem_copy_params, + acc_O=acc_O, tOrP=tOrP, + smem_copy_params=smem_copy_params, thr_mma_qk=thr_mma_qk, check_inf=True, ) @@ -1673,6 +1693,7 @@ def mma( else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + # mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, @@ -1690,14 +1711,8 @@ def mma( O_should_accumulate = False # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 - ) - pipeline_k.consumer_wait(kv_consumer_state) - sm90_utils.gemm( - tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, kv_consumer_state.index], - zero_init=True, wg_wait=0 - ) + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(kv_consumer_state.index, wg_wait=0) pipeline_k.consumer_release(kv_consumer_state) # Use vectorized score modification if cutlass.const_expr(score_mod is not None): @@ -1717,17 +1732,15 @@ def mma( # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) softmax.online_softmax(acc_S, is_first=True) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - # tOrP.store(tOrP_acc.load().to(self.dtype)) - # the "to(self.dtype)" conversion fails to vectorize for block sizes other - # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of - # 2 elements. So we just call ptx directly. - utils.cvt_f16(tOrP_acc, tOrP) + tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) if const_expr(not self.mma_pv_is_rs): - tPrP = smem_thr_copy_P.retile(tOrP) + tPrP = smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_thr_copy_P, tPrP, tPsP) # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter # acc_O.fill(0.0) @@ -1778,12 +1791,7 @@ def mma( # Last "half" iteration if const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, kv_consumer_state.index], - zero_init=not O_should_accumulate, wg_wait=-1 - ) - warpgroup.wait_group(0) + mma_pv_fn(kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0) pipeline_v.consumer_release(kv_consumer_state) kv_consumer_state.advance() else: @@ -1822,12 +1830,13 @@ def mma_one_n_block( self, n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, + mma_qk_fn: Callable, + mma_pv_fn: Callable, tiled_mma_pv_rs: cute.TiledMma, pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, - mma_params: SimpleNamespace, + acc_O: cute.Tensor, + tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, score_mod: Callable, @@ -1840,17 +1849,10 @@ def mma_one_n_block( mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, - O_should_accumulate: cutlass.Boolean = True, + O_should_accumulate: Boolean = True, ): - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 - ) pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - sm90_utils.gemm( - tiled_mma_qk, acc_S, mma_params.tSrQ, - mma_params.tSrK[None, None, None, smem_pipe_read.index], - zero_init=True, wg_wait=-1 - ) + acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) @@ -1871,24 +1873,25 @@ def mma_one_n_block( row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) # tOrP.store(tOrP_acc.load().to(self.dtype)) - utils.cvt_f16(tOrP_acc, tOrP) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(mma_params.acc_O, row_scale) + softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read.index], - zero_init=not O_should_accumulate, wg_wait=0 - ) + mma_pv_fn(smem_pipe_read.index, zero_init=not O_should_accumulate, wg_wait=0) pipeline_v.consumer_release(smem_pipe_read) smem_pipe_read.advance() return smem_pipe_read @@ -1898,12 +1901,13 @@ def mma_one_n_block_intrawg_overlap( self, n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, + mma_qk_fn: Callable, + mma_pv_fn: Callable, tiled_mma_pv_rs: cute.TiledMma, pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, - mma_params: SimpleNamespace, + acc_O: cute.Tensor, + tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, score_mod: Callable, @@ -1915,26 +1919,15 @@ def mma_one_n_block_intrawg_overlap( fastdiv_mods=None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, - O_should_accumulate: cutlass.Boolean = True, + O_should_accumulate: Boolean = True, ): smem_pipe_read_v = smem_pipe_read.clone() smem_pipe_read.advance() - acc_S = cute.make_fragment( - tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 - ) pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - sm90_utils.gemm( - tiled_mma_qk, acc_S, mma_params.tSrQ, - mma_params.tSrK[None, None, None, smem_pipe_read.index], - zero_init=True, wg_wait=-1 - ) + acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) - sm90_utils.gemm( - tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, - mma_params.tOrVt[None, None, None, smem_pipe_read_v.index], - zero_init=not O_should_accumulate, wg_wait=-1 - ) + mma_pv_fn(smem_pipe_read_v.index, zero_init=not O_should_accumulate, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) @@ -1958,16 +1951,21 @@ def mma_one_n_block_intrawg_overlap( warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - # tOrP.store(tOrP_acc.load().to(self.dtype)) - utils.cvt_f16(tOrP_acc, tOrP) + tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(mma_params.acc_O, row_scale) + softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV return smem_pipe_read @@ -2033,23 +2031,3 @@ def warp_scheduler_barrier_arrive(self): barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, ) - - # @cute.jit - def load_K( - self, - tma_atom: cute.CopyAtom, - tKgK: cute.Tensor, - tKsK: cute.Tensor, - pipeline: cutlass.pipeline.PipelineAsync, - block: Int32, - producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - ): - # TODO: mcast - # TODO check warp_idx if we have 128 producer threads - pipeline.producer_acquire(producer_state) - cute.copy( - tma_atom, - tKgK[None, block], - tKsK[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state) - ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index acb0273effd..5a46139fb6b 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -1,10 +1,13 @@ # Copyright (c) 2025, Tri Dao. +from typing import Type, Union, Optional import cutlass import cutlass.cute as cute +from cutlass import const_expr from cutlass.cute.nvgpu import warpgroup - from cutlass._mlir.dialects import llvm -from cutlass.cutlass_dsl import dsl_user_op +from cutlass.cutlass_dsl import Numeric, dsl_user_op +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_og @cute.jit @@ -18,7 +21,7 @@ def gemm( # A_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: - if cutlass.const_expr(swap_AB): + if const_expr(swap_AB): gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) else: warpgroup.fence() @@ -30,10 +33,34 @@ def gemm( cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) mma_atom.set(warpgroup.Field.ACCUMULATE, True) warpgroup.commit_group() - if cutlass.const_expr(wg_wait >= 0): + if const_expr(wg_wait >= 0): warpgroup.wait_group(wg_wait) +@dsl_user_op +def make_smem_layout( + dtype: Type[Numeric], + layout: LayoutEnum, + shape: cute.Shape, + stage: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] + smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), + dtype, + ) + order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2) + smem_layout_staged = cute.tile_to_shape( + smem_layout_atom, + cute.append(shape, stage) if const_expr(stage is not None) else shape, + order=order if const_expr(stage is not None) else order[:2], + ) + return smem_layout_staged + + @dsl_user_op def tma_reduce_add_bulk_f32( smem_ptr: cute.Pointer, @@ -41,7 +68,6 @@ def tma_reduce_add_bulk_f32( store_bytes: cutlass.Int32, *, loc=None, ip=None ): - cute.make_mma_atom smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index dee63db6bf4..792d84e2d64 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -2,6 +2,7 @@ import cutlass import cutlass.cute as cute +from cutlass import Int32, const_expr """ This consolidates all the info related to sequence length. This is so that we can do all @@ -17,10 +18,10 @@ def __init__( cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, ): - self.offset = 0 if cutlass.const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] - if cutlass.const_expr(seqused is not None): + self.offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + if const_expr(seqused is not None): self.seqlen = seqused[batch_idx] - elif cutlass.const_expr(cu_seqlens is not None): + elif const_expr(cu_seqlens is not None): self.seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: self.seqlen = seqlen_static @@ -37,23 +38,44 @@ def __init__( mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, ): - self.offset_q = 0 if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] - self.offset_k = 0 if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] - if cutlass.const_expr(mSeqUsedQ is not None): + self.offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + self.offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + if const_expr(mSeqUsedQ is not None): self.seqlen_q = mSeqUsedQ[batch_idx] else: self.seqlen_q = ( seqlen_q_static - if cutlass.const_expr(mCuSeqlensQ is None) + if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx + 1] - self.offset_q ) - if cutlass.const_expr(mSeqUsedK is not None): + if const_expr(mSeqUsedK is not None): self.seqlen_k = mSeqUsedK[batch_idx] else: self.seqlen_k = ( seqlen_k_static - if cutlass.const_expr(mCuSeqlensK is None) + if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - self.offset_k ) self.has_cu_seqlens_q: int = mCuSeqlensQ is not None self.has_cu_seqlens_k: int = mCuSeqlensK is not None + + def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + """Seqlen must be the first dimension of mQ + """ + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset = self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q) + idx = (offset,) + (0,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + + def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: + """Seqlen must be the first dimension of mK + """ + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + idx = (self.offset_k,) + (0,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 6d48aca644d..06e7824dc13 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -208,6 +208,10 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: return rA_mma_view +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + def transpose_view(a: cute.Tensor) -> cute.Tensor: """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) From b528f4b2d29e9521fc858f86e4f075195e097619 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 00:44:21 -0400 Subject: [PATCH 116/258] [Cute,Fwd,Sm90] Simplify score mode by passing around partial fn --- flash_attn/cute/flash_fwd.py | 87 ++++++++++-------------------------- 1 file changed, 24 insertions(+), 63 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index ac3656bb807..33a77aef289 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -959,17 +959,17 @@ def load_V_next(): ) if cutlass.const_expr(score_mod is not None): self.apply_score_mod( - acc_S, mma_params.thr_mma_qk, batch_idx, head_idx, m_block, + acc_S, n_block, - softmax=softmax, + softmax_scale=softmax.softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, ) - + smem_pipe_write = self.advance_pipeline(smem_pipe_write) def load_K_next(): if n_block - self.num_stages >= 0: @@ -1655,7 +1655,6 @@ def mma( pipeline_k=pipeline_k, pipeline_v=pipeline_v, acc_O=acc_O, tOrP=tOrP, smem_copy_params=smem_copy_params, - thr_mma_qk=thr_mma_qk, check_inf=True, ) @@ -1672,18 +1671,22 @@ def mma( # shape: (atom_v_m * rest_m) softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) m_block, head_idx, batch_idx = work_tile.tile_idx - score_mod = self.score_mod - mma_one_n_block = partial( - mma_one_n_block_all, softmax=softmax, score_mod=score_mod, - batch_idx=batch_idx, head_idx=head_idx, m_block=m_block, buffers=buffers, - fastdiv_mods=fastdiv_mods - ) seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, batch_idx, head_idx, m_block, + softmax_scale=softmax.softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, + ) + mma_one_n_block = partial( + mma_one_n_block_all, softmax=softmax, score_mod_fn=score_mod_fn + ) softmax.reset() # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): @@ -1715,18 +1718,8 @@ def mma( acc_S = mma_qk_fn(kv_consumer_state.index, wg_wait=0) pipeline_k.consumer_release(kv_consumer_state) # Use vectorized score modification - if cutlass.const_expr(score_mod is not None): - self.apply_score_mod( - acc_S, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - n_block_max - 1, - softmax=softmax, - buffers=buffers, - fastdiv_mods=fastdiv_mods, - ) + if cutlass.const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block_max - 1) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) @@ -1839,13 +1832,7 @@ def mma_one_n_block( tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, - score_mod: Callable, - batch_idx: cutlass.Int32, - head_idx: cutlass.Int32, - m_block: cutlass.Int32, - thr_mma_qk: cute.TiledMma, - buffers=None, - fastdiv_mods=None, + score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, @@ -1856,18 +1843,8 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) - if cutlass.const_expr(score_mod is not None): - self.apply_score_mod( - acc_S, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - n_block, - softmax=softmax, - buffers=buffers, - fastdiv_mods=fastdiv_mods, - ) + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) @@ -1910,13 +1887,7 @@ def mma_one_n_block_intrawg_overlap( tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, - score_mod: Callable, - batch_idx: cutlass.Int32, - head_idx: cutlass.Int32, - m_block: cutlass.Int32, - thr_mma_qk: cute.TiledMma, - buffers=None, - fastdiv_mods=None, + score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, O_should_accumulate: Boolean = True, @@ -1931,18 +1902,8 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) - if cutlass.const_expr(score_mod is not None): - self.apply_score_mod( - acc_S, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - n_block, - softmax=softmax, - buffers=buffers, - fastdiv_mods=fastdiv_mods, - ) + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) @@ -1982,13 +1943,13 @@ def mma_init(self): @cute.jit def apply_score_mod( self, - acc_S, thr_mma_qk, batch_idx, head_idx, m_block, + acc_S, n_block, - softmax, + softmax_scale, buffers=None, fastdiv_mods=None, ): @@ -2003,7 +1964,7 @@ def apply_score_mod( self.score_mod, batch_idx, head_idx, - softmax.softmax_scale, + softmax_scale, self.vec_size, self.qk_acc_dtype, buffers, From 13f20773c8a2a1b0bb394488e61930ab81ca320e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 00:52:28 -0400 Subject: [PATCH 117/258] [Cute] Optionally dump cubin and sass --- flash_attn/cute/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py index f1a4ed2d214..fbbfc14050e 100644 --- a/flash_attn/cute/__init__.py +++ b/flash_attn/cute/__init__.py @@ -1,11 +1,19 @@ """Flash Attention CUTE (CUDA Template Engine) implementation.""" +__version__ = "0.1.0" + +import cutlass.cute as cute + from .interface import ( flash_attn_func, flash_attn_varlen_func, ) -__version__ = "0.1.0" +from flash_attn.cute.cute_dsl_utils import cute_compile_patched + +# Patch cute.compile to optionally dump SASS +cute.compile = cute_compile_patched + __all__ = [ "flash_attn_func", From c172985a41b351f31f8feb21b1ede2946ce56928 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 09:54:28 -0400 Subject: [PATCH 118/258] [Cute,Fwd,Sm90] Rename m_block_size->tile_m, n_block_size->tile_n --- flash_attn/cute/flash_fwd.py | 267 +++++++++++++++++------------------ flash_attn/cute/interface.py | 4 +- flash_attn/cute/mask.py | 28 ++-- 3 files changed, 144 insertions(+), 155 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 33a77aef289..6e56b23d76e 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -64,8 +64,8 @@ def __init__( is_causal: bool = False, is_local: bool = False, pack_gqa: bool = True, - m_block_size: int = 128, - n_block_size: int = 128, + tile_m: int = 128, + tile_n: int = 128, num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, @@ -79,10 +79,10 @@ def __init__( :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int - :param n_block_size: n block size - :type n_block_size: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal @@ -92,19 +92,19 @@ def __init__( self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v - self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication - self.check_hdim_oob = head_dim != self.head_dim_padded - self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal self.is_local = is_local self.pack_gqa = pack_gqa - self.m_block_size = m_block_size - self.n_block_size = n_block_size + self.tile_m = tile_m + self.tile_n = tile_n self.num_threads = num_threads self.num_stages = num_stages self.Q_in_regs = Q_in_regs @@ -117,7 +117,7 @@ def __init__( @staticmethod def can_implement( - dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages, num_threads, is_causal, + dtype, head_dim, head_dim_v, tile_m, tile_n, num_stages, num_threads, is_causal, Q_in_regs=False ) -> bool: """Check if the kernel can be implemented with the given parameters. @@ -126,10 +126,10 @@ def can_implement( :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int - :param n_block_size: n block size - :type n_block_size: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal @@ -144,15 +144,15 @@ def can_implement( return False if head_dim_v % 8 != 0: return False - if n_block_size % 16 != 0: + if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False # Check if block size setting is out of shared memory capacity # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size - smem_usage_Q = m_block_size * head_dim * 2 - smem_usage_K = n_block_size * head_dim * num_stages * 2 - smem_usage_V = n_block_size * head_dim_v * num_stages * 2 + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 @@ -160,7 +160,7 @@ def can_implement( if smem_usage > smem_capacity: return False # Check if twice the block size is divisible by the number of threads - if (m_block_size * 2) % num_threads != 0: + if (tile_m * 2) % num_threads != 0: return False return True @@ -199,20 +199,20 @@ def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = self._get_smem_layout_atom() self.sQ_layout = cute.tile_to_shape( - sQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1), + sQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1), ) self.sK_layout = cute.tile_to_shape( - sK_layout_atom, (self.n_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2), + sK_layout_atom, (self.tile_n, self.tile_hdim, self.num_stages), (0, 1, 2), ) self.sV_layout = cute.tile_to_shape( - sV_layout_atom, (self.n_block_size, self.head_dim_v_padded, self.num_stages), (0, 1, 2), + sV_layout_atom, (self.tile_n, self.tile_hdimv, self.num_stages), (0, 1, 2), ) self.sO_layout = cute.tile_to_shape( - sO_layout_atom, (self.m_block_size, self.head_dim_v_padded), (0, 1), + sO_layout_atom, (self.tile_m, self.tile_hdimv), (0, 1), ) if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( - sP_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), + sP_layout_atom, (self.tile_m, self.tile_n), (0, 1), ) else: self.sP_layout = None @@ -244,7 +244,7 @@ def _setup_attributes(self): (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we load Q - assert self.m_block_size % tQ_layout.shape[0] == 0 + assert self.tile_m % tQ_layout.shape[0] == 0 tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems tV_layout = cute.make_ordered_layout( (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), @@ -255,7 +255,7 @@ def _setup_attributes(self): (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O - assert self.m_block_size % tO_layout.shape[0] == 0 + assert self.tile_m % tO_layout.shape[0] == 0 # Value layouts for copies vQKV_layout = cute.make_layout((1, async_copy_elems)) @@ -323,8 +323,8 @@ def epilogue( # copy acc O from rmem to smem with the smem copy atom cute.copy(smem_copy_atom_O, taccOrO, taccOsO) - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) + cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) + pack_gqa = PackGQA(self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead) # Write LSE from rmem -> gmem if const_expr(mLSE is not None): @@ -334,9 +334,9 @@ def epilogue( offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) if const_expr(not self.pack_gqa): - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( - gLSE.layout, cute.make_layout((self.head_dim_v_padded,), stride=(0,)) + gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,)) ) gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) thr_mma = tiled_mma.get_slice(tidx) @@ -347,7 +347,7 @@ def epilogue( # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): - if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - taccOcO[0][0]: + if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]: taccOgLSE[m, 0] = lse[m] else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) @@ -365,7 +365,7 @@ def epilogue( # ensure smem writes are visible to TMA cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) tOsO, tOgO = cpasync.tma_partition( tma_atom_O, 0, @@ -387,14 +387,14 @@ def epilogue( # load acc O from smem to rmem for wider vectorization cute.autovec_copy(tOsO, tOrO) if const_expr(not self.pack_gqa): - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]: cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], @@ -419,14 +419,14 @@ def load_Q( headdim: Int32, ): tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) - cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tQcQ = gmem_thr_copy.partition_S(cQ) t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) tQpQ = utils.predicate_k(tQcQ, limit=headdim) for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. - if t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]: + if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]: cute.copy( gmem_thr_copy, tQgQ[None, m, None], @@ -450,17 +450,17 @@ def load_K( need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load K? - is_even_n_smem_k = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 + is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_k): # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. if const_expr(is_even_n_smem_k): - seqlen_limit = seqlen - block * self.n_block_size + seqlen_limit = seqlen - block * self.tile_n else: if const_expr(not need_predicates): - seqlen_limit = self.n_block_size + seqlen_limit = self.tile_n else: - seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) + seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n) seqlen_limit -= tKcK[0][0] for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: @@ -494,14 +494,14 @@ def load_V( need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load V? - is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 + is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_v): for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked - if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: + if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.tile_n: predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None if const_expr(need_predicates): - seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] + seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): @@ -525,9 +525,9 @@ def load_V( class FlashAttentionForwardSm80(FlashAttentionForwardBase): def _get_smem_layout_atom(self): - sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) sK_layout_atom = sQ_layout_atom - sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv) sO_layout_atom = sV_layout_atom sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom @@ -604,7 +604,7 @@ def __call__( mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( - cute.ceil_div(mQ.shape[0], self.m_block_size), + cute.ceil_div(mQ.shape[0], self.tile_m), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]), ) @@ -690,7 +690,7 @@ def kernel( m_block, num_head, batch_size = cute.arch.block_idx() block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + self.tile_m, self.tile_n, self.is_causal, self.is_local, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -705,9 +705,9 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - blkQ_shape = (self.m_block_size, self.head_dim_padded) - blkK_shape = (self.n_block_size, self.head_dim_padded) - blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkQ_shape = (self.tile_m, self.tile_hdim) + blkK_shape = (self.tile_n, self.tile_hdim) + blkV_shape = (self.tile_n, self.tile_hdimv) gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) num_head_kv = num_head // self.qhead_per_kvhead gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) @@ -724,7 +724,7 @@ def kernel( sV = storage.sV.get_tensor(sV_layout) else: sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) - # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) @@ -742,7 +742,7 @@ def kernel( tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) - acc_shape_O = thr_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) + acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) acc_O = cute.make_fragment(acc_shape_O, Float32) acc_O.fill(0.0) @@ -768,14 +768,14 @@ def kernel( # of tile_shape # /////////////////////////////////////////////////////////////////////////////// # Construct identity layout for KV - cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tKcK = gmem_thr_copy_K.partition_S(cK) t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) - if const_expr(self.head_dim_padded == self.head_dim_v_padded): + if const_expr(self.tile_hdim == self.tile_hdimv): tVcV = tKcK t0VcV = t0KcK else: - cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) tVcV = gmem_thr_copy_V.partition_S(cV) t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) # Allocate predicate tensors for m and n, here we only allocate the tile of k, and @@ -856,10 +856,10 @@ def preprocess_Q(): # Start processing of the first n-block. # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( - self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, + self.tile_m, self.tile_n, seqlen.seqlen_q, seqlen.seqlen_k, window_size_left, window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -937,7 +937,7 @@ def sync(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) cute.arch.barrier() - acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) + acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) acc_S = cute.make_fragment(acc_shape_S, Float32) acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S @@ -1011,14 +1011,14 @@ def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = Tr def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim ), self.dtype ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv ), self.dtype ) @@ -1026,7 +1026,7 @@ def _get_smem_layout_atom(self): if not self.mma_pv_is_rs: sP_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n ), self.dtype ) @@ -1041,8 +1041,8 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.n_block_size), + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_n), ) tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -1050,8 +1050,8 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.head_dim_v_padded), + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, ) tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( @@ -1060,8 +1060,8 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.head_dim_v_padded), + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), a_source=warpgroup.OperandSource.RMEM ) return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs @@ -1165,8 +1165,8 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) - self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) + self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa # TODO: rescale_O_before_gemm self._setup_attributes() @@ -1174,16 +1174,16 @@ def __call__( self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) for mX, shape, stage in [ - (mQ, (self.m_block_size, self.head_dim_padded), None), - (mK, (self.n_block_size, self.head_dim_padded), self.num_stages), - (mV, (self.n_block_size, self.head_dim_v_padded), self.num_stages), - (mO, (self.m_block_size, self.head_dim_v_padded), None), + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), ] ] self.sP_layout = None if const_expr(not self.mma_pv_is_rs): self.sP_layout = sm90_utils.make_smem_layout( - mV.dtype, LayoutEnum.ROW_MAJOR, (self.m_block_size, self.n_block_size) + mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) ) SharedStorage = self._get_shared_storage_cls() @@ -1210,40 +1210,40 @@ def __call__( tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast + gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.tile_m, self.tile_hdim), # No mcast ) tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_padded), + (self.tile_n, self.tile_hdim), 1 # No mcast for now ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_v_padded), + (self.tile_n, self.tile_hdimv), 1 # No mcast for now ) tma_atom_O, tma_tensor_O = None, None if const_expr(self.use_tma_O): tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast + gmem_tiled_copy_O, mO, self.sO_layout, (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.m_block_size), + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.m_block_size, self.n_block_size), + tile_shape_mn=(self.tile_m, self.tile_n), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -1408,7 +1408,7 @@ def kernel( sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) - # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) sP = None if const_expr(sP_layout is not None): @@ -1417,7 +1417,7 @@ def kernel( sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) block_info = BlockInfo( - self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + self.tile_m, self.tile_n, self.is_causal, self.is_local, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -1428,7 +1428,7 @@ def kernel( mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( - AttentionMask, self.m_block_size, self.n_block_size, + AttentionMask, self.tile_m, self.tile_n, window_size_left=window_size_left, window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -1522,23 +1522,14 @@ def load( # if work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - if const_expr(not seqlen.has_cu_seqlens_q): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) - # mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx - if const_expr(not seqlen.has_cu_seqlens_k): - mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] - else: - mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] - # mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] - # mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] - gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) - gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) if const_expr(self.use_tma_Q): - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, 0, @@ -1618,7 +1609,7 @@ def mma( tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) if const_expr(self.mma_pv_is_rs): - acc_S_shape = tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) + acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) tOrP = cute.make_fragment( utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype ) @@ -1640,17 +1631,16 @@ def mma( self.mma_init() - acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) + acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) acc_O = cute.make_fragment(acc_shape_O, Float32) smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - mma_qk_fn = partial(mma_qk, tiled_mma_qk, (self.m_block_size, self.n_block_size), tSrQ, tSrK) + mma_qk_fn = partial(mma_qk, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK) mma_pv_fn = partial(mma_pv, tiled_mma_pv, acc_O, tOrP, tOrVt) mma_one_n_block_all = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, mma_qk_fn=mma_qk_fn, - mma_pv_fn=mma_pv_fn, tiled_mma_pv_rs=tiled_mma_pv_rs, pipeline_k=pipeline_k, pipeline_v=pipeline_v, acc_O=acc_O, tOrP=tOrP, @@ -1665,11 +1655,11 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) while work_tile.is_valid_tile: # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) - softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) @@ -1682,23 +1672,17 @@ def mma( score_mod_fn = partial( self.apply_score_mod, thr_mma_qk, batch_idx, head_idx, m_block, - softmax_scale=softmax.softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, + softmax_scale=softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( mma_one_n_block_all, softmax=softmax, score_mod_fn=score_mod_fn ) - softmax.reset() # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): - pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) - if const_expr(not seqlen.has_cu_seqlens_q): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) - # mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + pack_gqa = PackGQA(self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, # headdim=mQ.shape[1]) pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) @@ -1709,8 +1693,9 @@ def mma( q_consumer_phase ^= 1 # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): @@ -1740,9 +1725,11 @@ def mma( else: self.warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( - n_block_max - 1, kv_consumer_state, - is_first_n_block=True, mask_fn=partial(mask_fn, mask_seqlen=True), - O_should_accumulate=False + kv_consumer_state, + n_block=n_block_max - 1, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_seqlen=True), ) O_should_accumulate = True # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) @@ -1754,10 +1741,11 @@ def mma( ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile kv_consumer_state = mma_one_n_block( - n_block, kv_consumer_state, mask_fn=partial(mask_fn, mask_seqlen=False), - O_should_accumulate=O_should_accumulate + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), ) O_should_accumulate = True n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) @@ -1767,18 +1755,21 @@ def mma( ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - kv_consumer_state = mma_one_n_block(n_block, kv_consumer_state, check_inf=True, O_should_accumulate=O_should_accumulate) + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + ) O_should_accumulate = True # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): - n_block = n_block_max - 1 - n_tile kv_consumer_state = mma_one_n_block( - n_block, kv_consumer_state, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False), - O_should_accumulate=O_should_accumulate + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), ) O_should_accumulate = True # Last "half" iteration @@ -1796,10 +1787,10 @@ def mma( sink_val = Float32(learnable_sink[head_idx]) else: # Each thread might have a different sink value due to different q_head sink_val = cute.make_fragment_like(softmax.row_max, Float32) - cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS)) for r in cutlass.range(cute.size(sink_val), unroll_full=True): - row = m_block * self.m_block_size + tScS_mn[r][0] + row = m_block * self.tile_m + tScS_mn[r][0] q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead sink_val[r] = Float32(learnable_sink[q_head_idx]) @@ -1821,8 +1812,8 @@ def mma( @cute.jit def mma_one_n_block( self, - n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, mma_qk_fn: Callable, mma_pv_fn: Callable, tiled_mma_pv_rs: cute.TiledMma, @@ -1836,7 +1827,6 @@ def mma_one_n_block( mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, - O_should_accumulate: Boolean = True, ): pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) @@ -1868,7 +1858,7 @@ def mma_one_n_block( cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - mma_pv_fn(smem_pipe_read.index, zero_init=not O_should_accumulate, wg_wait=0) + mma_pv_fn(smem_pipe_read.index, wg_wait=0) pipeline_v.consumer_release(smem_pipe_read) smem_pipe_read.advance() return smem_pipe_read @@ -1876,8 +1866,8 @@ def mma_one_n_block( @cute.jit def mma_one_n_block_intrawg_overlap( self, - n_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, mma_qk_fn: Callable, mma_pv_fn: Callable, tiled_mma_pv_rs: cute.TiledMma, @@ -1890,7 +1880,6 @@ def mma_one_n_block_intrawg_overlap( score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, - O_should_accumulate: Boolean = True, ): smem_pipe_read_v = smem_pipe_read.clone() smem_pipe_read.advance() @@ -1898,7 +1887,7 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_sync() acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) - mma_pv_fn(smem_pipe_read_v.index, zero_init=not O_should_accumulate, wg_wait=-1) + mma_pv_fn(smem_pipe_read_v.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) @@ -1954,8 +1943,8 @@ def apply_score_mod( fastdiv_mods=None, ): # Prepare index tensor - cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) - cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) tScS = thr_mma_qk.partition_C(cS) apply_score_mod_inner( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 3e5a31311ac..b13589c5670 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -240,8 +240,8 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - m_block_size=m_block_size, - n_block_size=n_block_size, + tile_m=m_block_size, + tile_n=n_block_size, # num_stages=1, num_stages=2, num_threads=num_threads, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 0f99add2cce..bacb69e9f00 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -11,8 +11,8 @@ @dataclass(frozen=True) class AttentionMask: - m_block_size: cutlass.Constexpr[int] - n_block_size: cutlass.Constexpr[int] + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] seqlen_q: cutlass.Int32 seqlen_k: cutlass.Int32 window_size_left: Optional[cutlass.Int32] = None @@ -32,13 +32,13 @@ def apply_mask( ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) # We use t0ScS as these indices are known at compile time. We then must subtract the # column limit by the thread column offset. t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) thr_col_offset = tScS_mn[0][1] - seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): if cutlass.const_expr(False): @@ -71,10 +71,10 @@ def apply_mask( assert cute.size(acc_S_mn.shape[0]) <= threads_per_row tidx = thr_mma.thr_idx mma_m_idx = ( - m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0] + m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] ) // self.qhead_per_kvhead_packgqa causal_row_offset = ( - 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset + 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset ) c = 0 col_limit_transformed = 0 @@ -86,7 +86,7 @@ def apply_mask( for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m else: row_idx = utils.shuffle_sync( mma_m_idx, r % threads_per_row, width=threads_per_row @@ -122,7 +122,7 @@ def apply_mask( c = 0 for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m else: row_idx = utils.shuffle_sync( mma_m_idx, r % threads_per_row, width=threads_per_row @@ -132,7 +132,7 @@ def apply_mask( if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: - col_limit_right = self.n_block_size + col_limit_right = self.tile_n col_limit_left = ( row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 ) @@ -158,10 +158,10 @@ def apply_mask_sm100( mask_local: cutlass.Constexpr, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" - cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) - seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) @@ -197,8 +197,8 @@ def apply_mask_sm100( # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf # if tidx == 0: cute.print_tensor(acc_S) else: # Causal or local - causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - row_idx = tScS_t2r[0][0] + m_block * self.m_block_size + causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q + row_idx = tScS_t2r[0][0] + m_block * self.tile_m if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa c = 0 @@ -243,7 +243,7 @@ def apply_mask_sm100( if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: - col_limit_right = self.n_block_size + col_limit_right = self.tile_n col_limit_left = ( row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 ) From 9eee0898c1feb8a959b707cf61d0f1729c977ea0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 10:12:12 -0400 Subject: [PATCH 119/258] [Cute,Bwd,Sm90] Format file w ruff --- flash_attn/cute/flash_bwd_sm90.py | 1037 +++++++++++++++-------------- 1 file changed, 539 insertions(+), 498 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 8163fb3663c..fc6b6c7a414 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -7,7 +7,8 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warpgroup -#import cutlass.pipeline + +# import cutlass.pipeline import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass import const_expr @@ -19,6 +20,7 @@ from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd + class FlashAttentionBackwardSm90: arch = 90 @@ -34,7 +36,6 @@ def __init__( num_threads: int = 384, Q_in_regs: bool = False, ): - self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 @@ -54,10 +55,15 @@ def __init__( @staticmethod def can_implement( - dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages, num_threads, - Q_in_regs=False + dtype, + head_dim, + head_dim_v, + m_block_size, + n_block_size, + num_stages, + num_threads, + Q_in_regs=False, ) -> bool: - if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: @@ -107,44 +113,37 @@ def _check_type( def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, - self.dtype, - self.head_dim_padded + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded ), - self.dtype + self.dtype, ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, - self.dtype, - self.head_dim_v_padded + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded ), - self.dtype + self.dtype, ) sPdS_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, - self.dtype, - self.n_block_size + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size ), - self.dtype + self.dtype, ) sdO_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, - self.dtype, - self.head_dim_padded + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded ), - self.dtype + self.dtype, ) return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom - def _setup_attributes(self): - sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom = self._get_smem_layout_atom() + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom = ( + self._get_smem_layout_atom() + ) universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.dtype.width @@ -155,20 +154,43 @@ def _setup_attributes(self): num_bits_per_copy=universal_copy_bits, ) - self.sQ_layout = cute.tile_to_shape(sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2),) - self.sK_layout = cute.tile_to_shape(sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),) - self.sV_layout = cute.tile_to_shape(sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),) - self.sdO_layout = cute.tile_to_shape(sdO_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2),) - - self.sPdS_layout = cute.tile_to_shape(sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),) - self.sdQaccum_layout = cute.make_layout(shape=(self.m_block_size * self.head_dim_padded, ),) + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self.m_block_size, self.head_dim_padded, self.num_stages), + (0, 1, 2), + ) + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, + (self.n_block_size, self.head_dim_padded), + (0, 1), + ) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, + (self.n_block_size, self.head_dim_v_padded), + (0, 1), + ) + self.sdO_layout = cute.tile_to_shape( + sdO_layout_atom, + (self.m_block_size, self.head_dim_padded, self.num_stages), + (0, 1, 2), + ) + self.sPdS_layout = cute.tile_to_shape( + sPdS_layout_atom, + (self.m_block_size, self.n_block_size), + (0, 1), + ) + self.sdQaccum_layout = cute.make_layout( + shape=(self.m_block_size * self.head_dim_padded,), + ) # dQaccum R->S self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits), - cute.make_layout(self.num_mma_threads), - cute.make_layout(universal_copy_bits // cutlass.Float32.width) + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits + ), + cute.make_layout(self.num_mma_threads), + cute.make_layout(universal_copy_bits // cutlass.Float32.width), ) # dV: S->G @@ -178,9 +200,7 @@ def _setup_attributes(self): order=(1, 0), ) self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv( - atom_universal_copy, - tdV_layout, - cute.make_layout((1, async_copy_elems)) + atom_universal_copy, tdV_layout, cute.make_layout((1, async_copy_elems)) ) # dK: S->G @@ -190,13 +210,10 @@ def _setup_attributes(self): order=(1, 0), ) self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv( - atom_universal_copy, - tdK_layout, - cute.make_layout((1, async_copy_elems)) + atom_universal_copy, tdK_layout, cute.make_layout((1, async_copy_elems)) ) def _get_tiled_mma(self): - # C = A @ B.T tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -214,7 +231,7 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN, cutlass.Float32, - atom_layout_mnk=(self.n_block_size // 64 , 1, 1), + atom_layout_mnk=(self.n_block_size // 64, 1, 1), tiler_mn=(64, self.head_dim_padded), ) # C = A @ B @@ -230,104 +247,102 @@ def _get_tiled_mma(self): return tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum - def _get_shared_storage_cls(self): sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 128 sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] for (layout, type, alignment) in [ - (self.sQ_layout, self.dtype, sQ_alignment), - (self.sK_layout, self.dtype, sK_alignment), - (self.sV_layout, self.dtype, sV_alighment), - (self.sdO_layout, self.dtype, sdO_alignment), - (self.sdQaccum_layout, cutlass.Float32, sdQaccum_alignment) + (self.sQ_layout, self.dtype, sQ_alignment), + (self.sK_layout, self.dtype, sK_alignment), + (self.sV_layout, self.dtype, sV_alighment), + (self.sdO_layout, self.dtype, sdO_alignment), + (self.sdQaccum_layout, cutlass.Float32, sdQaccum_alignment), ] ] - cosize_sPdS = cute.cosize(self.sPdS_layout) - sPdS_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sPdS], 1024] - sLSE_struct = cute.struct.Align[cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128] - sdPsum_struct = cute.struct.Align[cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128] + cosize_sPdS = cute.cosize(self.sPdS_layout) + sPdS_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sPdS], 1024] + sLSE_struct = cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128 + ] + sdPsum_struct = cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128 + ] - mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_LSE_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_LSE_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] mbar_ptr_dPsum_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_dO_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - - mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, 2] - mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_dO_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, 2] @cute.struct class SharedStorageQKV: - mbar_ptr_Q: mbar_ptr_Q_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - mbar_ptr_lse: mbar_ptr_LSE_struct + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + mbar_ptr_lse: mbar_ptr_LSE_struct mbar_ptr_dpsum: mbar_ptr_dPsum_struct - mbar_ptr_dO: mbar_ptr_dO_struct - - sQ: sQ_struct - sV: sV_struct - sK: sK_struct - sPdS: sPdS_struct - sLSE: sLSE_struct - sdPsum: sdPsum_struct - sdO: sdO_struct + mbar_ptr_dO: mbar_ptr_dO_struct + + sQ: sQ_struct + sV: sV_struct + sK: sK_struct + sPdS: sPdS_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sdO: sdO_struct sdQaccum: sdQaccum_struct return SharedStorageQKV @cute.jit - def __call__(self, + def __call__( + self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, - - mdO: cute.Tensor, + mdO: cute.Tensor, mLSE: cute.Tensor, - - mdPsum: cute.Tensor, + mdPsum: cute.Tensor, mdQaccum: cute.Tensor, - mdK: cute.Tensor, - mdV: cute.Tensor, - + mdK: cute.Tensor, + mdV: cute.Tensor, softmax_scale: cutlass.Float32, - stream: cuda.CUstream, - + stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, - mSeqUsedQ: Optional[cute.Tensor] = None, - mSeqUsedK: Optional[cute.Tensor] = None, - - softcap: cutlass.Float32 | float | None = None, - window_size_left: cutlass.Int32 | int | None = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: cutlass.Float32 | float | None = None, + window_size_left: cutlass.Int32 | int | None = None, window_size_right: cutlass.Int32 | int | None = None, ): - self._check_type( - *(t.element_type if t is not None else None - for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)) + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ) ) - layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) mQ, mK, mV, mdK, mdV, mdO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=layout_transpose)) for t in (mQ, mK, mV, mdK, mdV, mdO) ] - LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) mLSE, mdPsum, mdQaccum = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_dPsum_dQaccum_transpose)) for t in (mLSE, mdPsum, mdQaccum) ] - tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum = self._get_tiled_mma() - self.tiled_mma_SdP = tiled_mma_SdP - self.tiled_mma_dKV = tiled_mma_dKV + self.tiled_mma_SdP = tiled_mma_SdP + self.tiled_mma_dKV = tiled_mma_dKV self.tiled_mma_sdQaccum = tiled_mma_dQaccum self.num_mma_threads = tiled_mma_SdP.size @@ -342,15 +357,21 @@ def __call__(self, self._setup_attributes() SharedStorage = self._get_shared_storage_cls() + self.tma_copy_q_bytes = cute.size_in_bytes( + mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1]) + ) + self.tma_copy_k_bytes = cute.size_in_bytes( + mK.element_type, cute.select(self.sK_layout, mode=[0, 1]) + ) + self.tma_copy_v_bytes = cute.size_in_bytes( + mV.element_type, cute.select(self.sK_layout, mode=[0, 1]) + ) - self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) - self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) - self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sK_layout, mode=[0, 1])) - - self.tma_copy_do_bytes = cute.size_in_bytes(mdO.element_type, cute.select(self.sdO_layout, mode=[0,1])) - self.tma_copy_lse_bytes = self.m_block_size * 4 - self.tma_copy_dPsum_bytes = self.m_block_size * 4 - + self.tma_copy_do_bytes = cute.size_in_bytes( + mdO.element_type, cute.select(self.sdO_layout, mode=[0, 1]) + ) + self.tma_copy_lse_bytes = self.m_block_size * 4 + self.tma_copy_dPsum_bytes = self.m_block_size * 4 tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), @@ -363,30 +384,32 @@ def __call__(self, mK, cute.select(self.sK_layout, mode=[0, 1]), (self.n_block_size, self.head_dim_padded), - 1 + 1, ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mV, - cute.select(self.sV_layout, mode=[0,1]), + cute.select(self.sV_layout, mode=[0, 1]), (self.n_block_size, self.head_dim_v_padded), - 1 + 1, ) tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mdO, - cute.select(self.sdO_layout, mode=[0,1]), - (self.m_block_size, self.head_dim_padded) + cute.select(self.sdO_layout, mode=[0, 1]), + (self.m_block_size, self.head_dim_padded), ) tma_atom_LSE, tma_tensor_LSE = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mLSE, - cute.make_layout(self.m_block_size), (self.m_block_size,), + cute.make_layout(self.m_block_size), + (self.m_block_size,), ) tma_atom_dPsum, tma_tensor_dPsum = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mdPsum, - cute.make_layout(self.m_block_size), (self.m_block_size, ), + cute.make_layout(self.m_block_size), + (self.m_block_size,), ) TileScheduler = SingleTileScheduler tile_sched_args = TileSchedulerArguments( @@ -400,7 +423,7 @@ def __call__(self, tile_shape_mn=(self.m_block_size, self.n_block_size), mCuSeqlensQ=None, mSeqUsedQ=None, - qhead_per_kvhead_packgqa= 1, + qhead_per_kvhead_packgqa=1, element_size=self.dtype.width // 8, is_persistent=False, lpt=False, @@ -419,33 +442,27 @@ def __call__(self, tma_tensor_LSE, tma_tensor_dPsum, tma_tensor_dO, - tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_LSE, tma_atom_dPsum, tma_atom_dO, - mdK, mdV, mdQaccum, - self.sQ_layout, self.sK_layout, self.sV_layout, self.sPdS_layout, self.sdO_layout, self.sdQaccum_layout, - self.gmem_tiled_copy_dV, self.gmem_tiled_copy_dK, self.r2s_tiled_copy_dQaccum, - tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum, - softmax_scale_log2, softmax_scale, tile_sched_params, @@ -462,47 +479,41 @@ def __call__(self, @cute.kernel def kernel( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mLSE: cute.Tensor, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdO: cute.Tensor, - - tma_atom_Q: Optional[cute.CopyAtom], - tma_atom_K: Optional[cute.CopyAtom], - tma_atom_V: Optional[cute.CopyAtom], - tma_atom_LSE: Optional[cute.CopyAtom], + mdO: cute.Tensor, + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_LSE: Optional[cute.CopyAtom], tma_atom_dPsum: Optional[cute.CopyAtom], - tma_atom_dO: Optional[cute.CopyAtom], - - mdK: cute.Tensor, - mdV: cute.Tensor, + tma_atom_dO: Optional[cute.CopyAtom], + mdK: cute.Tensor, + mdV: cute.Tensor, mdQaccum: cute.Tensor, - - sQ_layout: cute.ComposedLayout, - sK_layout: cute.ComposedLayout, - sV_layout: cute.ComposedLayout, - sPdS_layout: cute.ComposedLayout, - sdO_layout: cute.ComposedLayout, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, - r2s_tiled_copy_dQaccum: cute.TiledCopy, - - tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, tiled_mma_dQaccum: cute.TiledMma, - softmax_scale_log2, softmax_scale, tile_sched_params: ParamsBase, - TileScheduler: cutlass.Constexpr[Callable], - SharedStorage: cutlass.Constexpr[Callable], + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx = cute.arch.thread_idx()[0] + tidx = cute.arch.thread_idx()[0] # prefetch TMA descriptors if warp_idx == 0: @@ -513,7 +524,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_dPsum) cpasync.prefetch_descriptor(tma_atom_dO) - smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -526,7 +536,9 @@ def kernel( cute.arch.mbarrier_init(mbar_ptr_V, 1) pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) - pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group + ) pipeline_q = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), @@ -560,32 +572,38 @@ def kernel( tx_count=self.tma_copy_do_bytes, init_wait=False, ) - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sQt = utils.transpose_view(sQ) - sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - - sLSE_load = storage.sLSE.get_tensor(cute.make_layout( - (self.m_block_size, self.num_stages), - stride=(1, cute.round_up(self.m_block_size, 64)) - )) - sLSE_mma = storage.sLSE.get_tensor(cute.make_layout( - (self.m_block_size, self.n_block_size, self.num_stages), - stride=(1, 0, cute.round_up(self.m_block_size, 64)) - )) - sdPsum_load = storage.sdPsum.get_tensor(cute.make_layout( - (self.m_block_size, self.num_stages), - stride=(1, cute.round_up(self.m_block_size, 64)) - )) - sdPsum_mma = storage.sdPsum.get_tensor(cute.make_layout( - (self.m_block_size, self.n_block_size, self.num_stages), - stride=(1, 0, cute.round_up(self.m_block_size, 64)) - )) - - sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sLSE_load = storage.sLSE.get_tensor( + cute.make_layout( + (self.m_block_size, self.num_stages), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) + ) + sLSE_mma = storage.sLSE.get_tensor( + cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages), + stride=(1, 0, cute.round_up(self.m_block_size, 64)), + ) + ) + sdPsum_load = storage.sdPsum.get_tensor( + cute.make_layout( + (self.m_block_size, self.num_stages), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) + ) + sdPsum_mma = storage.sdPsum.get_tensor( + cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages), + stride=(1, 0, cute.round_up(self.m_block_size, 64)), + ) + ) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) sP = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sPt = utils.transpose_view(sP) @@ -593,23 +611,33 @@ def kernel( sdS = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sdSt = utils.transpose_view(sdS) - sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) sdOt = utils.transpose_view(sdO) - - block_info = BlockInfo(self.m_block_size, self.n_block_size, False, False,None, None, qhead_per_kvhead_packgqa=1,) + block_info = BlockInfo( + self.m_block_size, + self.n_block_size, + False, + False, + None, + None, + qhead_per_kvhead_packgqa=1, + ) SeqlenInfoCls = partial( - SeqlenInfoQK, seqlen_q_static=mQ.shape[0], + SeqlenInfoQK, + seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=None, mCuSeqlensK=None, - mSeqUsedQ=None, mSeqUsedK=None + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) - if warp_idx < 4: + if warp_idx < 4: cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) - if warp_idx == 0: + if warp_idx == 0: self.load( mQ, mK, @@ -617,34 +645,32 @@ def kernel( mLSE, mdPsum, mdO, - sQ, sK, sV, sLSE_load, sdPsum_load, sdO, - tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_LSE, tma_atom_dPsum, tma_atom_dO, - pipeline_q, pipeline_lse, pipeline_dpsum, pipeline_do, - mbar_ptr_K, mbar_ptr_V, - SeqlenInfoCls, TileSchedulerCls, ) if warp_idx == 1: - cute.arch.barrier_arrive(barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmpty), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, + ) self.dQaccum_writer( mdQaccum, sdQaccum, @@ -654,124 +680,110 @@ def kernel( else: cute.arch.warpgroup_reg_alloc(self.num_mma_regs) tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - 128 + tidx = tidx - 128 self.mma( tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum, - mdK, mdV, mdQaccum, - sQ, sQt, sK, sV, - sP, sPt, - sdS, sdSt, - sdO, sdOt, - sLSE_mma, sdPsum_mma, - sdQaccum, - pipeline_q, pipeline_lse, pipeline_dpsum, pipeline_do, - mbar_ptr_K, mbar_ptr_V, tidx, gmem_tiled_copy_dV, gmem_tiled_copy_dK, r2s_tiled_copy_dQaccum, - softmax_scale_log2, softmax_scale, - block_info, SeqlenInfoCls, TileSchedulerCls, ) - @cute.jit def load( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mLSE: cute.Tensor, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdO: cute.Tensor, - - sQ: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - sLSE: cute.Tensor, + mdO: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, sdPsum: cute.Tensor, - sdO: cute.Tensor, - + sdO: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - - tma_atom_LSE: cute.CopyAtom, + tma_atom_LSE: cute.CopyAtom, tma_atom_dPsum: cute.CopyAtom, - tma_atom_dO: cute.CopyAtom, - - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, + tma_atom_dO: cute.CopyAtom, + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, pipeline_dpsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, - + pipeline_dO: cutlass.pipeline.PipelineAsync, mbar_ptr_K: cutlass.Pointer, mbar_ptr_V: cutlass.Pointer, - - SeqlenInfoCls: Callable, + SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 if warp_idx_in_wg == 0: - producer_state = pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.num_stages) - + producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mK_cur = mK[None, None, head_idx, batch_idx] - gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + gK = cute.local_tile( + mK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0) + ) mV_cur = mV[None, None, head_idx, batch_idx] - gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + gV = cute.local_tile( + mV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0) + ) mQ_cur = mQ[None, None, head_idx, batch_idx] - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) + gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) mLSE_cur = mLSE[None, head_idx, batch_idx] - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) mdPsum_cur = mdPsum[None, head_idx, batch_idx] - gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) mdO_cur = mdO[None, None, head_idx, batch_idx] - gdO = cute.local_tile(mdO_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) + gdO = cute.local_tile(mdO_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) tQsQ, tQgQ = cpasync.tma_partition( tma_atom_Q, @@ -816,10 +828,12 @@ def load( cute.group_modes(gdO, 0, 2), ) - load_Q = partial(self.load_m_tile, tma_atom_Q, tQgQ, tQsQ, pipeline_q) - load_LSE = partial(self.load_m_tile, tma_atom_LSE, tLSEgLSE, tLSEsLSE, pipeline_lse) - load_dPsum = partial(self.load_m_tile, tma_atom_dPsum, tdPsumgdPsum, tdPsumsdPsum, pipeline_dpsum) - load_dO = partial(self.load_m_tile, tma_atom_dO, tdOgdO, tdOsdO, pipeline_dO) + load_Q = partial(self.load_m_tile, tma_atom_Q, tQgQ, tQsQ, pipeline_q) + load_LSE = partial(self.load_m_tile, tma_atom_LSE, tLSEgLSE, tLSEsLSE, pipeline_lse) + load_dPsum = partial( + self.load_m_tile, tma_atom_dPsum, tdPsumgdPsum, tdPsumsdPsum, pipeline_dpsum + ) + load_dO = partial(self.load_m_tile, tma_atom_dO, tdOgdO, tdOsdO, pipeline_dO) with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_K, self.tma_copy_k_bytes) @@ -833,10 +847,10 @@ def load( for i in cutlass.range(m_block_max - m_block_min, unroll=2): m_block = m_block_max - i - 1 - load_Q(m_block, producer_state=producer_state) - load_LSE(m_block, producer_state=producer_state) + load_Q(m_block, producer_state=producer_state) + load_LSE(m_block, producer_state=producer_state) load_dPsum(m_block, producer_state=producer_state) - load_dO(m_block, producer_state=producer_state) + load_dO(m_block, producer_state=producer_state) producer_state.advance() @@ -844,133 +858,147 @@ def load( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit def mma( self, - tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, - tiled_mma_dQaccum: cute.TiledMma, - - mdK: cute.Tensor, - mdV: cute.Tensor, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, + tiled_mma_dQaccum: cute.TiledMma, + mdK: cute.Tensor, + mdV: cute.Tensor, mdQaccum: cute.Tensor, - - sQ: cute.Tensor, - sQt: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - - sP: cute.Tensor, - sPt: cute.Tensor, - - sdS: cute.Tensor, + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sP: cute.Tensor, + sPt: cute.Tensor, + sdS: cute.Tensor, sdSt: cute.Tensor, - - sdO: cute.Tensor, + sdO: cute.Tensor, sdOt: cute.Tensor, - - sLSE_mma: cute.Tensor, + sLSE_mma: cute.Tensor, sdPsum_mma: cute.Tensor, - - sdQaccum: cute.Tensor, - - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, + sdQaccum: cute.Tensor, + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, pipeline_dPsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, - + pipeline_dO: cutlass.pipeline.PipelineAsync, mbar_ptr_K: cutlass.Pointer, mbar_ptr_V: cutlass.Pointer, - tidx: cutlass.Int32, - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, r2s_tiled_copy_dQaccum: cute.TiledCopy, - softmax_scale_log2: cutlass.Float32, - softmax_scale: cutlass.Float32, - + softmax_scale: cutlass.Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - warp_group_thread_layout = cute.make_layout(self.num_mma_warp_groups, stride=self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) - wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dQaccum = tiled_mma_dQaccum.get_slice(warp_group_thread_layout(warp_group_idx)) smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(tidx) + smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( + tidx + ) # S = Q @ K.T - tSrQ = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sQ)) - tSrK = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sK)) + tSrQ = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sQ)) + tSrK = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sK)) # dP = dO @ V.T tdPrdO = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sdO)) - tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) + tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) # P = exp(S-LSE) tPsP = smem_thr_copy_PdS.partition_D(sP) LSEslice = (None, 0, None) - tLSEsLSE_2D = utils.make_acc_tensor_mn_view(tiled_mma_SdP.get_slice(tidx).partition_C(sLSE_mma))[LSEslice] + tLSEsLSE_2D = utils.make_acc_tensor_mn_view( + tiled_mma_SdP.get_slice(tidx).partition_C(sLSE_mma) + )[LSEslice] # dS = P*(dP-dPsum) tdSsdS = smem_thr_copy_PdS.partition_D(sdS) dPsumslice = (None, 0, None) - tdPsumsdPsum_2D = utils.make_acc_tensor_mn_view(tiled_mma_SdP.get_slice(tidx).partition_C(sdPsum_mma))[dPsumslice] + tdPsumsdPsum_2D = utils.make_acc_tensor_mn_view( + tiled_mma_SdP.get_slice(tidx).partition_C(sdPsum_mma) + )[dPsumslice] # dV += P.T @ dO - tdVrPt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sPt)) + tdVrPt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sPt)) tdVrdOt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sdOt)) # dK += dS.T @ Q - tdKrdSt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sdSt)) - tdKrQt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sQt)) + tdKrdSt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sdSt)) + tdKrQt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sQt)) # dQ = dS @ K sKt = utils.transpose_view(sK) tdQaccumrdS = tiled_mma_dQaccum.make_fragment_A(wg_mma_dQaccum.partition_A(sdS)) - tdQaccumrK = tiled_mma_dQaccum.make_fragment_B(wg_mma_dQaccum.partition_B(sKt)) - + tdQaccumrK = tiled_mma_dQaccum.make_fragment_B(wg_mma_dQaccum.partition_B(sKt)) smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) - tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) acc_dV = cute.make_fragment( tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), - cutlass.Float32 + cutlass.Float32, ) acc_dK = cute.make_fragment( tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), - cutlass.Float32 + cutlass.Float32, ) acc_dV.fill(0.0) acc_dK.fill(0.0) - mma_one_m_block_all = partial(self.mma_one_m_block, - tiled_mma_SdP=tiled_mma_SdP, tiled_mma_dKV=tiled_mma_dKV, tiled_mma_dQaccum=tiled_mma_dQaccum, - pipeline_q=pipeline_q, pipeline_lse=pipeline_lse, - pipeline_dPsum=pipeline_dPsum, pipeline_dO=pipeline_dO, - tLSEsLSE_2D=tLSEsLSE_2D, tdPsumsdPsum_2D=tdPsumsdPsum_2D, sP=sP, sdS=sdS, sdQaccum=sdQaccum, acc_dV=acc_dV, acc_dK=acc_dK, - tSrQ=tSrQ, tSrK=tSrK, - tPsP=tPsP, tdSsdS=tdSsdS, - tdVrPt=tdVrPt, tdVrdOt=tdVrdOt, - tdKrdSt=tdKrdSt, tdKrQt=tdKrQt, - tdPrdO=tdPrdO, tdPrV=tdPrV, - tdQaccumrdS=tdQaccumrdS, tdQaccumrK=tdQaccumrK, tdQaccumsdQaccum=tdQaccumsdQaccum, - smem_thr_copy_PdS=smem_thr_copy_PdS, - smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, - ) + mma_one_m_block_all = partial( + self.mma_one_m_block, + tiled_mma_SdP=tiled_mma_SdP, + tiled_mma_dKV=tiled_mma_dKV, + tiled_mma_dQaccum=tiled_mma_dQaccum, + pipeline_q=pipeline_q, + pipeline_lse=pipeline_lse, + pipeline_dPsum=pipeline_dPsum, + pipeline_dO=pipeline_dO, + tLSEsLSE_2D=tLSEsLSE_2D, + tdPsumsdPsum_2D=tdPsumsdPsum_2D, + sP=sP, + sdS=sdS, + sdQaccum=sdQaccum, + acc_dV=acc_dV, + acc_dK=acc_dK, + tSrQ=tSrQ, + tSrK=tSrK, + tPsP=tPsP, + tdSsdS=tdSsdS, + tdVrPt=tdVrPt, + tdVrdOt=tdVrdOt, + tdKrdSt=tdKrdSt, + tdKrQt=tdKrQt, + tdPrdO=tdPrdO, + tdPrV=tdPrV, + tdQaccumrdS=tdQaccumrdS, + tdQaccumrK=tdQaccumrK, + tdQaccumsdQaccum=tdQaccumsdQaccum, + smem_thr_copy_PdS=smem_thr_copy_PdS, + smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, + ) KV_consumer_phase = cutlass.Int32(0) - consumer_state = pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.num_stages) + consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -999,22 +1027,29 @@ def mma( softmax_scale_log2=softmax_scale_log2, ) - #scale dK + # scale dK acc_dK.store(acc_dK.load() * softmax_scale) self.epilogue_dKV( - acc_dV, mdV, sV, - acc_dK, mdK, sK, + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, seqlen, - gmem_tiled_copy_dV, gmem_tiled_copy_dK, + gmem_tiled_copy_dV, + gmem_tiled_copy_dK, tiled_mma_dKV, - tidx, n_block, head_idx, batch_idx, + tidx, + n_block, + head_idx, + batch_idx, ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit def mma_one_m_block( self, @@ -1023,65 +1058,51 @@ def mma_one_m_block( m_block: cutlass.Int32, head_idx: cutlass.Int32, batch_idx: cutlass.Int32, - - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - - tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dKV: cute.TiledMma, tiled_mma_dQaccum: cute.TiledMma, - - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, + pipeline_q: cutlass.pipeline.PipelineAsync, + pipeline_lse: cutlass.pipeline.PipelineAsync, pipeline_dPsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, - - tLSEsLSE_2D: cute.Tensor, + pipeline_dO: cutlass.pipeline.PipelineAsync, + tLSEsLSE_2D: cute.Tensor, tdPsumsdPsum_2D: cute.Tensor, - sP: Optional[cute.Tensor], - sdS: Optional[cute.Tensor], - sdQaccum: cute.Tensor, - - acc_dV: cute.Tensor, - acc_dK: cute.Tensor, - - + sP: Optional[cute.Tensor], + sdS: Optional[cute.Tensor], + sdQaccum: cute.Tensor, + acc_dV: cute.Tensor, + acc_dK: cute.Tensor, tSrQ: cute.Tensor, tSrK: cute.Tensor, - - tPsP: Optional[cute.Tensor], + tPsP: Optional[cute.Tensor], tdSsdS: Optional[cute.Tensor], - - tdVrPt: cute.Tensor, + tdVrPt: cute.Tensor, tdVrdOt: cute.Tensor, - tdKrdSt: cute.Tensor, - tdKrQt: cute.Tensor, - - tdPrdO: cute.Tensor, - tdPrV: cute.Tensor, + tdKrQt: cute.Tensor, + tdPrdO: cute.Tensor, + tdPrV: cute.Tensor, tdQaccumrdS: cute.Tensor, - tdQaccumrK: cute.Tensor, + tdQaccumrK: cute.Tensor, tdQaccumsdQaccum: cute.Tensor, - - smem_thr_copy_PdS: cute.TiledCopy, + smem_thr_copy_PdS: cute.TiledCopy, smem_thr_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: cutlass.Float32 = 1.0, ): - - # (1) [GEMM 1] S = Q @ K^T pipeline_q.consumer_wait(smem_pipe_read, pipeline_q.consumer_try_wait(smem_pipe_read)) acc_S = cute.make_fragment( - tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), - cutlass.Float32 + tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 ) sm90_utils.gemm( - tiled_mma_SdP, acc_S, + tiled_mma_SdP, + acc_S, tSrQ[None, None, None, smem_pipe_read.index], tSrK, zero_init=True, - wg_wait=0 + wg_wait=0, ) # (2) [Pointwise 1] P = exp(S - LSE) @@ -1092,7 +1113,9 @@ def mma_one_m_block( acc_P_mn = utils.make_acc_tensor_mn_view(acc_S) for r in cutlass.range_constexpr(cute.size(acc_P_mn, mode=[0])): - acc_P_mn[r, None].store(cute.exp2(acc_P_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) + acc_P_mn[r, None].store( + cute.exp2(acc_P_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]) + ) # fp32->bf16 tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) @@ -1102,51 +1125,60 @@ def mma_one_m_block( # cp: rmem->smem tPrP = smem_thr_copy_PdS.retile(tdVrP) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) cute.copy(smem_thr_copy_PdS, tPrP, tPsP) - - ''' + """ if warp_group_idx == 0 and cute.arch.thread_idx()[0] == 128 and m_block == 0 and n_block == 0 and head_idx == 0 and batch_idx == 0: for j in cutlass.range_constexpr(16): cute.printf("%.15f", tPrP[j].to(cutlass.Float32)) - ''' + """ - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) pipeline_lse.consumer_release(smem_pipe_read) - # (3) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait(smem_pipe_read, pipeline_dO.consumer_try_wait(smem_pipe_read)) acc_dP = cute.make_fragment( - tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), - cutlass.Float32 + tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 ) sm90_utils.gemm( - tiled_mma_SdP, acc_dP, + tiled_mma_SdP, + acc_dP, tdPrdO[None, None, None, smem_pipe_read.index], tdPrV, zero_init=True, - wg_wait=-0 + wg_wait=-0, ) # (4) [GEMM 3] dV += P.T @ dO sm90_utils.gemm( - tiled_mma_dKV, acc_dV, + tiled_mma_dKV, + acc_dV, tdVrPt, tdVrdOt[None, None, None, smem_pipe_read.index], zero_init=False, - wg_wait=0 + wg_wait=0, ) pipeline_dO.consumer_release(smem_pipe_read) # (4) [Pointwise 2] dS = P*(dP-dPsum) - pipeline_dPsum.consumer_wait(smem_pipe_read, pipeline_dPsum.consumer_try_wait(smem_pipe_read)) + pipeline_dPsum.consumer_wait( + smem_pipe_read, pipeline_dPsum.consumer_try_wait(smem_pipe_read) + ) # dPsum tdPsumrdPsum = cute.make_fragment_like(tdPsumsdPsum_2D[None, 0]) @@ -1155,8 +1187,8 @@ def mma_one_m_block( acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): acc_dP_mn[r, None].store( - acc_P_mn[r, None].load() * (acc_dP_mn[r, None].load() - tdPsumrdPsum[r]) - ) + acc_P_mn[r, None].load() * (acc_dP_mn[r, None].load() - tdPsumrdPsum[r]) + ) # fp32->bf16 tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) @@ -1165,151 +1197,169 @@ def mma_one_m_block( tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) pipeline_dPsum.consumer_release(smem_pipe_read) - - # (6) [GEMM 4] dQ = dS @ K acc_dQ = cute.make_fragment( tiled_mma_dQaccum.partition_shape_C((self.m_block_size, self.head_dim_padded)), - cutlass.Float32 + cutlass.Float32, + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads ) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) sm90_utils.gemm( - tiled_mma_dQaccum, acc_dQ, - tdQaccumrdS, - tdQaccumrK, - zero_init=True, - wg_wait=0 + tiled_mma_dQaccum, acc_dQ, tdQaccumrdS, tdQaccumrK, zero_init=True, wg_wait=0 ) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads) - cute.arch.barrier(barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQEmpty), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, + ) - tdQaccumrdQaccum_tmp = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQaccumsdQaccum.shape)) + tdQaccumrdQaccum_tmp = cute.make_tensor( + acc_dQ.iterator, cute.make_layout(tdQaccumsdQaccum.shape) + ) cute.copy(smem_thr_copy_dQaccum, tdQaccumrdQaccum_tmp, tdQaccumsdQaccum) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier_arrive(barrier_id=int(NamedBarrierBwd.dQFull), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQFull), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, + ) # (7) [GEMM 5] dK += dS.T @ Q sm90_utils.gemm( - tiled_mma_dKV, acc_dK, + tiled_mma_dKV, + acc_dK, tdKrdSt, tdKrQt[None, None, None, smem_pipe_read.index], zero_init=False, - wg_wait=0 + wg_wait=0, ) pipeline_q.consumer_release(smem_pipe_read) smem_pipe_read.advance() return smem_pipe_read - @cute.jit def epilogue_dKV( - self, - acc_dV: cute.Tensor, - mdV: cute.Tensor, - sV: cute.Tensor, - - acc_dK: cute.Tensor, - mdK: cute.Tensor, - sK: cute.Tensor, - - - seqlen: SeqlenInfoQK, - - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, - - tiled_mma_dKV: cute.TiledMma, - - tidx: cutlass.Int32, - n_block: cutlass.Int32, - head_idx: cutlass.Int32, - batch_idx: cutlass.Int32 - ): - - ### RMEM --> SMEM - rdV = cute.make_fragment_like(acc_dV, self.dtype) - rdV.store(acc_dV.load().to(self.dtype)) - - rdK = cute.make_fragment_like(acc_dK, self.dtype) - rdK.store(acc_dK.load().to(self.dtype)) - - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads) - - - smem_copy_atom_dKV = cute.make_copy_atom(cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,) - smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dKV).get_slice(tidx) - - - taccdVrdV = smem_thr_copy_dKV.retile(rdV) - taccdVsdV = smem_thr_copy_dKV.partition_D(sV) # reuse sV SMEM - cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - - taccdKrdK = smem_thr_copy_dKV.retile(rdK) - taccdKsdK = smem_thr_copy_dKV.partition_D(sK) # reuse sK SMEM - cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) - - - # SMEM -> GMEM - cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) - mdV_cur = mdV[None, None, head_idx, batch_idx] - - cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) - mdK_cur = mdK[None, None, head_idx, batch_idx] - - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads) - gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) - gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) + self, + acc_dV: cute.Tensor, + mdV: cute.Tensor, + sV: cute.Tensor, + acc_dK: cute.Tensor, + mdK: cute.Tensor, + sK: cute.Tensor, + seqlen: SeqlenInfoQK, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + tiled_mma_dKV: cute.TiledMma, + tidx: cutlass.Int32, + n_block: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + ): + ### RMEM --> SMEM + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) - tdVsdV = gmem_thr_copy_dV.partition_S(sV) - tdVrdV = cute.make_fragment_like(tdVsdV, self.dtype) - cute.autovec_copy(tdVsdV, tdVrdV) + rdK = cute.make_fragment_like(acc_dK, self.dtype) + rdK.store(acc_dK.load().to(self.dtype)) - tdKsdK = gmem_thr_copy_dK.partition_S(sK) - tdKrdK = cute.make_fragment_like(tdKsdK, self.dtype) - cute.autovec_copy(tdKsdK, tdKrdK) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) - gdV = cute.local_tile(mdV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) - tdVgdV = gmem_thr_copy_dV.partition_D(gdV) + smem_copy_atom_dKV = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, + ) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dKV).get_slice( + tidx + ) - gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) - tdKgdK = gmem_thr_copy_dK.partition_D(gdK) + taccdVrdV = smem_thr_copy_dKV.retile(rdV) + taccdVsdV = smem_thr_copy_dKV.partition_D(sV) # reuse sV SMEM + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - tdVcdV = gmem_thr_copy_dV.partition_S(cdV) - t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) - tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[1]) + taccdKrdK = smem_thr_copy_dKV.retile(rdK) + taccdKsdK = smem_thr_copy_dKV.partition_D(sK) # reuse sK SMEM + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) - tdKcdK = gmem_thr_copy_dK.partition_S(cdK) - tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[1]) + # SMEM -> GMEM + cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + mdV_cur = mdV[None, None, head_idx, batch_idx] - for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): - row_idx = n_block * self.n_block_size + t0dVcdV[0, rest_m, 0][0] - if row_idx < seqlen.seqlen_k: - cute.copy( - gmem_tiled_copy_dV, - tdVrdV[None, rest_m, None], - tdVgdV[None, rest_m, None], - pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, - ) - cute.copy( - gmem_tiled_copy_dK, - tdKrdK[None, rest_m, None], - tdKgdK[None, rest_m, None], - pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, - ) + cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + mdK_cur = mdK[None, None, head_idx, batch_idx] + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) + + tdVsdV = gmem_thr_copy_dV.partition_S(sV) + tdVrdV = cute.make_fragment_like(tdVsdV, self.dtype) + cute.autovec_copy(tdVsdV, tdVrdV) + + tdKsdK = gmem_thr_copy_dK.partition_S(sK) + tdKrdK = cute.make_fragment_like(tdKsdK, self.dtype) + cute.autovec_copy(tdKsdK, tdKrdK) + + gdV = cute.local_tile(mdV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + tdVgdV = gmem_thr_copy_dV.partition_D(gdV) + + gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + tdKgdK = gmem_thr_copy_dK.partition_D(gdK) + + tdVcdV = gmem_thr_copy_dV.partition_S(cdV) + t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) + tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[1]) + + tdKcdK = gmem_thr_copy_dK.partition_S(cdK) + tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[1]) + + for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): + row_idx = n_block * self.n_block_size + t0dVcdV[0, rest_m, 0][0] + if row_idx < seqlen.seqlen_k: + cute.copy( + gmem_tiled_copy_dV, + tdVrdV[None, rest_m, None], + tdVgdV[None, rest_m, None], + pred=tdVpdV[None, rest_m, None] + if cutlass.const_expr(self.check_hdim_v_oob) + else None, + ) + cute.copy( + gmem_tiled_copy_dK, + tdKrdK[None, rest_m, None], + tdKgdK[None, rest_m, None], + pred=tdKpdK[None, rest_m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, + ) @cute.jit def dQaccum_writer( @@ -1317,14 +1367,13 @@ def dQaccum_writer( mdQaccum: cute.Tensor, sdQaccum: cute.Tensor, TileSchedulerCls: cutlass.Constexpr[Callable], - SeqlenInfoCls: cutlass.Constexpr[Callable], + SeqlenInfoCls: cutlass.Constexpr[Callable], ): - tile_elems = cute.cosize(sdQaccum.layout) tile_bytes = cutlass.Int32(tile_elems * 4) tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() + work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx @@ -1333,60 +1382,52 @@ def dQaccum_writer( # GMEM mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - base_flat = cute.domain_offset( - (seqlen.offset_q * self.head_dim_padded, ), - mdQaccum_cur - ) + base_flat = cute.domain_offset((seqlen.offset_q * self.head_dim_padded,), mdQaccum_cur) m_block_min = cutlass.Int32(0) m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) for it_m in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block = m_block_max -1 - it_m + m_block = m_block_max - 1 - it_m cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFull), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - gdQaccum_block = cute.local_tile( - base_flat, - (tile_elems, ), - (m_block, ) - ) + gdQaccum_block = cute.local_tile(base_flat, (tile_elems,), (m_block,)) with cute.arch.elect_one(): sm90_utils.tma_reduce_add_bulk_f32( - sdQaccum.iterator, - gdQaccum_block.iterator, - tile_bytes, - ) + sdQaccum.iterator, + gdQaccum_block.iterator, + tile_bytes, + ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmpty), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE + barrier_id=int(NamedBarrierBwd.dQEmpty), + number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit def load_m_tile( - self, - tma_atom: cute.CopyAtom, - tXgX: cute.Tensor, - tXsX: cute.Tensor, - pipeline: cutlass.pipeline.PipelineAsync, - block: cutlass.Int32, - producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + pipeline: cutlass.pipeline.PipelineAsync, + block: cutlass.Int32, + producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, ): pipeline.producer_acquire(producer_state) cute.copy( tma_atom, tXgX[None, block], tXsX[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), ) From 42e4e3e88ea0846cecf225c4ceb1edaaea621d25 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 20:17:44 -0400 Subject: [PATCH 120/258] [Cute,Bwd,Sm90] Fix bwd dK & dV, more async --- flash_attn/cute/block_info.py | 73 +- flash_attn/cute/copy_utils.py | 12 +- flash_attn/cute/flash_bwd_postprocess.py | 5 +- flash_attn/cute/flash_bwd_sm90.py | 1062 +++++++++------------- flash_attn/cute/flash_fwd.py | 10 +- flash_attn/cute/interface.py | 30 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 2 +- 7 files changed, 487 insertions(+), 707 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 50e6371dda3..9e911fdd581 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -4,89 +4,88 @@ import cutlass import cutlass.cute as cute +from cutlass import Int32, const_expr from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass(frozen=True) class BlockInfo: - m_block_size: cutlass.Constexpr[int] - n_block_size: cutlass.Constexpr[int] + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] is_causal: cutlass.Constexpr[bool] is_local: cutlass.Constexpr[bool] = False - window_size_left: Optional[cutlass.Int32] = None - window_size_right: Optional[cutlass.Int32] = None + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit def get_n_block_min_max( - self, seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32 - ) -> Tuple[cutlass.Int32, cutlass.Int32]: - n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) - if cutlass.const_expr( + self, seqlen_info: SeqlenInfoQK, m_block: Int32 + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr( self.is_causal or (self.is_local and self.window_size_right is not None) ): - m_idx_max = (m_block + 1) * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q - n_idx_right = n_idx if cutlass.const_expr(self.is_causal) else n_idx + self.window_size_right - n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.n_block_size)) + n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) n_block_min = 0 - if cutlass.const_expr(self.is_local and self.window_size_left is not None): - m_idx_min = m_block * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + if const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left - n_block_min = cutlass.max(n_idx_left // self.n_block_size, 0) + n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) return n_block_min, n_block_max @cute.jit def get_m_block_min_max( - self, seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32 - ) -> Tuple[cutlass.Int32, cutlass.Int32]: - m_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.m_block_size) - + self, seqlen_info: SeqlenInfoQK, n_block: Int32 + ) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 - + if const_expr(self.is_causal): + m_block_min = max(m_block_min, cute.ceil_div(seqlen_info.seqlen_q - seqlen_info.seqlen_k + (n_block + 1) * self.tile_n, self.tile_m)) return m_block_min, m_block_max - - @cute.jit def get_n_block_min_causal_local_mask( self, seqlen_info: SeqlenInfoQK, - m_block: cutlass.Int32, - n_block_min: cutlass.Int32, - ) -> cutlass.Int32: + m_block: Int32, + n_block_min: Int32, + ) -> Int32: """If we have separate iterations with causal or local masking at the start, where do we stop""" - m_idx_min = m_block * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_right = ( n_idx - if cutlass.const_expr(not self.is_local or self.window_size_right is None) + if const_expr(not self.is_local or self.window_size_right is None) else n_idx + self.window_size_right ) - return cutlass.max(n_block_min, n_idx_right // self.n_block_size) + return cutlass.max(n_block_min, n_idx_right // self.tile_n) @cute.jit def get_n_block_min_before_local_mask( self, seqlen_info: SeqlenInfoQK, - m_block: cutlass.Int32, - n_block_min: cutlass.Int32, - ) -> cutlass.Int32: + m_block: Int32, + n_block_min: Int32, + ) -> Int32: """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" - if cutlass.const_expr(not self.is_local or self.window_size_left is None): + if const_expr(not self.is_local or self.window_size_left is None): return n_block_min else: - m_idx_max = (m_block + 1) * self.m_block_size - if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left - return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.n_block_size)) + return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 9ac20207444..822cdde2a4f 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -91,6 +91,7 @@ def tma_get_copy_fn( src_tensor: cute.Tensor, dst_tensor: cute.Tensor, filter_zeros: bool = False, + single_stage: bool = False, **kwargs, ) -> Callable: src_is_smem = const_expr( @@ -98,13 +99,15 @@ def tma_get_copy_fn( and src_tensor.memspace == cute.AddressSpace.smem ) smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) s, g = cpasync.tma_partition( atom, cta_coord, cta_layout, - cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1), - cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1), + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), ) if const_expr(filter_zeros): s = cute.filter_zeros(s) @@ -114,7 +117,10 @@ def tma_get_copy_fn( def copy_tma(src_idx, dst_idx, **new_kwargs): cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) - return copy_tma, s, g + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index ddad08beb5b..0abe36d39c3 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -358,6 +358,9 @@ def __call__( scale: cutlass.Float32, stream: cuda.CUstream, ): + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mdQaccum, mdQ = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mdQaccum, mdQ)] mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1,3,2,0])) mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2,1,0])) @@ -369,7 +372,7 @@ def __call__( warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), + atom_layout_mnk=(self.m_block_size // 64, 2, 1), tiler_mn=(64, self.head_dim_padded) ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index fc6b6c7a414..d391f9f4bf9 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -6,14 +6,14 @@ import cutlass import cutlass.cute as cute -from cutlass.cute.nvgpu import cpasync, warpgroup - -# import cutlass.pipeline import cutlass.utils.hopper_helpers as sm90_utils_basic -from cutlass import const_expr +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass import Float32, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils +from flash_attn.cute import copy_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline @@ -21,6 +21,37 @@ from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +def mma_zero_init( + tiled_mma: cute.TiledMma, + shape: cute.Shape, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, +) -> cute.Tensor: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc + + +def mma_sm90( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: Boolean, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + + class FlashAttentionBackwardSm90: arch = 90 @@ -30,8 +61,8 @@ def __init__( head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, - m_block_size: int = 64, - n_block_size: int = 128, + tile_m: int = 64, + tile_n: int = 128, num_stages: int = 2, num_threads: int = 384, Q_in_regs: bool = False, @@ -39,18 +70,19 @@ def __init__( self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v - self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication - self.check_hdim_oob = head_dim != self.head_dim_padded - self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead - self.m_block_size = m_block_size - self.n_block_size = n_block_size + self.tile_m = tile_m + self.tile_n = tile_n self.num_threads = num_threads self.num_stages = num_stages + self.dS_stage = 2 self.Q_in_regs = Q_in_regs @staticmethod @@ -58,8 +90,8 @@ def can_implement( dtype, head_dim, head_dim_v, - m_block_size, - n_block_size, + tile_m, + tile_n, num_stages, num_threads, Q_in_regs=False, @@ -70,12 +102,12 @@ def can_implement( return False if head_dim_v % 8 != 0: return False - if n_block_size % 16 != 0: + if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False - if (m_block_size * 2) % num_threads != 0: + if (tile_m * 2) % num_threads != 0: return False return True @@ -96,159 +128,93 @@ def _check_type( raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if const_expr(mLSE_type not in [cutlass.Float32]): + if const_expr(mLSE_type not in [Float32]): raise TypeError("LSE tensor must be Float32") - if const_expr(mdPsum_type not in [cutlass.Float32]): + if const_expr(mdPsum_type not in [Float32]): raise TypeError("dPsum tensor must be Float32") - if const_expr(mdQaccum_type not in [cutlass.Float32]): + if const_expr(mdQaccum_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") if const_expr(self.qhead_per_kvhead == 1): if const_expr(not (mdK_type == mdV_type == mQ_type)): raise TypeError("mdK and mdV tensors must have the same data type as mQ") else: - if const_expr(not (mdK_type == mdV_type == cutlass.Float32)): + if const_expr(not (mdK_type == mdV_type == Float32)): raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") assert mQ_type == self.dtype - def _get_smem_layout_atom(self): - sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded - ), - self.dtype, - ) - sK_layout_atom = sQ_layout_atom - - sV_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded - ), - self.dtype, - ) - sPdS_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size - ), - self.dtype, - ) - sdO_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded - ), - self.dtype, - ) - - return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom - def _setup_attributes(self): - sQ_layout_atom, sK_layout_atom, sV_layout_atom, sPdS_layout_atom, sdO_layout_atom = ( - self._get_smem_layout_atom() - ) - - universal_copy_bits = 128 - async_copy_elems = universal_copy_bits // self.dtype.width - - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.dtype, - num_bits_per_copy=universal_copy_bits, - ) - - self.sQ_layout = cute.tile_to_shape( - sQ_layout_atom, - (self.m_block_size, self.head_dim_padded, self.num_stages), - (0, 1, 2), - ) - self.sK_layout = cute.tile_to_shape( - sK_layout_atom, - (self.n_block_size, self.head_dim_padded), - (0, 1), - ) - self.sV_layout = cute.tile_to_shape( - sV_layout_atom, - (self.n_block_size, self.head_dim_v_padded), - (0, 1), - ) - self.sdO_layout = cute.tile_to_shape( - sdO_layout_atom, - (self.m_block_size, self.head_dim_padded, self.num_stages), - (0, 1, 2), - ) + self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ + sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) + for shape, stage in [ + ((self.tile_m, self.tile_hdim), self.num_stages), + ((self.tile_n, self.tile_hdim), None), + ((self.tile_n, self.tile_hdimv), None), + ((self.tile_m, self.tile_hdimv), self.num_stages), + ((self.tile_m, self.tile_n), self.dS_stage), + ] + ] - self.sPdS_layout = cute.tile_to_shape( - sPdS_layout_atom, - (self.m_block_size, self.n_block_size), - (0, 1), - ) - self.sdQaccum_layout = cute.make_layout( - shape=(self.m_block_size * self.head_dim_padded,), - ) + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) # dQaccum R->S - self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=universal_copy_bits - ), - cute.make_layout(self.num_mma_threads), - cute.make_layout(universal_copy_bits // cutlass.Float32.width), + self.r2s_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_mma_threads, num_copy_elems=128 // Float32.width ) - # dV: S->G - tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems - tdV_layout = cute.make_ordered_layout( - (self.num_mma_threads // tV_shape_dim_1, tV_shape_dim_1), - order=(1, 0), - ) - self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv( - atom_universal_copy, tdV_layout, cute.make_layout((1, async_copy_elems)) + tV_shape_dim_1 = self.sV_layout.outer.shape[1][0] + self.gmem_tiled_copy_dV = copy_utils.tiled_copy_2d( + self.dtype, tV_shape_dim_1, self.num_mma_threads ) - # dK: S->G - tK_shape_dim_1 = sK_layout_atom.outer.shape[1] // async_copy_elems - tdK_layout = cute.make_ordered_layout( - (self.num_mma_threads // tK_shape_dim_1, tK_shape_dim_1), - order=(1, 0), - ) - self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv( - atom_universal_copy, tdK_layout, cute.make_layout((1, async_copy_elems)) + tK_shape_dim_1 = self.sK_layout.outer.shape[1][0] + self.gmem_tiled_copy_dK = copy_utils.tiled_copy_2d( + self.dtype, tK_shape_dim_1, self.num_mma_threads ) def _get_tiled_mma(self): - # C = A @ B.T + # S = Q @ K.T, dP = dO @ V.T tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, - cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), - tiler_mn=(64, self.n_block_size), + Float32, + atom_layout_mnk=(self.tile_m // 64, 2, 1), + tiler_mn=(64, self.tile_n // 2), ) - # C = A.T @ B - tiled_mma_dKV = sm90_utils_basic.make_trivial_tiled_mma( + # dV = P.T @ dO, dK = dS.T @ Q + tiled_mma_dK = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN, - cutlass.Float32, - atom_layout_mnk=(self.n_block_size // 64, 1, 1), - tiler_mn=(64, self.head_dim_padded), + Float32, + atom_layout_mnk=(self.tile_n // 64, 1, 1), + tiler_mn=(64, self.tile_hdim), ) - # C = A @ B - tiled_mma_dQaccum = sm90_utils_basic.make_trivial_tiled_mma( + tiled_mma_dV = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_n // 64, 1, 1), + tiler_mn=(64, self.tile_hdimv), + ) + # dQ = dS @ K + tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, - cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 1, 1), - tiler_mn=(64, self.head_dim_padded), + Float32, + atom_layout_mnk=(self.tile_m // 64, 2, 1), + tiler_mn=(64, self.tile_hdim // 2), ) - - return tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum + return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _get_shared_storage_cls(self): - sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 128 + sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 1024 sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] @@ -257,43 +223,35 @@ def _get_shared_storage_cls(self): (self.sK_layout, self.dtype, sK_alignment), (self.sV_layout, self.dtype, sV_alighment), (self.sdO_layout, self.dtype, sdO_alignment), - (self.sdQaccum_layout, cutlass.Float32, sdQaccum_alignment), + (self.sdQaccum_layout, Float32, sdQaccum_alignment), ] ] - cosize_sPdS = cute.cosize(self.sPdS_layout) - sPdS_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sPdS], 1024] + cosize_sdS = cute.cosize(self.sPdS_layout) + cosize_sP = cute.cosize(self.sPdS_layout) # Could be zero sLSE_struct = cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128 + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.num_stages], 128 ] sdPsum_struct = cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, self.m_block_size * self.num_stages], 128 + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.num_stages], 128 ] - mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_LSE_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_dPsum_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_dO_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - - mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, 2] - mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, 2] - @cute.struct class SharedStorageQKV: - mbar_ptr_Q: mbar_ptr_Q_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - mbar_ptr_lse: mbar_ptr_LSE_struct - mbar_ptr_dpsum: mbar_ptr_dPsum_struct - mbar_ptr_dO: mbar_ptr_dO_struct - + mbar_ptr_K: cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_V: cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_LSE: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_dPsum: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + sLSE: sLSE_struct + sdPsum: sdPsum_struct sQ: sQ_struct sV: sV_struct sK: sK_struct - sPdS: sPdS_struct - sLSE: sLSE_struct - sdPsum: sdPsum_struct sdO: sdO_struct + sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024] sdQaccum: sdQaccum_struct return SharedStorageQKV @@ -310,15 +268,15 @@ def __call__( mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, - softmax_scale: cutlass.Float32, + softmax_scale: Float32, stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: cutlass.Float32 | float | None = None, - window_size_left: cutlass.Int32 | int | None = None, - window_size_right: cutlass.Int32 | int | None = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, ): self._check_type( *( @@ -327,23 +285,28 @@ def __call__( ) ) + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ] + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) mQ, mK, mV, mdK, mdV, mdO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=layout_transpose)) - for t in (mQ, mK, mV, mdK, mdV, mdO) + utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdK, mdV, mdO) ] - LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) mLSE, mdPsum, mdQaccum = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_dPsum_dQaccum_transpose)) - for t in (mLSE, mdPsum, mdQaccum) + utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] - tiled_mma_SdP, tiled_mma_dKV, tiled_mma_dQaccum = self._get_tiled_mma() - - self.tiled_mma_SdP = tiled_mma_SdP - self.tiled_mma_dKV = tiled_mma_dKV - self.tiled_mma_sdQaccum = tiled_mma_dQaccum + tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() self.num_mma_threads = tiled_mma_SdP.size @@ -357,70 +320,66 @@ def __call__( self._setup_attributes() SharedStorage = self._get_shared_storage_cls() - self.tma_copy_q_bytes = cute.size_in_bytes( - mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1]) - ) - self.tma_copy_k_bytes = cute.size_in_bytes( - mK.element_type, cute.select(self.sK_layout, mode=[0, 1]) - ) - self.tma_copy_v_bytes = cute.size_in_bytes( - mV.element_type, cute.select(self.sK_layout, mode=[0, 1]) - ) - - self.tma_copy_do_bytes = cute.size_in_bytes( - mdO.element_type, cute.select(self.sdO_layout, mode=[0, 1]) - ) - self.tma_copy_lse_bytes = self.m_block_size * 4 - self.tma_copy_dPsum_bytes = self.m_block_size * 4 + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mQ, cute.select(self.sQ_layout, mode=[0, 1]), - (self.m_block_size, self.head_dim_padded), + (self.tile_m, self.tile_hdim), ) tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mK, cute.select(self.sK_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_padded), + (self.tile_n, self.tile_hdim), 1, ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mV, cute.select(self.sV_layout, mode=[0, 1]), - (self.n_block_size, self.head_dim_v_padded), + (self.tile_n, self.tile_hdimv), 1, ) tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mdO, cute.select(self.sdO_layout, mode=[0, 1]), - (self.m_block_size, self.head_dim_padded), + (self.tile_m, self.tile_hdimv), ) tma_atom_LSE, tma_tensor_LSE = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mLSE, - cute.make_layout(self.m_block_size), - (self.m_block_size,), + cute.make_layout(self.tile_m), + (self.tile_m,), ) tma_atom_dPsum, tma_tensor_dPsum = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mdPsum, - cute.make_layout(self.m_block_size), - (self.m_block_size,), + cute.make_layout(self.tile_m), + (self.tile_m,), ) TileScheduler = SingleTileScheduler tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mK.shape[0]), self.n_block_size), + cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), cute.size(mK.shape[2]), cute.size(mK.shape[3]), cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.m_block_size, self.n_block_size), + tile_shape_mn=(self.tile_m, self.tile_n), mCuSeqlensQ=None, mSeqUsedQ=None, qhead_per_kvhead_packgqa=1, @@ -461,8 +420,9 @@ def __call__( self.gmem_tiled_copy_dK, self.r2s_tiled_copy_dQaccum, tiled_mma_SdP, - tiled_mma_dKV, - tiled_mma_dQaccum, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, softmax_scale_log2, softmax_scale, tile_sched_params, @@ -504,8 +464,9 @@ def kernel( gmem_tiled_copy_dK: cute.TiledCopy, r2s_tiled_copy_dQaccum: cute.TiledCopy, tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, - tiled_mma_dQaccum: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, softmax_scale_log2, softmax_scale, tile_sched_params: ParamsBase, @@ -513,7 +474,6 @@ def kernel( SharedStorage: cutlass.Constexpr[Callable], ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx = cute.arch.thread_idx()[0] # prefetch TMA descriptors if warp_idx == 0: @@ -539,29 +499,12 @@ def kernel( pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) - pipeline_q = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), num_stages=self.num_stages, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_q_bytes, - init_wait=False, - ) - pipeline_lse = pipeline.PipelineTmaAsyncNoCluster.create( - barrier_storage=storage.mbar_ptr_lse.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_producer_group, - consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_lse_bytes, - init_wait=False, - ) - pipeline_dpsum = pipeline.PipelineTmaAsyncNoCluster.create( - barrier_storage=storage.mbar_ptr_dpsum.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_producer_group, - consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_dPsum_bytes, + tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], init_wait=False, ) pipeline_do = pipeline.PipelineTmaAsyncNoCluster.create( @@ -569,54 +512,34 @@ def kernel( num_stages=self.num_stages, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_do_bytes, - init_wait=False, + tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"], + init_wait=True, ) - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sQt = utils.transpose_view(sQ) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) - sLSE_load = storage.sLSE.get_tensor( - cute.make_layout( - (self.m_block_size, self.num_stages), - stride=(1, cute.round_up(self.m_block_size, 64)), - ) - ) - sLSE_mma = storage.sLSE.get_tensor( + sLSE = storage.sLSE.get_tensor( cute.make_layout( - (self.m_block_size, self.n_block_size, self.num_stages), - stride=(1, 0, cute.round_up(self.m_block_size, 64)), + (self.tile_m, self.num_stages), + stride=(1, cute.round_up(self.tile_m, 64)), ) ) - sdPsum_load = storage.sdPsum.get_tensor( + sdPsum = storage.sdPsum.get_tensor( cute.make_layout( - (self.m_block_size, self.num_stages), - stride=(1, cute.round_up(self.m_block_size, 64)), + (self.tile_m, self.num_stages), + stride=(1, cute.round_up(self.tile_m, 64)), ) ) - sdPsum_mma = storage.sdPsum.get_tensor( - cute.make_layout( - (self.m_block_size, self.n_block_size, self.num_stages), - stride=(1, 0, cute.round_up(self.m_block_size, 64)), - ) - ) - sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) - sP = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) - sPt = utils.transpose_view(sP) - - sdS = storage.sPdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) - sdSt = utils.transpose_view(sdS) - - sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) - sdOt = utils.transpose_view(sdO) - block_info = BlockInfo( - self.m_block_size, - self.n_block_size, + self.tile_m, + self.tile_n, False, False, None, @@ -648,21 +571,20 @@ def kernel( sQ, sK, sV, - sLSE_load, - sdPsum_load, sdO, + sLSE, + sdPsum, tma_atom_Q, tma_atom_K, tma_atom_V, + tma_atom_dO, tma_atom_LSE, tma_atom_dPsum, - tma_atom_dO, pipeline_q, - pipeline_lse, - pipeline_dpsum, pipeline_do, mbar_ptr_K, mbar_ptr_V, + block_info, SeqlenInfoCls, TileSchedulerCls, ) @@ -671,40 +593,29 @@ def kernel( barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - self.dQaccum_writer( - mdQaccum, - sdQaccum, - TileSchedulerCls, - SeqlenInfoCls, - ) + self.dQaccum_store(mdQaccum, sdQaccum, block_info, TileSchedulerCls, SeqlenInfoCls) else: cute.arch.warpgroup_reg_alloc(self.num_mma_regs) tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 - self.mma( tiled_mma_SdP, - tiled_mma_dKV, - tiled_mma_dQaccum, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, mdK, mdV, mdQaccum, sQ, - sQt, sK, sV, + sdO, sP, - sPt, sdS, - sdSt, - sdO, - sdOt, - sLSE_mma, - sdPsum_mma, + sLSE, + sdPsum, sdQaccum, pipeline_q, - pipeline_lse, - pipeline_dpsum, pipeline_do, mbar_ptr_K, mbar_ptr_V, @@ -731,21 +642,20 @@ def load( sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, + sdO: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, - sdO: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, tma_atom_LSE: cute.CopyAtom, tma_atom_dPsum: cute.CopyAtom, - tma_atom_dO: cute.CopyAtom, pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, - pipeline_dpsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, + pipeline_do: cutlass.pipeline.PipelineAsync, mbar_ptr_K: cutlass.Pointer, mbar_ptr_V: cutlass.Pointer, + block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): @@ -762,96 +672,59 @@ def load( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - mK_cur = mK[None, None, head_idx, batch_idx] - gK = cute.local_tile( - mK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0) - ) - + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) mV_cur = mV[None, None, head_idx, batch_idx] - gV = cute.local_tile( - mV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0) - ) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) mQ_cur = mQ[None, None, head_idx, batch_idx] - gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) - + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) + mdO_cur = mdO[None, None, head_idx, batch_idx] + gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) mLSE_cur = mLSE[None, head_idx, batch_idx] - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) - + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) mdPsum_cur = mdPsum[None, head_idx, batch_idx] - gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) - mdO_cur = mdO[None, None, head_idx, batch_idx] - gdO = cute.local_tile(mdO_cur, (self.m_block_size, self.head_dim_padded), (None, 0)) - - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, - cute.make_layout(1), - cute.group_modes(sQ, 0, 2), - cute.group_modes(gQ, 0, 2), + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, - cute.make_layout(1), - cute.group_modes(sK, 0, 2), - cute.group_modes(gK, 0, 2), + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, - cute.make_layout(1), - cute.group_modes(sV, 0, 2), - cute.group_modes(gV, 0, 2), + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ ) - tLSEsLSE, tLSEgLSE = cpasync.tma_partition( - tma_atom_LSE, - 0, - cute.make_layout(1), - sLSE, - gLSE, + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, 0, cute.make_layout(1), gdO, sdO ) - tdPsumsdPsum, tdPsumgdPsum = cpasync.tma_partition( - tma_atom_dPsum, - 0, - cute.make_layout(1), - sdPsum, - gdPsum, + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_do) + load_LSE, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_LSE, 0, cute.make_layout(1), gLSE, sLSE ) - tdOsdO, tdOgdO = cpasync.tma_partition( - tma_atom_dO, - 0, - cute.make_layout(1), - cute.group_modes(sdO, 0, 2), - cute.group_modes(gdO, 0, 2), - ) - - load_Q = partial(self.load_m_tile, tma_atom_Q, tQgQ, tQsQ, pipeline_q) - load_LSE = partial(self.load_m_tile, tma_atom_LSE, tLSEgLSE, tLSEsLSE, pipeline_lse) - load_dPsum = partial( - self.load_m_tile, tma_atom_dPsum, tdPsumgdPsum, tdPsumsdPsum, pipeline_dpsum + load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_q) + load_dPsum, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dPsum, 0, cute.make_layout(1), gdPsum, sdPsum ) - load_dO = partial(self.load_m_tile, tma_atom_dO, tdOgdO, tdOsdO, pipeline_dO) + load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_do) + # TODO: need to wait if we do persistent kernel with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_K, self.tma_copy_k_bytes) - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_V, self.tma_copy_v_bytes) - - cute.copy(tma_atom_K, tKgK, tKsK, tma_bar_ptr=mbar_ptr_K) - cute.copy(tma_atom_V, tVgV, tVsV, tma_bar_ptr=mbar_ptr_V) - - m_block_min, m_block_max = 0, cute.ceil_div(seqlen.seqlen_q, self.m_block_size) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_K, self.tma_copy_bytes["K"]) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_V, self.tma_copy_bytes["V"]) + load_K(tma_bar_ptr=mbar_ptr_K) + load_V(tma_bar_ptr=mbar_ptr_V) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) for i in cutlass.range(m_block_max - m_block_min, unroll=2): m_block = m_block_max - i - 1 - + pipeline_q.producer_acquire(producer_state) load_Q(m_block, producer_state=producer_state) load_LSE(m_block, producer_state=producer_state) - load_dPsum(m_block, producer_state=producer_state) + pipeline_do.producer_acquire(producer_state) load_dO(m_block, producer_state=producer_state) - + load_dPsum(m_block, producer_state=producer_state) producer_state.advance() tile_scheduler.prefetch_next_work() @@ -862,36 +735,31 @@ def load( def mma( self, tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, - tiled_mma_dQaccum: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, mdK: cute.Tensor, mdV: cute.Tensor, mdQaccum: cute.Tensor, sQ: cute.Tensor, - sQt: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, - sP: cute.Tensor, - sPt: cute.Tensor, - sdS: cute.Tensor, - sdSt: cute.Tensor, sdO: cute.Tensor, - sdOt: cute.Tensor, - sLSE_mma: cute.Tensor, - sdPsum_mma: cute.Tensor, + sP: Optional[cute.Tensor], + sdS: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, sdQaccum: cute.Tensor, pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, - pipeline_dPsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, + pipeline_do: cutlass.pipeline.PipelineAsync, mbar_ptr_K: cutlass.Pointer, mbar_ptr_V: cutlass.Pointer, - tidx: cutlass.Int32, + tidx: Int32, gmem_tiled_copy_dV: cute.TiledCopy, gmem_tiled_copy_dK: cute.TiledCopy, r2s_tiled_copy_dQaccum: cute.TiledCopy, - softmax_scale_log2: cutlass.Float32, - softmax_scale: cutlass.Float32, + softmax_scale_log2: Float32, + softmax_scale: Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -900,136 +768,123 @@ def mma( warp_group_thread_layout = cute.make_layout( self.num_mma_warp_groups, stride=self.num_threads_per_warp_group ) - + thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_dQaccum = tiled_mma_dQaccum.get_slice(warp_group_thread_layout(warp_group_idx)) - - smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( - tidx - ) - + wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) # S = Q @ K.T tSrQ = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sQ)) tSrK = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sK)) - # dP = dO @ V.T tdPrdO = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sdO)) tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) + # dV += P.T @ dO + sPt = utils.transpose_view(sP) + sdOt = utils.transpose_view(sdO) + tdVrPt = tiled_mma_dV.make_fragment_A(wg_mma_dV.partition_A(sPt)) + tdVrdOt = tiled_mma_dV.make_fragment_B(wg_mma_dV.partition_B(sdOt)) + # dK += dS.T @ Q + sdSt = utils.transpose_view(sdS) + sQt = utils.transpose_view(sQ) + tdKrdSt = tiled_mma_dK.make_fragment_A(wg_mma_dK.partition_A(sdSt)) + tdKrQt = tiled_mma_dK.make_fragment_B(wg_mma_dK.partition_B(sQt)) + # dQ = dS @ K + sKt = utils.transpose_view(sK) + tdQrdS = tiled_mma_dQ.make_fragment_A(wg_mma_dQ.partition_A(sdS)) + tdQrKt = tiled_mma_dQ.make_fragment_B(wg_mma_dQ.partition_B(sKt)) - # P = exp(S-LSE) + # Smem copy atom tiling + smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( + tidx + ) tPsP = smem_thr_copy_PdS.partition_D(sP) - - LSEslice = (None, 0, None) - tLSEsLSE_2D = utils.make_acc_tensor_mn_view( - tiled_mma_SdP.get_slice(tidx).partition_C(sLSE_mma) - )[LSEslice] - - # dS = P*(dP-dPsum) tdSsdS = smem_thr_copy_PdS.partition_D(sdS) - dPsumslice = (None, 0, None) - tdPsumsdPsum_2D = utils.make_acc_tensor_mn_view( - tiled_mma_SdP.get_slice(tidx).partition_C(sdPsum_mma) - )[dPsumslice] - - # dV += P.T @ dO - tdVrPt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sPt)) - tdVrdOt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sdOt)) - - # dK += dS.T @ Q - tdKrdSt = tiled_mma_dKV.make_fragment_A(wg_mma_dKV.partition_A(sdSt)) - tdKrQt = tiled_mma_dKV.make_fragment_B(wg_mma_dKV.partition_B(sQt)) - - # dQ = dS @ K - sKt = utils.transpose_view(sK) - tdQaccumrdS = tiled_mma_dQaccum.make_fragment_A(wg_mma_dQaccum.partition_A(sdS)) - tdQaccumrK = tiled_mma_dQaccum.make_fragment_B(wg_mma_dQaccum.partition_B(sKt)) + sLSE_mma = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.num_stages), + stride=(1, 0, cute.round_up(self.tile_m, 64)) + ) + ) + sdPsum_mma = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.num_stages), + stride=(1, 0, cute.round_up(self.tile_m, 64)) + ) + ) + LSEslice = (None, 0, None) + tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice] + tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice] smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) - tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) acc_dV = cute.make_fragment( - tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), - cutlass.Float32, + tiled_mma_dV.partition_shape_C((self.tile_n, self.tile_hdimv)), + Float32, ) acc_dK = cute.make_fragment( - tiled_mma_dKV.partition_shape_C((self.n_block_size, self.head_dim_padded)), - cutlass.Float32, + tiled_mma_dK.partition_shape_C((self.tile_n, self.tile_hdim)), + Float32, ) - acc_dV.fill(0.0) - acc_dK.fill(0.0) + mma_qk_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK) + mma_dov_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV) + mma_pdo_fn = partial(mma_sm90, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) + mma_dsq_fn = partial(mma_sm90, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) + mma_dsk_fn = partial(mma_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt) mma_one_m_block_all = partial( self.mma_one_m_block, - tiled_mma_SdP=tiled_mma_SdP, - tiled_mma_dKV=tiled_mma_dKV, - tiled_mma_dQaccum=tiled_mma_dQaccum, + mma_qk_fn=mma_qk_fn, + mma_dov_fn=mma_dov_fn, + mma_pdo_fn=mma_pdo_fn, + mma_dsq_fn=mma_dsq_fn, + mma_dsk_fn=mma_dsk_fn, pipeline_q=pipeline_q, - pipeline_lse=pipeline_lse, - pipeline_dPsum=pipeline_dPsum, - pipeline_dO=pipeline_dO, - tLSEsLSE_2D=tLSEsLSE_2D, - tdPsumsdPsum_2D=tdPsumsdPsum_2D, - sP=sP, - sdS=sdS, - sdQaccum=sdQaccum, - acc_dV=acc_dV, - acc_dK=acc_dK, - tSrQ=tSrQ, - tSrK=tSrK, + pipeline_do=pipeline_do, + tLSEsLSE=tLSEsLSE, + tLSEsdPsum=tLSEsdPsum, tPsP=tPsP, tdSsdS=tdSsdS, - tdVrPt=tdVrPt, - tdVrdOt=tdVrdOt, - tdKrdSt=tdKrdSt, - tdKrQt=tdKrQt, - tdPrdO=tdPrdO, - tdPrV=tdPrV, - tdQaccumrdS=tdQaccumrdS, - tdQaccumrK=tdQaccumrK, - tdQaccumsdQaccum=tdQaccumsdQaccum, + tdQsdQaccum=tdQsdQaccum, smem_thr_copy_PdS=smem_thr_copy_PdS, smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, + softmax_scale_log2=softmax_scale_log2, + acc_dV=acc_dV, + acc_dK=acc_dK, ) - KV_consumer_phase = cutlass.Int32(0) + acc_dV.fill(0.0) + acc_dK.fill(0.0) + + kv_consumer_phase = Int32(0) consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.num_stages ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - - cute.arch.mbarrier_wait(mbar_ptr_K, phase=KV_consumer_phase) - cute.arch.mbarrier_wait(mbar_ptr_V, phase=KV_consumer_phase) - KV_consumer_phase ^= 1 + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) - for m_block in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block_idx = m_block_max - 1 - m_block + cute.arch.mbarrier_wait(mbar_ptr_K, phase=kv_consumer_phase) + cute.arch.mbarrier_wait(mbar_ptr_V, phase=kv_consumer_phase) + kv_consumer_phase ^= 1 - consumer_state = mma_one_m_block_all( - warp_group_idx, - n_block, - m_block_idx, - head_idx, - batch_idx, - consumer_state, - softmax_scale_log2=softmax_scale_log2, - ) + for m_tile in cutlass.range(m_block_max - m_block_min, unroll=1): + m_block = m_block_max - 1 - m_tile + consumer_state = mma_one_m_block_all(warp_group_idx, m_block, consumer_state) # scale dK acc_dK.store(acc_dK.load() * softmax_scale) - self.epilogue_dKV( acc_dV, mdV, @@ -1040,7 +895,8 @@ def mma( seqlen, gmem_tiled_copy_dV, gmem_tiled_copy_dK, - tiled_mma_dKV, + tiled_mma_dK, + tiled_mma_dV, tidx, n_block, head_idx, @@ -1054,192 +910,120 @@ def mma( def mma_one_m_block( self, warp_group_idx, - n_block: cutlass.Int32, - m_block: cutlass.Int32, - head_idx: cutlass.Int32, - batch_idx: cutlass.Int32, + m_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - tiled_mma_SdP: cute.TiledMma, - tiled_mma_dKV: cute.TiledMma, - tiled_mma_dQaccum: cute.TiledMma, + mma_qk_fn: Callable, + mma_dov_fn: Callable, + mma_pdo_fn: Callable, + mma_dsq_fn: Callable, + mma_dsk_fn: Callable, pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_lse: cutlass.pipeline.PipelineAsync, - pipeline_dPsum: cutlass.pipeline.PipelineAsync, - pipeline_dO: cutlass.pipeline.PipelineAsync, - tLSEsLSE_2D: cute.Tensor, - tdPsumsdPsum_2D: cute.Tensor, - sP: Optional[cute.Tensor], - sdS: Optional[cute.Tensor], - sdQaccum: cute.Tensor, - acc_dV: cute.Tensor, - acc_dK: cute.Tensor, - tSrQ: cute.Tensor, - tSrK: cute.Tensor, + pipeline_do: cutlass.pipeline.PipelineAsync, + tLSEsLSE: cute.Tensor, + tLSEsdPsum: cute.Tensor, tPsP: Optional[cute.Tensor], tdSsdS: Optional[cute.Tensor], - tdVrPt: cute.Tensor, - tdVrdOt: cute.Tensor, - tdKrdSt: cute.Tensor, - tdKrQt: cute.Tensor, - tdPrdO: cute.Tensor, - tdPrV: cute.Tensor, - tdQaccumrdS: cute.Tensor, - tdQaccumrK: cute.Tensor, - tdQaccumsdQaccum: cute.Tensor, + tdQsdQaccum: cute.Tensor, smem_thr_copy_PdS: cute.TiledCopy, smem_thr_copy_dQaccum: cute.TiledCopy, - softmax_scale_log2: cutlass.Float32 = 1.0, + softmax_scale_log2: Float32, + acc_dV, + acc_dK, ): + smem_idx = smem_pipe_read.index # (1) [GEMM 1] S = Q @ K^T pipeline_q.consumer_wait(smem_pipe_read, pipeline_q.consumer_try_wait(smem_pipe_read)) - acc_S = cute.make_fragment( - tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - - sm90_utils.gemm( - tiled_mma_SdP, - acc_S, - tSrQ[None, None, None, smem_pipe_read.index], - tSrK, - zero_init=True, - wg_wait=0, - ) - - # (2) [Pointwise 1] P = exp(S - LSE) - pipeline_lse.consumer_wait(smem_pipe_read, pipeline_lse.consumer_try_wait(smem_pipe_read)) - - tLSErLSE = cute.make_fragment_like(tLSEsLSE_2D[None, 0]) - cute.autovec_copy(tLSEsLSE_2D[None, smem_pipe_read.index], tLSErLSE) - - acc_P_mn = utils.make_acc_tensor_mn_view(acc_S) - for r in cutlass.range_constexpr(cute.size(acc_P_mn, mode=[0])): - acc_P_mn[r, None].store( - cute.exp2(acc_P_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]) + acc_S = mma_qk_fn(A_idx=smem_idx, wg_wait=-1) + # S2R for LSE + tLSErLSE = cute.make_fragment_like(tLSEsLSE[None, 0]) + cute.autovec_copy(tLSEsLSE[None, smem_idx], tLSErLSE) + # (2) [GEMM 2] dP = dO @ V.T + pipeline_do.consumer_wait(smem_pipe_read, pipeline_do.consumer_try_wait(smem_pipe_read)) + acc_dP = mma_dov_fn(A_idx=smem_idx, wg_wait=1) + # (3) [Pointwise 1] P = exp(S - LSE) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) + for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + acc_S_mn[r, None].store( + cute.math.exp2( + acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True + ) ) - - # fp32->bf16 + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) + # Convert P from f32 -> f16 tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) utils.cvt_f16(tdVrP_acc, tdVrP) + # S2R for dPsum + tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) + cute.autovec_copy(tLSEsdPsum[None, smem_idx], tLSErdPsum) - # cp: rmem->smem + PdS_smem_idx = smem_idx if const_expr(self.dS_stage > 1) else 0 + # R2S for P tPrP = smem_thr_copy_PdS.retile(tdVrP) - - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads - ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads - ) - cute.copy(smem_thr_copy_PdS, tPrP, tPsP) - - """ - if warp_group_idx == 0 and cute.arch.thread_idx()[0] == 128 and m_block == 0 and n_block == 0 and head_idx == 0 and batch_idx == 0: - for j in cutlass.range_constexpr(16): - cute.printf("%.15f", tPrP[j].to(cutlass.Float32)) - """ - - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads - ) - - pipeline_lse.consumer_release(smem_pipe_read) - - # (3) [GEMM 2] dP = dO @ V.T - pipeline_dO.consumer_wait(smem_pipe_read, pipeline_dO.consumer_try_wait(smem_pipe_read)) - acc_dP = cute.make_fragment( - tiled_mma_SdP.partition_shape_C((self.m_block_size, self.n_block_size)), cutlass.Float32 - ) - - sm90_utils.gemm( - tiled_mma_SdP, - acc_dP, - tdPrdO[None, None, None, smem_pipe_read.index], - tdPrV, - zero_init=True, - wg_wait=-0, - ) - - # (4) [GEMM 3] dV += P.T @ dO - sm90_utils.gemm( - tiled_mma_dKV, - acc_dV, - tdVrPt, - tdVrdOt[None, None, None, smem_pipe_read.index], - zero_init=False, - wg_wait=0, - ) - - pipeline_dO.consumer_release(smem_pipe_read) + # sync to make sure P has already been used in the previous iteration before writing new vals + if const_expr(self.dS_stage == 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, PdS_smem_idx]) # (4) [Pointwise 2] dS = P*(dP-dPsum) - pipeline_dPsum.consumer_wait( - smem_pipe_read, pipeline_dPsum.consumer_try_wait(smem_pipe_read) - ) - - # dPsum - tdPsumrdPsum = cute.make_fragment_like(tdPsumsdPsum_2D[None, 0]) - cute.autovec_copy(tdPsumsdPsum_2D[None, smem_pipe_read.index], tdPsumrdPsum) - + warpgroup.wait_group(0) acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): acc_dP_mn[r, None].store( - acc_P_mn[r, None].load() * (acc_dP_mn[r, None].load() - tdPsumrdPsum[r]) + acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) ) - - # fp32->bf16 + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) + # Convert dS from f32 -> f16 tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) tdKrdS = cute.make_fragment_like(tdKrdS_acc, self.dtype) utils.cvt_f16(tdKrdS_acc, tdKrdS) - tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) - - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads + # If there's double buffering on dS, we don't need to sync here. + # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. + # But because both WGs have to sync at the end of the loop and double buffering, + # this race condition is not possible. + # This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and + # (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) - cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS) + # R2S for dS + tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) + cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, PdS_smem_idx]) + + # (4) [GEMM 3] dV += P.T @ dO + mma_pdo_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=False, wg_wait=-1) + # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) - - pipeline_dPsum.consumer_release(smem_pipe_read) - # (6) [GEMM 4] dQ = dS @ K - acc_dQ = cute.make_fragment( - tiled_mma_dQaccum.partition_shape_C((self.m_block_size, self.head_dim_padded)), - cutlass.Float32, - ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads - ) - sm90_utils.gemm( - tiled_mma_dQaccum, acc_dQ, tdQaccumrdS, tdQaccumrK, zero_init=True, wg_wait=0 - ) + acc_dQ = mma_dsk_fn(A_idx=PdS_smem_idx, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) + pipeline_do.consumer_release(smem_pipe_read) # release dO as dV mma is done + + # (7) [GEMM 5] dK += dS.T @ Q + mma_dsq_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=False, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.Epilogue), number_of_threads=self.num_mma_threads - ) cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - - tdQaccumrdQaccum_tmp = cute.make_tensor( - acc_dQ.iterator, cute.make_layout(tdQaccumsdQaccum.shape) - ) - cute.copy(smem_thr_copy_dQaccum, tdQaccumrdQaccum_tmp, tdQaccumsdQaccum) - + tdQrdQaccum_tmp = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.copy(smem_thr_copy_dQaccum, tdQrdQaccum_tmp, tdQsdQaccum) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) @@ -1248,16 +1032,10 @@ def mma_one_m_block( number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - # (7) [GEMM 5] dK += dS.T @ Q - sm90_utils.gemm( - tiled_mma_dKV, - acc_dK, - tdKrdSt, - tdKrQt[None, None, None, smem_pipe_read.index], - zero_init=False, - wg_wait=0, - ) + warpgroup.wait_group(0) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) pipeline_q.consumer_release(smem_pipe_read) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_q consumer release", cute.arch.thread_idx()[0], m_block) smem_pipe_read.advance() return smem_pipe_read @@ -1274,44 +1052,45 @@ def epilogue_dKV( seqlen: SeqlenInfoQK, gmem_tiled_copy_dV: cute.TiledCopy, gmem_tiled_copy_dK: cute.TiledCopy, - tiled_mma_dKV: cute.TiledMma, - tidx: cutlass.Int32, - n_block: cutlass.Int32, - head_idx: cutlass.Int32, - batch_idx: cutlass.Int32, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tidx: Int32, + n_block: Int32, + head_idx: Int32, + batch_idx: Int32, ): - ### RMEM --> SMEM rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) - rdK = cute.make_fragment_like(acc_dK, self.dtype) rdK.store(acc_dK.load().to(self.dtype)) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, before epilogue sync", cute.arch.thread_idx()[0]) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, after epilogue sync", cute.arch.thread_idx()[0]) smem_copy_atom_dKV = cute.make_copy_atom( cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, ) - smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dKV).get_slice( - tidx - ) + smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice(tidx) + smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice(tidx) - taccdVrdV = smem_thr_copy_dKV.retile(rdV) - taccdVsdV = smem_thr_copy_dKV.partition_D(sV) # reuse sV SMEM + # rmem -> smem + taccdVrdV = smem_thr_copy_dV.retile(rdV) + taccdVsdV = smem_thr_copy_dV.partition_D(sV) # reuse sV SMEM cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - taccdKrdK = smem_thr_copy_dKV.retile(rdK) - taccdKsdK = smem_thr_copy_dKV.partition_D(sK) # reuse sK SMEM + taccdKrdK = smem_thr_copy_dK.retile(rdK) + taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) # SMEM -> GMEM - cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + cdV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) mdV_cur = mdV[None, None, head_idx, batch_idx] - cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + cdK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) mdK_cur = mdK[None, None, head_idx, batch_idx] cute.arch.barrier( @@ -1328,10 +1107,10 @@ def epilogue_dKV( tdKrdK = cute.make_fragment_like(tdKsdK, self.dtype) cute.autovec_copy(tdKsdK, tdKrdK) - gdV = cute.local_tile(mdV_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) tdVgdV = gmem_thr_copy_dV.partition_D(gdV) - gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_padded), (n_block, 0)) + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) tdKgdK = gmem_thr_copy_dK.partition_D(gdK) tdVcdV = gmem_thr_copy_dV.partition_S(cdV) @@ -1342,7 +1121,7 @@ def epilogue_dKV( tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[1]) for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): - row_idx = n_block * self.n_block_size + t0dVcdV[0, rest_m, 0][0] + row_idx = n_block * self.tile_n + t0dVcdV[0, rest_m, 0][0] if row_idx < seqlen.seqlen_k: cute.copy( gmem_tiled_copy_dV, @@ -1362,50 +1141,39 @@ def epilogue_dKV( ) @cute.jit - def dQaccum_writer( + def dQaccum_store( self, mdQaccum: cute.Tensor, sdQaccum: cute.Tensor, + block_info: BlockInfo, TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], ): tile_elems = cute.cosize(sdQaccum.layout) - tile_bytes = cutlass.Int32(tile_elems * 4) + tile_bytes = Int32(tile_elems * 4) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - - # GMEM mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + base_flat = cute.domain_offset((seqlen.offset_q * self.tile_hdim,), mdQaccum_cur) - base_flat = cute.domain_offset((seqlen.offset_q * self.head_dim_padded,), mdQaccum_cur) - - m_block_min = cutlass.Int32(0) - m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) - + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) for it_m in cutlass.range(m_block_max - m_block_min, unroll=1): m_block = m_block_max - 1 - it_m - cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFull), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - gdQaccum_block = cute.local_tile(base_flat, (tile_elems,), (m_block,)) - with cute.arch.elect_one(): sm90_utils.tma_reduce_add_bulk_f32( - sdQaccum.iterator, - gdQaccum_block.iterator, - tile_bytes, + sdQaccum.iterator, gdQaccum_block.iterator, tile_bytes ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, @@ -1413,21 +1181,3 @@ def dQaccum_writer( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - - @cute.jit - def load_m_tile( - self, - tma_atom: cute.CopyAtom, - tXgX: cute.Tensor, - tXsX: cute.Tensor, - pipeline: cutlass.pipeline.PipelineAsync, - block: cutlass.Int32, - producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - ): - pipeline.producer_acquire(producer_state) - cute.copy( - tma_atom, - tXgX[None, block], - tXsX[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state), - ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 6e56b23d76e..885967158a8 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1530,12 +1530,8 @@ def load( gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) if const_expr(self.use_tma_Q): gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, - cute.make_layout(1), - cute.group_modes(sQ, 0, 2), - cute.group_modes(gQ, 0, 2), + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True ) # TODO: mcast # TODO check warp_idx if we have 128 producer threads @@ -1549,7 +1545,7 @@ def load( q_producer_phase ^= 1 with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) - cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) + load_Q(tma_bar_ptr=mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b13589c5670..15c81b8c1db 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -35,7 +35,9 @@ from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess_sm90 from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine @@ -382,6 +384,8 @@ def _flash_attn_bwd( n_block_size, num_threads, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs ) + m_block_size = 64 + n_block_size = 128 if compile_key not in _flash_attn_bwd.compile_cache: fa_bwd_sm80 = FlashAttentionBackwardSm80( dtype, @@ -402,9 +406,30 @@ def _flash_attn_bwd( AtomLayoutMdQ, V_in_regs=V_in_regs, ) + fa_bwd_sm90 = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + m_block_size, + n_block_size, + # num_stages_Q, + # num_stages_dO, + # num_threads, + # causal, + # SdP_swapAB, + # dKV_swapAB, + # dQ_swapAB, + # AtomLayoutMSdP, + # AtomLayoutNdKV, + # AtomLayoutMdQ, + # V_in_regs=V_in_regs, + ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( - fa_bwd_sm80, q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + # fa_bwd_sm80, + fa_bwd_sm90, + q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, @@ -421,7 +446,8 @@ def _flash_attn_bwd( # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: - fa_bwd_post = FlashAttentionBackwardPostprocess( + # fa_bwd_post = FlashAttentionBackwardPostprocess( + fa_bwd_post = FlashAttentionBackwardPostprocess_sm90( dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB ) # TODO: check @can_implement diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 0232b90e54a..c67ae17969f 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -938,7 +938,7 @@ struct CollectiveMainloopBwdSm90 { Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); From 093b935d9631191b2089dff38050040c7bee7ea8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 21:34:55 -0400 Subject: [PATCH 121/258] [Cute,Bwd,Sm90] Use cp.async.bulk instead of TMA for LSE & dPsum --- flash_attn/cute/copy_utils.py | 58 +++++++++++++++++++++ flash_attn/cute/flash_bwd_sm90.py | 83 ++++++++++--------------------- flash_attn/cute/flash_fwd.py | 10 ++-- flash_attn/cute/hopper_helpers.py | 18 ++++--- 4 files changed, 98 insertions(+), 71 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 822cdde2a4f..d69b3e7e0a4 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -9,6 +9,7 @@ from cutlass import Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir.dialects import llvm import cutlass.pipeline @@ -84,6 +85,63 @@ def tiled_copy_2d( return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + def tma_get_copy_fn( atom: cute.CopyAtom, cta_coord: cute.Coord, diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index d391f9f4bf9..7d7ab3d5fde 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -183,24 +183,18 @@ def _get_tiled_mma(self): tiler_mn=(64, self.tile_n // 2), ) # dV = P.T @ dO, dK = dS.T @ Q - tiled_mma_dK = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.MN, - warpgroup.OperandMajorMode.MN, - Float32, - atom_layout_mnk=(self.tile_n // 64, 1, 1), - tiler_mn=(64, self.tile_hdim), - ) - tiled_mma_dV = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.MN, - warpgroup.OperandMajorMode.MN, - Float32, - atom_layout_mnk=(self.tile_n // 64, 1, 1), - tiler_mn=(64, self.tile_hdimv), - ) + tiled_mma_dK, tiled_mma_dV = [ + sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_n // 64, 1, 1), + tiler_mn=(64, tile_hdim), + ) + for tile_hdim in (self.tile_hdim, self.tile_hdimv) + ] # dQ = dS @ K tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -242,8 +236,6 @@ class SharedStorageQKV: mbar_ptr_V: cute.struct.MemRange[cutlass.Int64, 2] mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_LSE: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_dPsum: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] sLSE: sLSE_struct sdPsum: sdPsum_struct sQ: sQ_struct @@ -316,6 +308,8 @@ def __call__( self.num_mma_regs = 240 self.num_producer_regs = 24 + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() @@ -358,18 +352,6 @@ def __call__( cute.select(self.sdO_layout, mode=[0, 1]), (self.tile_m, self.tile_hdimv), ) - tma_atom_LSE, tma_tensor_LSE = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileG2SOp(), - mLSE, - cute.make_layout(self.tile_m), - (self.tile_m,), - ) - tma_atom_dPsum, tma_tensor_dPsum = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileG2SOp(), - mdPsum, - cute.make_layout(self.tile_m), - (self.tile_m,), - ) TileScheduler = SingleTileScheduler tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), @@ -398,15 +380,13 @@ def __call__( tma_tensor_Q, tma_tensor_K, tma_tensor_V, - tma_tensor_LSE, - tma_tensor_dPsum, tma_tensor_dO, tma_atom_Q, tma_atom_K, tma_atom_V, - tma_atom_LSE, - tma_atom_dPsum, tma_atom_dO, + mLSE, + mdPsum, mdK, mdV, mdQaccum, @@ -442,15 +422,13 @@ def kernel( mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, - mLSE: cute.Tensor, - mdPsum: cute.Tensor, mdO: cute.Tensor, tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], - tma_atom_LSE: Optional[cute.CopyAtom], - tma_atom_dPsum: Optional[cute.CopyAtom], tma_atom_dO: Optional[cute.CopyAtom], + mLSE: cute.Tensor, + mdPsum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, mdQaccum: cute.Tensor, @@ -480,8 +458,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - cpasync.prefetch_descriptor(tma_atom_LSE) - cpasync.prefetch_descriptor(tma_atom_dPsum) cpasync.prefetch_descriptor(tma_atom_dO) smem = cutlass.utils.SmemAllocator() @@ -565,9 +541,9 @@ def kernel( mQ, mK, mV, + mdO, mLSE, mdPsum, - mdO, sQ, sK, sV, @@ -578,8 +554,6 @@ def kernel( tma_atom_K, tma_atom_V, tma_atom_dO, - tma_atom_LSE, - tma_atom_dPsum, pipeline_q, pipeline_do, mbar_ptr_K, @@ -636,9 +610,9 @@ def load( mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, + mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdO: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, @@ -649,8 +623,6 @@ def load( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, - tma_atom_LSE: cute.CopyAtom, - tma_atom_dPsum: cute.CopyAtom, pipeline_q: cutlass.pipeline.PipelineAsync, pipeline_do: cutlass.pipeline.PipelineAsync, mbar_ptr_K: cutlass.Pointer, @@ -700,13 +672,9 @@ def load( tma_atom_dO, 0, cute.make_layout(1), gdO, sdO ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_do) - load_LSE, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_LSE, 0, cute.make_layout(1), gLSE, sLSE - ) + load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_q) - load_dPsum, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_dPsum, 0, cute.make_layout(1), gdPsum, sdPsum - ) + load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_do) # TODO: need to wait if we do persistent kernel @@ -721,10 +689,13 @@ def load( m_block = m_block_max - i - 1 pipeline_q.producer_acquire(producer_state) load_Q(m_block, producer_state=producer_state) - load_LSE(m_block, producer_state=producer_state) + # cp.async.bulk is using ptx, so we need to elect one thread to do it + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state) pipeline_do.producer_acquire(producer_state) load_dO(m_block, producer_state=producer_state) - load_dPsum(m_block, producer_state=producer_state) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state) producer_state.advance() tile_scheduler.prefetch_next_work() diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 885967158a8..00721f07362 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -366,17 +366,13 @@ def epilogue( cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) - tOsO, tOgO = cpasync.tma_partition( - tma_atom_O, - 0, - cute.make_layout(1), - cute.group_modes(sO, 0, 2), - cute.group_modes(gO, 0, 2), + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) - cute.copy(tma_atom_O, tOsO, tOgO) + store_O() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) else: diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 5a46139fb6b..56d6a1651e1 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -2,7 +2,7 @@ from typing import Type, Union, Optional import cutlass import cutlass.cute as cute -from cutlass import const_expr +from cutlass import Int32, const_expr from cutlass.cute.nvgpu import warpgroup from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import Numeric, dsl_user_op @@ -63,15 +63,17 @@ def make_smem_layout( @dsl_user_op def tma_reduce_add_bulk_f32( - smem_ptr: cute.Pointer, - gmem_ptr: cute.Pointer, - store_bytes: cutlass.Int32, - *, loc=None, ip=None - ): - smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, - [gmem_ptr.llvm_ptr, smem_u32, store_bytes.ir_value()], + [gmem_ptr.llvm_ptr, smem_ptr_i32, store_bytes.ir_value()], "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", "l,r,r", has_side_effects=True, From 9be4a621877fbcb7e60d147852021266cc34891d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 21:39:08 -0400 Subject: [PATCH 122/258] [Cute,Bwd,Sm90] Use 1 barrier for loading both K & V --- flash_attn/cute/flash_bwd_sm90.py | 49 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 7d7ab3d5fde..e74b6e5421f 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -232,8 +232,7 @@ def _get_shared_storage_cls(self): @cute.struct class SharedStorageQKV: - mbar_ptr_K: cute.struct.MemRange[cutlass.Int64, 2] - mbar_ptr_V: cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_KV: cute.struct.MemRange[cutlass.Int64, 2] mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] sLSE: sLSE_struct @@ -463,13 +462,11 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) - mbar_ptr_K = storage.mbar_ptr_K.data_ptr() - mbar_ptr_V = storage.mbar_ptr_V.data_ptr() + mbar_ptr_KV = storage.mbar_ptr_KV.data_ptr() # mbarrier init if warp_idx == 1: - cute.arch.mbarrier_init(mbar_ptr_K, 1) - cute.arch.mbarrier_init(mbar_ptr_V, 1) + cute.arch.mbarrier_init(mbar_ptr_KV, 1) pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( @@ -556,8 +553,7 @@ def kernel( tma_atom_dO, pipeline_q, pipeline_do, - mbar_ptr_K, - mbar_ptr_V, + mbar_ptr_KV, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -591,8 +587,7 @@ def kernel( sdQaccum, pipeline_q, pipeline_do, - mbar_ptr_K, - mbar_ptr_V, + mbar_ptr_KV, tidx, gmem_tiled_copy_dV, gmem_tiled_copy_dK, @@ -625,8 +620,7 @@ def load( tma_atom_dO: cute.CopyAtom, pipeline_q: cutlass.pipeline.PipelineAsync, pipeline_do: cutlass.pipeline.PipelineAsync, - mbar_ptr_K: cutlass.Pointer, - mbar_ptr_V: cutlass.Pointer, + mbar_ptr_KV: cutlass.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -679,10 +673,11 @@ def load( # TODO: need to wait if we do persistent kernel with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_K, self.tma_copy_bytes["K"]) - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_V, self.tma_copy_bytes["V"]) - load_K(tma_bar_ptr=mbar_ptr_K) - load_V(tma_bar_ptr=mbar_ptr_V) + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr_KV, self.tma_copy_bytes["K"] + self.tma_copy_bytes["V"] + ) + load_K(tma_bar_ptr=mbar_ptr_KV) + load_V(tma_bar_ptr=mbar_ptr_KV) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) for i in cutlass.range(m_block_max - m_block_min, unroll=2): @@ -723,8 +718,7 @@ def mma( sdQaccum: cute.Tensor, pipeline_q: cutlass.pipeline.PipelineAsync, pipeline_do: cutlass.pipeline.PipelineAsync, - mbar_ptr_K: cutlass.Pointer, - mbar_ptr_V: cutlass.Pointer, + mbar_ptr_KV: cutlass.Pointer, tidx: Int32, gmem_tiled_copy_dV: cute.TiledCopy, gmem_tiled_copy_dK: cute.TiledCopy, @@ -777,15 +771,15 @@ def mma( sLSE.iterator, cute.make_layout( (self.tile_m, self.tile_n, self.num_stages), - stride=(1, 0, cute.round_up(self.tile_m, 64)) - ) + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), ) sdPsum_mma = cute.make_tensor( sdPsum.iterator, cute.make_layout( (self.tile_m, self.tile_n, self.num_stages), - stride=(1, 0, cute.round_up(self.tile_m, 64)) - ) + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), ) LSEslice = (None, 0, None) tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice] @@ -804,10 +798,14 @@ def mma( ) mma_qk_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK) - mma_dov_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV) + mma_dov_fn = partial( + mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV + ) mma_pdo_fn = partial(mma_sm90, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) mma_dsq_fn = partial(mma_sm90, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) - mma_dsk_fn = partial(mma_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt) + mma_dsk_fn = partial( + mma_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt + ) mma_one_m_block_all = partial( self.mma_one_m_block, @@ -846,8 +844,7 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) - cute.arch.mbarrier_wait(mbar_ptr_K, phase=kv_consumer_phase) - cute.arch.mbarrier_wait(mbar_ptr_V, phase=kv_consumer_phase) + cute.arch.mbarrier_wait(mbar_ptr_KV, phase=kv_consumer_phase) kv_consumer_phase ^= 1 for m_tile in cutlass.range(m_block_max - m_block_min, unroll=1): From 557648058c95337d10c43279459a9d729e9251ce Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 21:54:14 -0400 Subject: [PATCH 123/258] [Cute,Bwd,Sm90] Don't clear dK & dV, use zero_init mma flag instead --- flash_attn/cute/flash_bwd_sm90.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index e74b6e5421f..3d58ccd1a4c 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -809,6 +809,7 @@ def mma( mma_one_m_block_all = partial( self.mma_one_m_block, + warp_group_idx=warp_group_idx, mma_qk_fn=mma_qk_fn, mma_dov_fn=mma_dov_fn, mma_pdo_fn=mma_pdo_fn, @@ -824,13 +825,10 @@ def mma( smem_thr_copy_PdS=smem_thr_copy_PdS, smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, softmax_scale_log2=softmax_scale_log2, - acc_dV=acc_dV, - acc_dK=acc_dK, + # acc_dV=acc_dV, + # acc_dK=acc_dK, ) - acc_dV.fill(0.0) - acc_dK.fill(0.0) - kv_consumer_phase = Int32(0) consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.num_stages @@ -847,9 +845,13 @@ def mma( cute.arch.mbarrier_wait(mbar_ptr_KV, phase=kv_consumer_phase) kv_consumer_phase ^= 1 + dKV_should_accumulate = False for m_tile in cutlass.range(m_block_max - m_block_min, unroll=1): m_block = m_block_max - 1 - m_tile - consumer_state = mma_one_m_block_all(warp_group_idx, m_block, consumer_state) + consumer_state = mma_one_m_block_all( + m_block, consumer_state, dKV_should_accumulate=dKV_should_accumulate + ) + dKV_should_accumulate = True # scale dK acc_dK.store(acc_dK.load() * softmax_scale) @@ -877,9 +879,9 @@ def mma( @cute.jit def mma_one_m_block( self, - warp_group_idx, m_block: Int32, smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + warp_group_idx: Int32, mma_qk_fn: Callable, mma_dov_fn: Callable, mma_pdo_fn: Callable, @@ -895,8 +897,9 @@ def mma_one_m_block( smem_thr_copy_PdS: cute.TiledCopy, smem_thr_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: Float32, - acc_dV, - acc_dK, + # acc_dV, + # acc_dK, + dKV_should_accumulate: Boolean = True, ): smem_idx = smem_pipe_read.index # (1) [GEMM 1] S = Q @ K^T @@ -968,7 +971,7 @@ def mma_one_m_block( cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, PdS_smem_idx]) # (4) [GEMM 3] dV += P.T @ dO - mma_pdo_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=False, wg_wait=-1) + mma_pdo_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_proxy( @@ -983,7 +986,7 @@ def mma_one_m_block( pipeline_do.consumer_release(smem_pipe_read) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q - mma_dsq_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=False, wg_wait=1) + mma_dsq_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=1) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) cute.arch.barrier( From 5a5a65b48dc99fc7483d2a7d5cfb1d8befa89389 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Oct 2025 22:07:42 -0400 Subject: [PATCH 124/258] [Cute,Bwd,Sm90] Use TMA to store dK & dV --- flash_attn/cute/flash_bwd_sm90.py | 131 +++++++++++------------------- 1 file changed, 48 insertions(+), 83 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 3d58ccd1a4c..5223cedd032 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -155,21 +155,10 @@ def _setup_attributes(self): ] self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) - # dQaccum R->S self.r2s_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_mma_threads, num_copy_elems=128 // Float32.width ) - # dV: S->G - tV_shape_dim_1 = self.sV_layout.outer.shape[1][0] - self.gmem_tiled_copy_dV = copy_utils.tiled_copy_2d( - self.dtype, tV_shape_dim_1, self.num_mma_threads - ) - # dK: S->G - tK_shape_dim_1 = self.sK_layout.outer.shape[1][0] - self.gmem_tiled_copy_dK = copy_utils.tiled_copy_2d( - self.dtype, tK_shape_dim_1, self.num_mma_threads - ) def _get_tiled_mma(self): # S = Q @ K.T, dP = dO @ V.T @@ -336,14 +325,12 @@ def __call__( mK, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), - 1, ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mV, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), - 1, ) tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), @@ -351,6 +338,19 @@ def __call__( cute.select(self.sdO_layout, mode=[0, 1]), (self.tile_m, self.tile_hdimv), ) + tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + ) + tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + ) + TileScheduler = SingleTileScheduler tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), @@ -380,14 +380,16 @@ def __call__( tma_tensor_K, tma_tensor_V, tma_tensor_dO, + tma_tensor_dK, + tma_tensor_dV, tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, + tma_atom_dK, + tma_atom_dV, mLSE, mdPsum, - mdK, - mdV, mdQaccum, self.sQ_layout, self.sK_layout, @@ -395,8 +397,6 @@ def __call__( self.sPdS_layout, self.sdO_layout, self.sdQaccum_layout, - self.gmem_tiled_copy_dV, - self.gmem_tiled_copy_dK, self.r2s_tiled_copy_dQaccum, tiled_mma_SdP, tiled_mma_dK, @@ -422,14 +422,16 @@ def kernel( mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, - tma_atom_Q: Optional[cute.CopyAtom], - tma_atom_K: Optional[cute.CopyAtom], - tma_atom_V: Optional[cute.CopyAtom], - tma_atom_dO: Optional[cute.CopyAtom], - mLSE: cute.Tensor, - mdPsum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, mdQaccum: cute.Tensor, sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, @@ -437,8 +439,6 @@ def kernel( sPdS_layout: cute.ComposedLayout, sdO_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, r2s_tiled_copy_dQaccum: cute.TiledCopy, tiled_mma_SdP: cute.TiledMma, tiled_mma_dK: cute.TiledMma, @@ -589,8 +589,8 @@ def kernel( pipeline_do, mbar_ptr_KV, tidx, - gmem_tiled_copy_dV, - gmem_tiled_copy_dK, + tma_atom_dK, + tma_atom_dV, r2s_tiled_copy_dQaccum, softmax_scale_log2, softmax_scale, @@ -720,8 +720,8 @@ def mma( pipeline_do: cutlass.pipeline.PipelineAsync, mbar_ptr_KV: cutlass.Pointer, tidx: Int32, - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, r2s_tiled_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: Float32, softmax_scale: Float32, @@ -863,8 +863,8 @@ def mma( mdK, sK, seqlen, - gmem_tiled_copy_dV, - gmem_tiled_copy_dK, + tma_atom_dK, + tma_atom_dV, tiled_mma_dK, tiled_mma_dV, tidx, @@ -1021,8 +1021,8 @@ def epilogue_dKV( mdK: cute.Tensor, sK: cute.Tensor, seqlen: SeqlenInfoQK, - gmem_tiled_copy_dV: cute.TiledCopy, - gmem_tiled_copy_dK: cute.TiledCopy, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tidx: Int32, @@ -1035,11 +1035,9 @@ def epilogue_dKV( rdK = cute.make_fragment_like(acc_dK, self.dtype) rdK.store(acc_dK.load().to(self.dtype)) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, before epilogue sync", cute.arch.thread_idx()[0]) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, after epilogue sync", cute.arch.thread_idx()[0]) smem_copy_atom_dKV = cute.make_copy_atom( cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), @@ -1057,59 +1055,26 @@ def epilogue_dKV( taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) - # SMEM -> GMEM - cdV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) + # smem -> gmem mdV_cur = mdV[None, None, head_idx, batch_idx] - - cdK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) mdK_cur = mdK[None, None, head_idx, batch_idx] - + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + store_dK, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True + ) + store_dV, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True + ) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) - gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) - gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) - - tdVsdV = gmem_thr_copy_dV.partition_S(sV) - tdVrdV = cute.make_fragment_like(tdVsdV, self.dtype) - cute.autovec_copy(tdVsdV, tdVrdV) - - tdKsdK = gmem_thr_copy_dK.partition_S(sK) - tdKrdK = cute.make_fragment_like(tdKsdK, self.dtype) - cute.autovec_copy(tdKsdK, tdKrdK) - - gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) - tdVgdV = gmem_thr_copy_dV.partition_D(gdV) - - gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) - tdKgdK = gmem_thr_copy_dK.partition_D(gdK) - - tdVcdV = gmem_thr_copy_dV.partition_S(cdV) - t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) - tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[1]) - - tdKcdK = gmem_thr_copy_dK.partition_S(cdK) - tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[1]) - - for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): - row_idx = n_block * self.tile_n + t0dVcdV[0, rest_m, 0][0] - if row_idx < seqlen.seqlen_k: - cute.copy( - gmem_tiled_copy_dV, - tdVrdV[None, rest_m, None], - tdVgdV[None, rest_m, None], - pred=tdVpdV[None, rest_m, None] - if cutlass.const_expr(self.check_hdim_v_oob) - else None, - ) - cute.copy( - gmem_tiled_copy_dK, - tdKrdK[None, rest_m, None], - tdKgdK[None, rest_m, None], - pred=tdKpdK[None, rest_m, None] - if cutlass.const_expr(self.check_hdim_oob) - else None, - ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + store_dV() + store_dK() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) @cute.jit def dQaccum_store( From 66fd2a4c10d30a060b2e0e44a817cb32dbe8d23d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 08:56:05 -0400 Subject: [PATCH 125/258] [Cute,Bwd,Sm90] Load K together w Q & LSE in the first iteration --- flash_attn/cute/copy_utils.py | 21 +++++++ flash_attn/cute/flash_bwd.py | 4 +- flash_attn/cute/flash_bwd_sm90.py | 101 ++++++++++++++---------------- flash_attn/cute/hopper_helpers.py | 19 ------ flash_attn/cute/pipeline.py | 18 ++++-- 5 files changed, 83 insertions(+), 80 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index d69b3e7e0a4..5e4644cccfa 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -109,6 +109,27 @@ def cpasync_bulk_g2s( ) +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def cpasync_bulk_get_copy_fn( src_tensor: cute.Tensor, dst_tensor: cute.Tensor, diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index de2d4e74ea7..404fc4cba38 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -405,7 +405,7 @@ def kernel( mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdQaccu: cute.Tensor, + mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: cutlass.Float32, @@ -459,7 +459,7 @@ def kernel( gdO = cute.local_tile(mdO[batch_idx, None, head_idx, None], blkdO_shape, (None, 0)) gLSE = cute.local_tile(mLSE[batch_idx, head_idx, None], (self.m_block_size,), (None,)) gdPsum = cute.local_tile(mdPsum[batch_idx, head_idx, None], (self.m_block_size,), (None,)) - gdQaccum = cute.local_tile(mdQaccu[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) + gdQaccum = cute.local_tile(mdQaccum[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 5223cedd032..b910e862248 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -221,7 +221,6 @@ def _get_shared_storage_cls(self): @cute.struct class SharedStorageQKV: - mbar_ptr_KV: cute.struct.MemRange[cutlass.Int64, 2] mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] sLSE: sLSE_struct @@ -462,12 +461,6 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) - mbar_ptr_KV = storage.mbar_ptr_KV.data_ptr() - - # mbarrier init - if warp_idx == 1: - cute.arch.mbarrier_init(mbar_ptr_KV, 1) - pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group @@ -553,7 +546,6 @@ def kernel( tma_atom_dO, pipeline_q, pipeline_do, - mbar_ptr_KV, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -587,7 +579,6 @@ def kernel( sdQaccum, pipeline_q, pipeline_do, - mbar_ptr_KV, tidx, tma_atom_dK, tma_atom_dV, @@ -620,7 +611,6 @@ def load( tma_atom_dO: cute.CopyAtom, pipeline_q: cutlass.pipeline.PipelineAsync, pipeline_do: cutlass.pipeline.PipelineAsync, - mbar_ptr_KV: cutlass.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -671,17 +661,23 @@ def load( load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_do) - # TODO: need to wait if we do persistent kernel - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - mbar_ptr_KV, self.tma_copy_bytes["K"] + self.tma_copy_bytes["V"] - ) - load_K(tma_bar_ptr=mbar_ptr_KV) - load_V(tma_bar_ptr=mbar_ptr_KV) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - for i in cutlass.range(m_block_max - m_block_min, unroll=2): - m_block = m_block_max - i - 1 + # First iteration: load K together w Q & LSE, then V together w dO & dPsum + m_block = m_block_min + pipeline_q.producer_acquire(producer_state, extra_tx_count=self.tma_copy_bytes["K"]) + load_K(tma_bar_ptr=pipeline_q.producer_get_barrier(producer_state)) + load_Q(m_block, producer_state=producer_state) + # cp.async.bulk is using ptx, so we need to elect one thread to do it + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state) + pipeline_do.producer_acquire(producer_state, extra_tx_count=self.tma_copy_bytes["V"]) + load_V(tma_bar_ptr=pipeline_do.producer_get_barrier(producer_state)) + load_dO(m_block, producer_state=producer_state) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state) + producer_state.advance() + # Subsequent iterations: load Q & LSE, then dO & dPsum + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): pipeline_q.producer_acquire(producer_state) load_Q(m_block, producer_state=producer_state) # cp.async.bulk is using ptx, so we need to elect one thread to do it @@ -718,7 +714,6 @@ def mma( sdQaccum: cute.Tensor, pipeline_q: cutlass.pipeline.PipelineAsync, pipeline_do: cutlass.pipeline.PipelineAsync, - mbar_ptr_KV: cutlass.Pointer, tidx: Int32, tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, @@ -829,7 +824,6 @@ def mma( # acc_dK=acc_dK, ) - kv_consumer_phase = Int32(0) consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.num_stages ) @@ -838,16 +832,10 @@ def mma( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) - - cute.arch.mbarrier_wait(mbar_ptr_KV, phase=kv_consumer_phase) - kv_consumer_phase ^= 1 - dKV_should_accumulate = False - for m_tile in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block = m_block_max - 1 - m_tile + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): consumer_state = mma_one_m_block_all( m_block, consumer_state, dKV_should_accumulate=dKV_should_accumulate ) @@ -924,7 +912,8 @@ def mma_one_m_block( # Convert P from f32 -> f16 tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) - utils.cvt_f16(tdVrP_acc, tdVrP) + # utils.cvt_f16(tdVrP_acc, tdVrP) + tdVrP.store(tdVrP_acc.load().to(self.dtype)) # S2R for dPsum tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) cute.autovec_copy(tLSEsdPsum[None, smem_idx], tLSErdPsum) @@ -951,7 +940,8 @@ def mma_one_m_block( # Convert dS from f32 -> f16 tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) tdKrdS = cute.make_fragment_like(tdKrdS_acc, self.dtype) - utils.cvt_f16(tdKrdS_acc, tdKrdS) + # utils.cvt_f16(tdKrdS_acc, tdKrdS) + tdKrdS.store(tdKrdS_acc.load().to(self.dtype)) # If there's double buffering on dS, we don't need to sync here. # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. @@ -1033,7 +1023,8 @@ def epilogue_dKV( rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) rdK = cute.make_fragment_like(acc_dK, self.dtype) - rdK.store(acc_dK.load().to(self.dtype)) + # rdK.store(acc_dK.load().to(self.dtype)) + utils.cvt_f16(acc_dK, rdK) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads @@ -1045,17 +1036,6 @@ def epilogue_dKV( ) smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice(tidx) smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice(tidx) - - # rmem -> smem - taccdVrdV = smem_thr_copy_dV.retile(rdV) - taccdVsdV = smem_thr_copy_dV.partition_D(sV) # reuse sV SMEM - cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - - taccdKrdK = smem_thr_copy_dK.retile(rdK) - taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM - cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) - - # smem -> gmem mdV_cur = mdV[None, None, head_idx, batch_idx] mdK_cur = mdK[None, None, head_idx, batch_idx] gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) @@ -1066,12 +1046,29 @@ def epilogue_dKV( store_dV, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True ) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # rmem -> smem + taccdVrdV = smem_thr_copy_dV.retile(rdV) + taccdVsdV = smem_thr_copy_dV.partition_D(sV) # reuse sV SMEM + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: store_dV() + taccdKrdK = smem_thr_copy_dK.retile(rdK) + taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + # smem -> gmem + if warp_idx == 4: store_dK() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) @@ -1085,28 +1082,23 @@ def dQaccum_store( TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], ): - tile_elems = cute.cosize(sdQaccum.layout) - tile_bytes = Int32(tile_elems * 4) - + cpasync_bulk_bytes = self.tile_m * self.tile_hdim * Float32.width // 8 tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - base_flat = cute.domain_offset((seqlen.offset_q * self.tile_hdim,), mdQaccum_cur) - + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - for it_m in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block = m_block_max - 1 - it_m + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFull), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - gdQaccum_block = cute.local_tile(base_flat, (tile_elems,), (m_block,)) with cute.arch.elect_one(): - sm90_utils.tma_reduce_add_bulk_f32( - sdQaccum.iterator, gdQaccum_block.iterator, tile_bytes + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum.iterator, gdQaccum[None, m_block].iterator, cpasync_bulk_bytes ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) @@ -1114,6 +1106,5 @@ def dQaccum_store( barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 56d6a1651e1..bab56fe8d1e 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -61,22 +61,3 @@ def make_smem_layout( return smem_layout_staged -@dsl_user_op -def tma_reduce_add_bulk_f32( - smem_ptr: cute.Pointer, - gmem_ptr: cute.Pointer, - store_bytes: Int32, - *, - loc=None, - ip=None, -): - smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() - llvm.inline_asm( - None, - [gmem_ptr.llvm_ptr, smem_ptr_i32, store_bytes.ir_value()], - "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", - "l,r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 7ea4743c2ed..b1f422068c4 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -6,7 +6,8 @@ import cutlass import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, Int32, if_generate +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate from cutlass.pipeline import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait from cutlass.pipeline import PipelineUserType, PipelineOp @@ -134,7 +135,7 @@ def create( ) dst_rank = None producer_mask = None - if cutlass.const_expr(init_wait): + if const_expr(init_wait): pipeline_init_wait() return PipelineTmaAsyncNoCluster( sync_object_full, @@ -144,7 +145,12 @@ def create( dst_rank, ) - def producer_acquire(self, state: PipelineState, try_acquire_token: Optional[Boolean] = None): + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + ): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ @@ -152,7 +158,11 @@ def producer_acquire(self, state: PipelineState, try_acquire_token: Optional[Boo try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(state.index, state.phase), ) - self.sync_object_full.arrive(state.index, self.producer_mask) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count) def producer_commit(self, state: PipelineState): """ From 35384ecdf5461a79cf39d5c547185a4c89b91b5d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 14:45:39 -0400 Subject: [PATCH 126/258] [Cute,Sm90] Move gemm helper functions to hopper_helpers.py --- flash_attn/cute/copy_utils.py | 4 +++ flash_attn/cute/flash_bwd_sm90.py | 44 +++++-------------------------- flash_attn/cute/flash_fwd.py | 39 ++++++++++----------------- flash_attn/cute/hopper_helpers.py | 33 ++++++++++++++++++++++- 4 files changed, 57 insertions(+), 63 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 5e4644cccfa..84b3f4e2956 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -119,11 +119,15 @@ def cpasync_reduce_bulk_add_f32( ip=None, ): smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST llvm.inline_asm( None, [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index b910e862248..13ccef13962 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -21,37 +21,6 @@ from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd -def mma_zero_init( - tiled_mma: cute.TiledMma, - shape: cute.Shape, - tCrA: cute.Tensor, - tCrB: cute.Tensor, - A_idx: Optional[Int32] = None, - B_idx: Optional[Int32] = None, - wg_wait: int = -1, -) -> cute.Tensor: - acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) - return acc - - -def mma_sm90( - tiled_mma: cute.TiledMma, - acc: cute.Tensor, - tCrA: cute.Tensor, - tCrB: cute.Tensor, - zero_init: Boolean, - A_idx: Optional[Int32] = None, - B_idx: Optional[Int32] = None, - wg_wait: int = -1, -) -> None: - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - sm90_utils.gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) - - class FlashAttentionBackwardSm90: arch = 90 @@ -153,7 +122,6 @@ def _setup_attributes(self): ((self.tile_m, self.tile_n), self.dS_stage), ] ] - self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) # dQaccum R->S self.r2s_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( @@ -792,14 +760,16 @@ def mma( Float32, ) - mma_qk_fn = partial(mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK) + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK + ) mma_dov_fn = partial( - mma_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV + sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV ) - mma_pdo_fn = partial(mma_sm90, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) - mma_dsq_fn = partial(mma_sm90, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) + mma_pdo_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) + mma_dsq_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) mma_dsk_fn = partial( - mma_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt + sm90_utils.gemm_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt ) mma_one_m_block_all = partial( diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 00721f07362..222d0790967 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -35,22 +35,6 @@ from flash_attn.cute.fast_math import FastDivmod -def mma_qk(tiled_mma_qk: cute.TiledMma, shape: cute.Shape, tSrQ: cute.Tensor, tSrK: cute.Tensor, smem_idx: Int32, wg_wait: int = -1) -> cute.Tensor: - acc_S = cute.make_fragment(tiled_mma_qk.partition_shape_C(shape), Float32) - sm90_utils.gemm( - tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, smem_idx], zero_init=True, wg_wait=wg_wait - ) - return acc_S - - -def mma_pv(tiled_mma_pv: cute.TiledMma, acc_O: cute.Tensor, tOrP: cute.Tensor, tOrVt: cute.Tensor, smem_idx: Int32, zero_init: Boolean, wg_wait: int = -1) -> None: - sm90_utils.gemm( - tiled_mma_pv, acc_O, tOrP, - tOrVt[None, None, None, smem_idx], - zero_init=zero_init, wg_wait=wg_wait - ) - - class FlashAttentionForwardBase: arch: int = 80 @@ -1557,7 +1541,6 @@ def load( work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop - @cute.jit def mma( self, @@ -1627,8 +1610,10 @@ def mma( acc_O = cute.make_fragment(acc_shape_O, Float32) smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - mma_qk_fn = partial(mma_qk, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK) - mma_pv_fn = partial(mma_pv, tiled_mma_pv, acc_O, tOrP, tOrVt) + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) mma_one_n_block_all = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, @@ -1692,7 +1677,7 @@ def mma( # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) - acc_S = mma_qk_fn(kv_consumer_state.index, wg_wait=0) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) pipeline_k.consumer_release(kv_consumer_state) # Use vectorized score modification if cutlass.const_expr(score_mod_fn is not None): @@ -1767,7 +1752,7 @@ def mma( # Last "half" iteration if const_expr(self.intra_wg_overlap): pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - mma_pv_fn(kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0) pipeline_v.consumer_release(kv_consumer_state) kv_consumer_state.advance() else: @@ -1821,7 +1806,8 @@ def mma_one_n_block( check_inf: cutlass.Constexpr = True, ): pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) @@ -1850,7 +1836,8 @@ def mma_one_n_block( cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - mma_pv_fn(smem_pipe_read.index, wg_wait=0) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) pipeline_v.consumer_release(smem_pipe_read) smem_pipe_read.advance() return smem_pipe_read @@ -1877,9 +1864,11 @@ def mma_one_n_block_intrawg_overlap( smem_pipe_read.advance() pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() - acc_S = mma_qk_fn(smem_pipe_read.index, wg_wait=-1) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) - mma_pv_fn(smem_pipe_read_v.index, wg_wait=-1) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index bab56fe8d1e..2597cd4a566 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -2,7 +2,7 @@ from typing import Type, Union, Optional import cutlass import cutlass.cute as cute -from cutlass import Int32, const_expr +from cutlass import Int32, Float32, Boolean, const_expr from cutlass.cute.nvgpu import warpgroup from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import Numeric, dsl_user_op @@ -37,6 +37,37 @@ def gemm( warpgroup.wait_group(wg_wait) +def gemm_zero_init( + tiled_mma: cute.TiledMma, + shape: cute.Shape, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, +) -> cute.Tensor: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc + + +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: Boolean, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + + @dsl_user_op def make_smem_layout( dtype: Type[Numeric], From 7c0e373ada572362b94bd5eb722f161128f462c9 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 13 Oct 2025 09:43:28 -0400 Subject: [PATCH 127/258] Swap masking to not use R2P --- flash_attn/cute/mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index bacb69e9f00..9b20323aebe 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -94,7 +94,7 @@ def apply_mask( col_limit_right = row_idx + causal_row_offset if cutlass.const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - if cutlass.const_expr(False): + if cutlass.const_expr(True): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] From 60eb1ea2983d2946a21ef7418222760f9498d42a Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 13 Oct 2025 09:46:21 -0400 Subject: [PATCH 128/258] Pre-indent to make commit diffs readable --- flash_attn/cute/flash_bwd.py | 505 ++++++++++++----------- flash_attn/cute/flash_bwd_postprocess.py | 183 ++++---- flash_attn/cute/flash_bwd_preprocess.py | 41 +- 3 files changed, 366 insertions(+), 363 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 404fc4cba38..93a7ec84b12 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -432,14 +432,15 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() n_block, head_idx, batch_idx = cute.arch.block_idx() - m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) - m_block_min = 0 - if cutlass.const_expr(self.is_causal): - m_block_min = max( - (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, - m_block_min, - ) - # TODO: return early if m_block_max == 0 + if True: + m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) + m_block_min = 0 + if cutlass.const_expr(self.is_causal): + m_block_min = max( + (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, + m_block_min, + ) + # TODO: return early if m_block_max == 0 # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. @@ -461,267 +462,267 @@ def kernel( gdPsum = cute.local_tile(mdPsum[batch_idx, head_idx, None], (self.m_block_size,), (None,)) gdQaccum = cute.local_tile(mdQaccum[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - smem = cutlass.utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - sQ = storage.sQ.get_tensor(sQ_layout) - sK = storage.sK.get_tensor(sK_layout) - if cutlass.const_expr(not self.share_QV_smem): - sV = storage.sV.get_tensor(sV_layout) - else: - sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) - sdO = storage.sdO.get_tensor(sdO_layout) - sP = storage.sP.get_tensor(sPdS_layout) - sdS = storage.sdS.get_tensor(sPdS_layout) - sLSE = storage.sLSE.get_tensor(sLSE_layout) - sdPsum = storage.sdPsum.get_tensor(sLSE_layout) - sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) - sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if cutlass.const_expr(not self.share_QV_smem): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + sdO = storage.sdO.get_tensor(sdO_layout) + sP = storage.sP.get_tensor(sPdS_layout) + sdS = storage.sdS.get_tensor(sPdS_layout) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sLSE_layout) + sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) + sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) - # Transpose view of tensors for tiled mma - sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] + # Transpose view of tensors for tiled mma + sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] - gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) - gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) - gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) - gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) - # (CPY_Atom, CPY_M, CPY_K, m_block) - tQgQ = gmem_thr_copy_QK.partition_S(gQ) - tQsQ = gmem_thr_copy_QK.partition_D(sQ) - # (CPY_Atom, CPY_N, CPY_K) - tKgK = gmem_thr_copy_QK.partition_S(gK) - tKsK = gmem_thr_copy_QK.partition_D(sK) - # (CPY_Atom, CPY_N, CPY_K) - tVgV = gmem_thr_copy_VdO.partition_S(gV) - tVsV = gmem_thr_copy_VdO.partition_D(sV) - # (CPY_Atom, CPY_M, CPY_K, m_block) - tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) - tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) - tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) - tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) - tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) - tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) - tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) + gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) + gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tQgQ = gmem_thr_copy_QK.partition_S(gQ) + tQsQ = gmem_thr_copy_QK.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K) + tKgK = gmem_thr_copy_QK.partition_S(gK) + tKsK = gmem_thr_copy_QK.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K) + tVgV = gmem_thr_copy_VdO.partition_S(gV) + tVsV = gmem_thr_copy_VdO.partition_D(sV) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) + tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) + tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) + tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) + tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) + tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) - thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) - thr_mma_dq = tiled_mma_dq.get_slice(tidx) - acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) - acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) - acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) - acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) - acc_dK.fill(0.0) - acc_dV.fill(0.0) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) + thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) + thr_mma_dq = tiled_mma_dq.get_slice(tidx) + acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) + acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) + acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) + acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) + acc_dK.fill(0.0) + acc_dV.fill(0.0) - tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) - tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) - tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) - tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) - tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) - tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) - tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) - tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) - tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) - tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) + tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) + tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) - LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) - tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] - tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] + LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) + tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] + tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, - ) - smem_copy_atom_transposed = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, - ) - smem_thr_copy_QdO = utils.make_tiled_copy_A( - smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB - ).get_slice(tidx) - smem_thr_copy_KV = utils.make_tiled_copy_B( - smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB - ).get_slice(tidx) - # TODO: should this be smem_copy_atom_transposed? - smem_thr_copy_PdSt = utils.make_tiled_copy_A( - smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB - ).get_slice(tidx) - smem_thr_copy_QdOt = utils.make_tiled_copy_B( - smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB - ).get_slice(tidx) - smem_thr_copy_dS = utils.make_tiled_copy_A( - smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB - ).get_slice(tidx) - smem_thr_copy_Kt = utils.make_tiled_copy_B( - smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB - ).get_slice(tidx) - # TODO: what's the number of bits? What if SdP_swapAB - r2s_thr_copy_PdS = cute.make_tiled_copy_C( - cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width - ), - tiled_mma_sdp, - ).get_slice(tidx) + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + ) + smem_copy_atom_transposed = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + ) + smem_thr_copy_QdO = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + smem_thr_copy_KV = utils.make_tiled_copy_B( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + # TODO: should this be smem_copy_atom_transposed? + smem_thr_copy_PdSt = utils.make_tiled_copy_A( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_QdOt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_dS = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + smem_thr_copy_Kt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + # TODO: what's the number of bits? What if SdP_swapAB + r2s_thr_copy_PdS = cute.make_tiled_copy_C( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ), + tiled_mma_sdp, + ).get_slice(tidx) - tSsQ = smem_thr_copy_QdO.partition_S(sQ) - tdPsdO = smem_thr_copy_QdO.partition_S(sdO) - tSsK = smem_thr_copy_KV.partition_S(sK) - tdPsV = smem_thr_copy_KV.partition_S(sV) - tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) - tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) - tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) - tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) - tdQsdS = smem_thr_copy_dS.partition_S(sdS) - tdQsKt = smem_thr_copy_Kt.partition_S(sKt) - tPsP = r2s_thr_copy_PdS.partition_D(sP) - tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) + tSsQ = smem_thr_copy_QdO.partition_S(sQ) + tdPsdO = smem_thr_copy_QdO.partition_S(sdO) + tSsK = smem_thr_copy_KV.partition_S(sK) + tdPsV = smem_thr_copy_KV.partition_S(sV) + tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) + tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) + tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) + tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) + tdQsdS = smem_thr_copy_dS.partition_S(sdS) + tdQsKt = smem_thr_copy_Kt.partition_S(sKt) + tPsP = r2s_thr_copy_PdS.partition_D(sP) + tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tQcQ = gmem_thr_copy_QK.partition_S(cQ) - t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) - if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): - tdOcdO = tQcQ - t0dOcdO = t0QcQ - else: - cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) - t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) - cLSE = cute.make_identity_tensor((self.m_block_size,)) - tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQcQ = gmem_thr_copy_QK.partition_S(cQ) + t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdOcdO = tQcQ + t0dOcdO = t0QcQ + else: + cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) + t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) + cLSE = cute.make_identity_tensor((self.m_block_size,)) + tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) - # Allocate predicate tensors for m and n, here we only allocate the tile of k, and - # use "if" on the mn dimension. - # This is to reduce register pressure and gets 2-3% performance gain. - tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[3]) - if cutlass.const_expr(self.same_hdim_kv): - tdOpdO = tQpQ - else: - tdOpdO = utils.predicate_k(tdOcdO, limit=mdO.shape[3]) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[3]) + if cutlass.const_expr(self.same_hdim_kv): + tdOpdO = tQpQ + else: + tdOpdO = utils.predicate_k(tdOcdO, limit=mdO.shape[3]) - # group parameters for compute_one_m_block - mma_params = SimpleNamespace( - thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, - tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, - tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, - tdQrdS=tdQrdS, tdQrK=tdQrK, - acc_dK=acc_dK, acc_dV=acc_dV, - ) - smem_copy_params = SimpleNamespace( - smem_thr_copy_QdO=smem_thr_copy_QdO, - smem_thr_copy_KV=smem_thr_copy_KV, - smem_thr_copy_PdSt=smem_thr_copy_PdSt, - smem_thr_copy_QdOt=smem_thr_copy_QdOt, - smem_thr_copy_dS=smem_thr_copy_dS, - smem_thr_copy_Kt=smem_thr_copy_Kt, - r2s_thr_copy_PdS=r2s_thr_copy_PdS, - tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, - tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, - tPsP=tPsP, tdSsdS=tdSsdS, - tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, - tdQsdS=tdQsdS, tdQsKt=tdQsKt, - ) - gmem_copy_params = SimpleNamespace( - gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum - ) - seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1]) - load_Q_LSE = partial( - self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, - tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, - tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q - ) - load_dO_dPsum = partial( - self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, - tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, - tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q - ) - compute_one_m_block = partial( - self.compute_one_m_block, mma_params=mma_params, - smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, - load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, - m_block_max=m_block_max, - softmax_scale_log2=softmax_scale_log2, - ) + # group parameters for compute_one_m_block + mma_params = SimpleNamespace( + thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, + tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, + tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, + tdQrdS=tdQrdS, tdQrK=tdQrK, + acc_dK=acc_dK, acc_dV=acc_dV, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_QdO=smem_thr_copy_QdO, + smem_thr_copy_KV=smem_thr_copy_KV, + smem_thr_copy_PdSt=smem_thr_copy_PdSt, + smem_thr_copy_QdOt=smem_thr_copy_QdOt, + smem_thr_copy_dS=smem_thr_copy_dS, + smem_thr_copy_Kt=smem_thr_copy_Kt, + r2s_thr_copy_PdS=r2s_thr_copy_PdS, + tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, + tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, + tPsP=tPsP, tdSsdS=tdSsdS, + tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, + tdQsdS=tdQsdS, tdQsKt=tdQsKt, + ) + gmem_copy_params = SimpleNamespace( + gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum + ) + seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1]) + load_Q_LSE = partial( + self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, + tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, + tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + load_dO_dPsum = partial( + self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, + tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, + tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + compute_one_m_block = partial( + self.compute_one_m_block, mma_params=mma_params, + smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, + load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, + m_block_max=m_block_max, + softmax_scale_log2=softmax_scale_log2, + ) - # /////////////////////////////////////////////////////////////////////////////// - # Prologue - # /////////////////////////////////////////////////////////////////////////////// - # Start async loads of the last mn-tile, where we take care of the mn residue - self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, - headdim=mV.shape[3]) - if cutlass.const_expr(self.V_in_regs): + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, + headdim=mV.shape[3]) + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_commit_group() + self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, + headdim=mK.shape[3]) cute.arch.cp_async_commit_group() - self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, - headdim=mK.shape[3]) - cute.arch.cp_async_commit_group() - if cutlass.const_expr(self.V_in_regs): - cute.arch.cp_async_wait_group(1) - cute.arch.barrier() - tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) - cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) - # Sync to avoid loading Q to smem_q, which overlaps with smem_v - cute.arch.barrier() + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_wait_group(1) + cute.arch.barrier() + tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) + cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) + # Sync to avoid loading Q to smem_q, which overlaps with smem_v + cute.arch.barrier() - m_block = m_block_min - assert self.num_stages_Q >= self.num_stages_dO - for stage in cutlass.range_constexpr(self.num_stages_Q): - if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): - if stage == 0 or m_block + stage < m_block_max: - load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) - cute.arch.cp_async_commit_group() - if cutlass.const_expr(stage < self.num_stages_dO): - if stage == 0 or m_block + stage < m_block_max: - load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) - cute.arch.cp_async_commit_group() + m_block = m_block_min + assert self.num_stages_Q >= self.num_stages_dO + for stage in cutlass.range_constexpr(self.num_stages_Q): + if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): + if stage == 0 or m_block + stage < m_block_max: + load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + if cutlass.const_expr(stage < self.num_stages_dO): + if stage == 0 or m_block + stage < m_block_max: + load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() - # /////////////////////////////////////////////////////////////////////////////// - # Mainloop - # /////////////////////////////////////////////////////////////////////////////// - # Start processing of the first n-block. - mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( - mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, - mask_seqlen=True, mask_causal=self.is_causal - ) - smem_pipe_read_q = cutlass.Int32(0) - smem_pipe_read_do = cutlass.Int32(0) - smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) - smem_pipe_write_do = cutlass.Int32(0) - for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): - compute_one_m_block( - m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, - mask_fn=mask_fn, + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + mask_seqlen=True, mask_causal=self.is_causal ) - smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) - smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) - smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) - smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) + smem_pipe_read_q = cutlass.Int32(0) + smem_pipe_read_do = cutlass.Int32(0) + smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) + smem_pipe_write_do = cutlass.Int32(0) + for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): + compute_one_m_block( + m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, + mask_fn=mask_fn, + ) + smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) + smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) + smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) + smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - # If GQA, we scale dK in the postprocessing kernel instead - if cutlass.const_expr(self.qhead_per_kvhead == 1): - acc_dK.store(acc_dK.load() * softmax_scale) - # reuse sK and sV data iterator - sdK = cute.make_tensor(sK.iterator, sK_layout) - sdV = cute.make_tensor(sV.iterator, sV_layout) - self.epilogue( - acc_dK, acc_dV, mdK, mdV, sdK, sdV, - gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, - tidx, n_block, head_idx, batch_idx - ) + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # If GQA, we scale dK in the postprocessing kernel instead + if cutlass.const_expr(self.qhead_per_kvhead == 1): + acc_dK.store(acc_dK.load() * softmax_scale) + # reuse sK and sV data iterator + sdK = cute.make_tensor(sK.iterator, sK_layout) + sdV = cute.make_tensor(sV.iterator, sV_layout) + self.epilogue( + acc_dK, acc_dV, mdK, mdV, sdK, sdV, + gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, + tidx, n_block, head_idx, batch_idx + ) @cute.jit def compute_one_m_block( diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 0abe36d39c3..4dec60a9298 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -217,97 +217,98 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() m_block, num_head, batch_size = cute.arch.block_idx() - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile( - mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) - ) - blkdQ_shape = (self.m_block_size, self.head_dim_padded) - gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) - - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - smem = cutlass.utils.SmemAllocator() - sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) - sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) - - seqlen_q = mdQ.shape[1] - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) - - # Step 1: load dQaccum from gmem to smem - g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) - tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) - tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum) - # print(tdQgdQaccum) - # print(tdQsdQaccum) - cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) - cute.arch.cp_async_commit_group() - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() - - # Step 2: load dQ from smem to rmem - s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) - tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) - # print(s2r_tiled_copy_dQaccum) - # print(sdQaccum) - # thr_mma = tiled_mma.get_slice(tidx) - # print(tiled_mma) - acc_shape = tiled_mma.partition_shape_C( - (self.m_block_size, self.head_dim_padded) - if cutlass.const_expr(not dQ_swapAB) - else (self.head_dim_padded, self.m_block_size) - ) - acc = cute.make_fragment(acc_shape, cutlass.Float32) - assert cute.size(acc) == cute.size(tdQsdQaccum) - tdQrdQaccum = s2r_thr_copy_dQaccum.retile(acc) - # Somehow even after retiling the layouts of tdQsdQaccum and tdQrdQaccum are different. - # So we have to do a for loop to copy - # cute.copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum) - # print(acc) - # print(tdQsdQaccum) # ((1, 1), 64) - # print(tdQrdQaccum) # ((1, 4), 4, 4) - for i in cutlass.range(cute.size(tdQsdQaccum), unroll_full=True): - tdQrdQaccum[i] = tdQsdQaccum[i] - # Convert tdQrdQaccum from fp32 to fp16/bf16 - rdQ = cute.make_fragment_like(acc, self.dtype) - rdQ.store((acc.load() * scale).to(self.dtype)) - - # Step 3: Copy dQ from register to smem - cute.arch.barrier() # make sure all threads have finished loading dQaccum - smem_copy_atom_dQ = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width - ) - smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) - taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) - taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) - cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) - # print(taccdQrdQ) - # print(taccdQsdQ) - - # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem - gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) - tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) - tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) - tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) - cute.arch.barrier() # make sure all smem stores are done - # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled - cute.autovec_copy(tdQsdQ, tdQrdQ) - - # Step 5: Copy dQ from register to gmem - cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) - tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) - for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): - if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: - cute.copy( - gmem_tiled_copy_dQ, - tdQrdQ[None, rest_m, None], - tdQgdQ[None, rest_m, None], - pred=tdQpdQ[None, rest_m, None], - ) + if True: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) + blkdQ_shape = (self.m_block_size, self.head_dim_padded) + gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + + seqlen_q = mdQ.shape[1] + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + + # Step 1: load dQaccum from gmem to smem + g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum) + # print(tdQgdQaccum) + # print(tdQsdQaccum) + cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Step 2: load dQ from smem to rmem + s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + # print(s2r_tiled_copy_dQaccum) + # print(sdQaccum) + # thr_mma = tiled_mma.get_slice(tidx) + # print(tiled_mma) + acc_shape = tiled_mma.partition_shape_C( + (self.m_block_size, self.head_dim_padded) + if cutlass.const_expr(not dQ_swapAB) + else (self.head_dim_padded, self.m_block_size) + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + tdQrdQaccum = s2r_thr_copy_dQaccum.retile(acc) + # Somehow even after retiling the layouts of tdQsdQaccum and tdQrdQaccum are different. + # So we have to do a for loop to copy + # cute.copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum) + # print(acc) + # print(tdQsdQaccum) # ((1, 1), 64) + # print(tdQrdQaccum) # ((1, 4), 4, 4) + for i in cutlass.range(cute.size(tdQsdQaccum), unroll_full=True): + tdQrdQaccum[i] = tdQsdQaccum[i] + # Convert tdQrdQaccum from fp32 to fp16/bf16 + rdQ = cute.make_fragment_like(acc, self.dtype) + rdQ.store((acc.load() * scale).to(self.dtype)) + + # Step 3: Copy dQ from register to smem + cute.arch.barrier() # make sure all threads have finished loading dQaccum + smem_copy_atom_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width + ) + smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) + taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) + taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) + cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) + # print(taccdQrdQ) + # print(taccdQsdQ) + + # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) + tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) + tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) + tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) + cute.arch.barrier() # make sure all smem stores are done + # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled + cute.autovec_copy(tdQsdQ, tdQrdQ) + + # Step 5: Copy dQ from register to gmem + cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) + tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) + for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): + if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: + cute.copy( + gmem_tiled_copy_dQ, + tdQrdQ[None, rest_m, None], + tdQgdQ[None, rest_m, None], + pred=tdQpdQ[None, rest_m, None], + ) class FlashAttentionBackwardPostprocess_sm90(FlashAttentionBackwardPostprocess): diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 13080d7c2e4..e30fc6232a9 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -163,13 +163,14 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() m_block, num_head, batch_size = cute.arch.block_idx() - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkOdO_shape = (self.m_block_size, self.head_dim_padded) - # (m_block_size, head_dim) - gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) - gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + if True: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkOdO_shape = (self.m_block_size, self.head_dim_padded) + # (m_block_size, head_dim) + gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K) @@ -187,8 +188,8 @@ def kernel( tOpO = utils.predicate_k(tOcO, limit=mO.shape[3]) tOpdO = utils.predicate_k(tOcO, limit=mdO.shape[3]) - seqlen_q = mO.shape[1] - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + seqlen_q = mO.shape[1] + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) if cutlass.const_expr(mLSE is not None): gLSE = cute.local_tile( @@ -239,17 +240,17 @@ def kernel( row = tOcO[0, m, 0][0] gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0 - # Clear dQaccum - if cutlass.const_expr(mdQaccum is not None): - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile( - mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) - ) - gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) - tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - zero = cute.make_fragment_like(tQgQaccum) - zero.fill(0.0) - cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) + # Clear dQaccum + if cutlass.const_expr(mdQaccum is not None): + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + zero = cute.make_fragment_like(tQgQaccum) + zero.fill(0.0) + cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) if cutlass.const_expr(mLSE is not None): gLSElog2 = cute.local_tile( From 25f5d092b21d2d6b005ccd34092479a620ae4ceb Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:11:18 -0400 Subject: [PATCH 129/258] Adding varlen support + tests --- flash_attn/cute/flash_bwd.py | 213 ++++++++++++---- flash_attn/cute/flash_bwd_postprocess.py | 98 +++++++- flash_attn/cute/flash_bwd_preprocess.py | 244 +++++++++++++------ flash_attn/cute/interface.py | 159 +++++++++--- tests/cute/test_flash_attn_varlen.py | 298 +++++++++++++++++++++++ 5 files changed, 834 insertions(+), 178 deletions(-) create mode 100644 tests/cute/test_flash_attn_varlen.py diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 93a7ec84b12..4d3bbe7d185 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -17,6 +17,7 @@ from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionBackwardSm80: @@ -31,6 +32,7 @@ def __init__( num_stages_Q: int = 2, num_stages_dO: int = 2, num_threads: int = 256, + pack_gqa: bool = False, is_causal: bool = False, SdP_swapAB: bool = False, dKV_swapAB: bool = False, @@ -69,6 +71,7 @@ def __init__( self.m_block_size = m_block_size self.n_block_size = n_block_size self.num_threads = num_threads + self.pack_gqa = pack_gqa self.is_causal = is_causal self.num_stages_Q = num_stages_Q self.num_stages_dO = num_stages_dO @@ -141,6 +144,10 @@ def _check_type( mdQaccum_type: Type[cutlass.Numeric], mdK_type: Type[cutlass.Numeric], mdV_type: Type[cutlass.Numeric], + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, ): if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): raise TypeError("All tensors must have the same data type") @@ -158,6 +165,14 @@ def _check_type( raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") + if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensQ tensor must be Int32") + if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensK tensor must be Int32") + if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedQ tensor must be Int32") + if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedK tensor must be Int32") assert mQ_type == self.dtype def _setup_attributes(self): @@ -245,11 +260,22 @@ def _setup_attributes(self): self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout) self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout) async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width - atom_async_copy_accum = cute.make_copy_atom( - cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), - cutlass.Float32, - num_bits_per_copy=universal_copy_bits, - ) + + # I think we wouldn't require this with smarter padding + if cutlass.const_expr(not self.varlen_q): + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + else: + async_copy_elems_accum = 1 + atom_async_copy_accum = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=cutlass.Float32.width, + ) self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( atom_async_copy_accum, cute.make_layout(self.num_threads), @@ -343,22 +369,49 @@ def __call__( mdV: cute.Tensor, softmax_scale: cutlass.Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, ): # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None - for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV))) + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] + self.varlen_q = (mCuSeqlensQ is not None) self._setup_attributes() SharedStorage = self._get_shared_storage_cls() tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() - # grid_dim: (n_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mK.shape[1], self.n_block_size), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[0]), + + num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2] + + if cutlass.const_expr(mCuSeqlensK is not None): + TileScheduler = SingleTileVarlenScheduler + num_batch = mCuSeqlensK.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_batch = mK.shape[0] + + # Uses seqlen k, etc. since main bwd kernel's blocks are over n + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mK.shape[1], self.n_block_size), + num_head=num_head, + num_batch=num_batch, + seqlen_k=0, + headdim=mK.shape[2], + headdim_v=mV.shape[2], + total_q=mK.shape[0], + tile_shape_mn=(self.n_block_size, self.m_block_size), + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2 = softmax_scale * math.log2(math.e) self.kernel( mQ, @@ -370,6 +423,10 @@ def __call__( mdQaccum, mdK, mdV, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, softmax_scale, softmax_scale_log2, self.sQ_layout, @@ -389,6 +446,8 @@ def __call__( tiled_mma_dkv, tiled_mma_dq, SharedStorage, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -408,6 +467,10 @@ def kernel( mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, sQ_layout: cute.ComposedLayout, @@ -427,40 +490,68 @@ def kernel( tiled_mma_dkv: cute.TiledMma, tiled_mma_dq: cute.TiledMma, SharedStorage: cutlass.Constexpr, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - n_block, head_idx, batch_idx = cute.arch.block_idx() - if True: - m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + n_block, head_idx, batch_idx = work_tile.tile_idx + + if work_tile.is_valid_tile: + seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + + m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) m_block_min = 0 if cutlass.const_expr(self.is_causal): m_block_min = max( - (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, + (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size, m_block_min, ) # TODO: return early if m_block_max == 0 - # /////////////////////////////////////////////////////////////////////////////// - # Get the appropriate tiles for this thread block. - # /////////////////////////////////////////////////////////////////////////////// - blkQ_shape = (self.m_block_size, self.head_dim_padded) - blkK_shape = (self.n_block_size, self.head_dim_padded) - blkV_shape = (self.n_block_size, self.head_dim_v_padded) - blkdO_shape = (self.m_block_size, self.head_dim_v_padded) - # (m_block_size, head_dim, m_block) - gQ = cute.local_tile(mQ[batch_idx, None, head_idx, None], blkQ_shape, (None, 0)) - # (n_block_size, head_dim) - head_idx_kv = head_idx // self.qhead_per_kvhead - gK = cute.local_tile(mK[batch_idx, None, head_idx_kv, None], blkK_shape, (n_block, 0)) - # (n_block_size, head_dim_v) - gV = cute.local_tile(mV[batch_idx, None, head_idx_kv, None], blkV_shape, (n_block, 0)) - # (m_block_size, head_dim_v, m_block) - gdO = cute.local_tile(mdO[batch_idx, None, head_idx, None], blkdO_shape, (None, 0)) - gLSE = cute.local_tile(mLSE[batch_idx, head_idx, None], (self.m_block_size,), (None,)) - gdPsum = cute.local_tile(mdPsum[batch_idx, head_idx, None], (self.m_block_size,), (None,)) - gdQaccum = cute.local_tile(mdQaccum[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.m_block_size, self.head_dim_padded) + blkK_shape = (self.n_block_size, self.head_dim_padded) + blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkdO_shape = (self.m_block_size, self.head_dim_v_padded) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[batch_idx, None, head_idx, None] + mLSE_cur = mLSE[batch_idx, head_idx, None] + mdO_cur = mdO[batch_idx, None, head_idx, None] + mdPsum_cur = mdPsum[batch_idx, head_idx, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] + else: + padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) + mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]) + head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)] + + # (m_block_size, head_dim, m_block) + gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0)) + # (n_block_size, head_dim) + gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0)) + # (n_block_size, head_dim_v) + gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0)) + # (m_block_size, head_dim_v, m_block) + gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0)) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -604,11 +695,15 @@ def kernel( # Allocate predicate tensors for m and n, here we only allocate the tile of k, and # use "if" on the mn dimension. # This is to reduce register pressure and gets 2-3% performance gain. - tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[3]) + + d_head = mQ.shape[cute.rank(mQ) - 1] + d_head_v = mdO.shape[cute.rank(mdO) - 1] + + tQpQ = utils.predicate_k(tQcQ, limit=d_head) if cutlass.const_expr(self.same_hdim_kv): tdOpdO = tQpQ else: - tdOpdO = utils.predicate_k(tdOcdO, limit=mdO.shape[3]) + tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v) # group parameters for compute_one_m_block mma_params = SimpleNamespace( @@ -635,7 +730,6 @@ def kernel( gmem_copy_params = SimpleNamespace( gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum ) - seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1]) load_Q_LSE = partial( self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, @@ -659,11 +753,11 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Start async loads of the last mn-tile, where we take care of the mn residue self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, - headdim=mV.shape[3]) + headdim=d_head_v) if cutlass.const_expr(self.V_in_regs): cute.arch.cp_async_commit_group() self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, - headdim=mK.shape[3]) + headdim=d_head) cute.arch.cp_async_commit_group() if cutlass.const_expr(self.V_in_regs): @@ -721,7 +815,7 @@ def kernel( self.epilogue( acc_dK, acc_dV, mdK, mdV, sdK, sdV, gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, - tidx, n_block, head_idx, batch_idx + tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v ) @cute.jit @@ -853,7 +947,6 @@ def dQ_mma(hook_fn): acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ) tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) - # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(acc_dQ) for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) @@ -898,6 +991,9 @@ def epilogue( n_block: cutlass.Int32, num_head: cutlass.Int32, batch_size: cutlass.Int32, + seqlen: SeqlenInfoQK, + d_head: cutlass.Int32, + d_head_v: cutlass.Int32 ): rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) @@ -906,6 +1002,9 @@ def epilogue( gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + batch_idx = batch_size + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + if cutlass.const_expr(self.qhead_per_kvhead == 1): # Make sure all threads have finished reading K and V, otherwise we get racy dQ # because smem_q could be changed. @@ -923,10 +1022,16 @@ def epilogue( cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)] + else: + mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)] + blkdK_shape = (self.n_block_size, self.head_dim_padded) blkdV_shape = (self.n_block_size, self.head_dim_v_padded) - gdK = cute.local_tile(mdK[batch_size, None, num_head, None], blkdK_shape, (n_block, 0)) - gdV = cute.local_tile(mdV[batch_size, None, num_head, None], blkdV_shape, (n_block, 0)) + gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0)) + gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0)) tdKsdK = gmem_thr_copy_dK.partition_S(sdK) tdKgdK = gmem_thr_copy_dK.partition_D(gdK) tdVsdV = gmem_thr_copy_dV.partition_S(sdV) @@ -951,14 +1056,14 @@ def epilogue( cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) tdVcdV = gmem_thr_copy_dV.partition_S(cdV) t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) - tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[3]) + tdKpdK = utils.predicate_k(tdKcdK, limit=d_head) if cutlass.const_expr(self.same_hdim_kv): tdVpdV = tdKpdK else: - tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[3]) + tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v) # copy acc dK and acc_dV from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])): - if t0dKcdK[0, rest_m, 0][0] < mdK.shape[1] - n_block * self.n_block_size - tdKcdK[0][0]: + if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]: cute.copy( gmem_tiled_copy_dK, tdKrdK[None, rest_m, None], @@ -966,7 +1071,7 @@ def epilogue( pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): - if t0dVcdV[0, rest_m, 0][0] < mdV.shape[1] - n_block * self.n_block_size - tdVcdV[0][0]: + if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]: cute.copy( gmem_tiled_copy_dV, tdVrdV[None, rest_m, None], @@ -977,9 +1082,17 @@ def epilogue( else: # qhead_per_kvhead > 1, do atomic add # For Sm90, we need to sync to avoid racy writes to smem_q # For Sm80, we don't need to sync since we're not touching smem - num_head_kv = num_head // self.qhead_per_kvhead - gdV = cute.local_tile(mdV[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_v_padded,), (n_block,)) - gdK = cute.local_tile(mdK[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_padded,), (n_block,)) + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)] + else: + padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size + mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None]) + mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None]) + + gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,)) + gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,)) tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV) tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK) acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 4dec60a9298..8adb4963815 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -2,7 +2,7 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math -from typing import Type +from typing import Callable, Optional, Type import cuda.bindings.driver as cuda @@ -12,6 +12,13 @@ from flash_attn.cute import ampere_helpers as sm80_utils import cutlass.utils.hopper_helpers as sm90_utils_basic from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments +) class FlashAttentionBackwardPostprocess: @@ -142,6 +149,8 @@ def __call__( mdQaccum: cute.Tensor, mdQ: cute.Tensor, scale: cutlass.Float32, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 @@ -175,15 +184,39 @@ def __call__( cute.size_in_bytes(self.dtype, self.sdQ_layout), ) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mdQ.shape[1], self.m_block_size), - cute.size(mdQ.shape[2]), - cute.size(mdQ.shape[0]), + if cutlass.const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mdQ.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_head = mdQ.shape[2] + num_batch = mdQ.shape[0] + + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mdQ.shape[1], self.m_block_size), + num_head=num_head, + num_batch=num_batch, + seqlen_k=0, + headdim=mdQ.shape[2], + headdim_v=0, + total_q=mdQ.shape[0], + tile_shape_mn=(self.m_block_size, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + + # grid_dim: (m_block, num_head, batch_size) self.kernel( mdQaccum, mdQ, + mCuSeqlensQ, + mSeqUsedQ, scale, tiled_mma, self.dQ_swapAB, @@ -192,6 +225,8 @@ def __call__( self.g2s_tiled_copy_dQaccum, self.s2r_tiled_copy_dQaccum, self.gmem_tiled_copy_dQ, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, block=[tiled_mma.size, 1, 1], @@ -204,6 +239,8 @@ def kernel( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], scale: cutlass.Float32, tiled_mma: cute.TiledMma, dQ_swapAB: cutlass.Constexpr, @@ -212,21 +249,54 @@ def kernel( g2s_tiled_copy_dQaccum: cute.TiledCopy, s2r_tiled_copy_dQaccum: cute.TiledCopy, gmem_tiled_copy_dQ: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() - if True: + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + m_block, num_head, batch_size = work_tile.tile_idx + + if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// + + seqlen = SeqlenInfoQK(batch_size, mdQ.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None) + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mdQ_cur = mdQ[batch_size, None, num_head, None] + mdQaccum_cur = mdQaccum[batch_size, num_head, None] + head_dim = mdQ.shape[3] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None]) + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]) + head_dim = mdQ.shape[2] + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.head_dim_padded keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor( + mdQaccum_cur_ptr, + mdQaccum_cur.layout + ) + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) gdQaccum = cute.local_tile( - mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + mdQaccum_cur, blkdQaccum_shape, (m_block,) ) blkdQ_shape = (self.m_block_size, self.head_dim_padded) - gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) + gdQ = cute.local_tile(mdQ_cur, blkdQ_shape, (m_block, 0)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -235,7 +305,7 @@ def kernel( sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) - seqlen_q = mdQ.shape[1] + seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) # Step 1: load dQaccum from gmem to smem @@ -300,9 +370,9 @@ def kernel( # Step 5: Copy dQ from register to gmem cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) - tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) + tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): - if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: + if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.m_block_size: cute.copy( gmem_tiled_copy_dQ, tdQrdQ[None, rest_m, None], @@ -357,6 +427,8 @@ def __call__( mdQaccum: cute.Tensor, mdQ: cute.Tensor, scale: cutlass.Float32, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], stream: cuda.CUstream, ): # Assume all strides are divisible by 128 bits except the last stride diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index e30fc6232a9..ee6535be527 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -3,7 +3,7 @@ # from Cutlass C++ to Cute-DSL. import math import operator -from typing import Type, Optional +from typing import Callable, Type, Optional import cuda.bindings.driver as cuda @@ -13,6 +13,8 @@ from flash_attn.cute import utils from flash_attn.cute import copy_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionBackwardPreprocess: @@ -101,6 +103,8 @@ def __call__( mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 @@ -126,12 +130,32 @@ def __call__( self._setup_attributes() - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mO.shape[1], self.m_block_size), - cute.size(mO.shape[2]), - cute.size(mO.shape[0]), + if cutlass.const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mO.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_head = mO.shape[2] + num_batch = mO.shape[0] + + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mO.shape[1], self.m_block_size), + num_head=num_head, + num_batch=num_batch, + seqlen_k=0, + headdim=0, + headdim_v=mO.shape[2], + total_q=mO.shape[0], + tile_shape_mn=(self.m_block_size, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + self.kernel( mO, mdO, @@ -139,8 +163,12 @@ def __call__( mLSE, mLSElog2, mdQaccum, + mCuSeqlensQ, + mSeqUsedQ, self.gmem_tiled_copy_O, self.gmem_tiled_copy_dQaccum, + tile_sched_params, + TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -156,95 +184,143 @@ def kernel( mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() - if True: + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + m_block, num_head, batch_size = work_tile.tile_idx + + if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// + seqlen = SeqlenInfoQK(batch_size, mO.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[batch_size, None, num_head, None] + mdO_cur = mdO[batch_size, None, num_head, None] + mdPsum_cur = mdPsum[batch_size, num_head, None] + headdim_v = mO.shape[3] + else: + mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None]) + + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None]) + headdim_v = mO.shape[2] + blkOdO_shape = (self.m_block_size, self.head_dim_padded) # (m_block_size, head_dim) - gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) - gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) + gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0)) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - # (CPY_Atom, CPY_M, CPY_K) - tOgO = gmem_thr_copy_O.partition_S(gO) - tOgdO = gmem_thr_copy_O.partition_S(gdO) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tOgO = gmem_thr_copy_O.partition_S(gO) + tOgdO = gmem_thr_copy_O.partition_S(gdO) - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[3]) - tOpdO = utils.predicate_k(tOcO, limit=mdO.shape[3]) - - seqlen_q = mO.shape[1] + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=headdim_v) + tOpdO = utils.predicate_k(tOcO, limit=headdim_v) + + seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) - if cutlass.const_expr(mLSE is not None): - gLSE = cute.local_tile( - mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,) - ) - lse = Float32.inf - if tidx < seqlen_q - m_block * self.m_block_size: - lse = gLSE[tidx] - - tOrO = cute.make_fragment_like(tOgO) - tOrdO = cute.make_fragment_like(tOgdO) - assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) - assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) - assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) - for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): - # Instead of using tOcO, we using t0OcO and subtract the offset from the limit - # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. - if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: - cute.copy( - gmem_thr_copy_O, - tOgO[None, m, None], - tOrO[None, m, None], - pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, - ) - cute.copy( - gmem_thr_copy_O, - tOgdO[None, m, None], - tOrdO[None, m, None], - pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[batch_size, num_head, None] + else: + mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None]) + + gLSE = cute.local_tile( + mLSE_cur, (self.m_block_size,), (m_block,) ) - # Sum across the "k" dimension - dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( - cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) - ) - threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] - assert cute.arch.WARP_SIZE % threads_per_row == 0 - dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) - dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) - dP_sum.store(dpsum) - - # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile( - mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,) - ) - # Only the thread corresponding to column 0 writes out the lse to gmem - if tOcO[0, 0, 0][1] == 0: - for m in cutlass.range(cute.size(dP_sum), unroll_full=True): - row = tOcO[0, m, 0][0] - gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0 + lse = Float32.inf + if tidx < seqlen_q - m_block * self.m_block_size: + lse = gLSE[tidx] + + tOrO = cute.make_fragment_like(tOgO) + tOrdO = cute.make_fragment_like(tOgdO) + assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) + assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) + assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit + # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_thr_copy_O, + tOgO[None, m, None], + tOrO[None, m, None], + pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + cute.copy( + gmem_thr_copy_O, + tOgdO[None, m, None], + tOrdO[None, m, None], + pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + # Sum across the "k" dimension + dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) + ) + threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] + assert cute.arch.WARP_SIZE % threads_per_row == 0 + dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) + dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) + dP_sum.store(dpsum) + + # Write dPsum from rmem -> gmem + gdPsum = cute.local_tile( + mdPsum_cur, (self.m_block_size,), (m_block,) + ) + # Only the thread corresponding to column 0 writes out the dPsum to gmem + if tOcO[0, 0, 0][1] == 0: + for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + row = tOcO[0, m, 0][0] + gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0 # Clear dQaccum if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[batch_size, num_head, None] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]) + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.head_dim_padded keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor( + mdQaccum_cur_ptr, + mdQaccum_cur.layout + ) + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) gdQaccum = cute.local_tile( - mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + mdQaccum_cur, blkdQaccum_shape, (m_block,) ) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) @@ -252,10 +328,16 @@ def kernel( zero.fill(0.0) cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) - if cutlass.const_expr(mLSE is not None): - gLSElog2 = cute.local_tile( - mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,) - ) - LOG2_E = math.log2(math.e) - if tidx < seqlen_q_rounded - m_block * self.m_block_size: - gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSElog2_cur = mLSElog2[batch_size, num_head, None] + else: + padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None]) + + gLSElog2 = cute.local_tile( + mLSElog2_cur, (self.m_block_size,), (m_block,) + ) + LOG2_E = math.log2(math.e) + if tidx < seqlen_q_rounded - m_block * self.m_block_size: + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 15c81b8c1db..a2a5a44a0fb 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -298,6 +298,7 @@ def _flash_attn_bwd( m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, + pack_gqa: bool = False, num_stages_Q: int = 2, num_stages_dO: int = 2, SdP_swapAB: bool = False, @@ -307,20 +308,61 @@ def _flash_attn_bwd( AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 2, V_in_regs: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, out, dout, lse = [maybe_contiguous(t) for t in (q, k, v, out, dout, lse)] - batch_size, seqlen_q, num_head, head_dim = q.shape - _, seqlen_k, num_head_kv, _ = k.shape - _, _, _, head_dim_v = v.shape - assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) - assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) - assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ + maybe_contiguous(t) + for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + + if cu_seqlens_k is None: + batch_size, seqlen_k = k.shape[:2] + total_k = batch_size * seqlen_k + else: + batch_size = cu_seqlens_k.shape[0] - 1 + seqlen_k = None + total_k = k.shape[0] + + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + + if cu_seqlens_k is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (total_k, num_head_kv, head_dim) + assert v.shape == (total_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" + + assert out.shape == (total_q, num_head, head_dim_v) + assert dout.shape == (total_q, num_head, head_dim_v) + assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" + else: + assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, "inputs must have the same dtype" + for t in [cu_seqlens_q, cu_seqlens_k]: + if t is not None: + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all(t.is_cuda for t in (q, k, v, out, dout, lse)), "inputs must be on CUDA device" + assert all(t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -329,38 +371,58 @@ def _flash_attn_bwd( if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 device = q.device # TODO: check if this is the right rounding - seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) - dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) - lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + + head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 + + if cu_seqlens_q is None: + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) + dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + else: + total_q_rounded_padded = (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size + dq_accum = torch.empty(num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) + dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + if qhead_per_kvhead > 1: - seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 - dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + if cu_seqlens_k is None: + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) + dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + else: + total_k_rounded_padded = (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + dk_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) + dv_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_v_rounded, dtype=torch.float32, device=device) dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dq_accum, dpsum, lse_log2) ] if qhead_per_kvhead > 1: dk_accum_tensor, dv_accum_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dk_accum, dv_accum) ] + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim-1) if t is not None else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. @@ -372,16 +434,17 @@ def _flash_attn_bwd( # TODO: check @can_implement _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( fa_bwd_pre, o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, - dq_accum_tensor, current_stream + dq_accum_tensor, cu_seqlens_q_tensor, seqused_q_tensor, current_stream ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, current_stream + o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, + cu_seqlens_q_tensor, seqused_q_tensor, current_stream ) # Backward kernel: compute dk, dv, dq_accum. compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, m_block_size, - n_block_size, num_threads, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, + n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs ) m_block_size = 64 @@ -397,6 +460,7 @@ def _flash_attn_bwd( num_stages_Q, num_stages_dO, num_threads, + pack_gqa, causal, SdP_swapAB, dKV_swapAB, @@ -433,14 +497,24 @@ def _flash_attn_bwd( dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, - softmax_scale, current_stream + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, - softmax_scale, current_stream + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, ) # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 @@ -452,10 +526,11 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, current_stream + fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, + seqused_q_tensor, current_stream ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum_tensor, dq_tensor, softmax_scale, current_stream + dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, seqused_q_tensor, current_stream ) if qhead_per_kvhead > 1: @@ -467,10 +542,10 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, current_stream + fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum_tensor, dk_tensor, softmax_scale, current_stream + dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream ) compile_key_post = (dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: @@ -479,10 +554,10 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream ) return dq, dk, dv @@ -591,10 +666,26 @@ def forward( @staticmethod def backward(ctx, dout, *args): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - raise NotImplementedError( - "Backward pass for FlashAttention with variable length sequences is not implemented yet." + assert seqused_q == seqused_k == None + assert ctx.softcap == 0.0 + dq, dk, dv = _flash_attn_bwd( + q, + k, + v, + out, + dout, + lse, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, ) + return dq, dk, dv, *((None,) * 11) + def flash_attn_func( q: torch.Tensor, diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py new file mode 100644 index 00000000000..3a514664449 --- /dev/null +++ b/tests/cute/test_flash_attn_varlen.py @@ -0,0 +1,298 @@ +import itertools +from typing import Optional +from einops import rearrange +import pytest + +import torch +import torch.nn.functional as F +from flash_attn.cute import flash_attn_varlen_func + +@pytest.mark.parametrize("B", [1, 7, 20]) +@pytest.mark.parametrize("H", [1, 4, 6]) +@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("min_seq_len", [1, 32, 128]) +@pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("softmax_scale", [None, 0.1]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +def test_varlen( + B, + H, + D, + min_seq_len, + max_seq_len, + causal, + softmax_scale, + dtype, + mha_type, +): + if min_seq_len > max_seq_len: + pytest.skip("Skipping min_seq_len > max_seq_len") + + q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( + batch_size=B, + n_heads=H, + d_head=D, + min_len=min_seq_len, + max_len=max_seq_len, + mha_type=mha_type, + dtype=dtype + ) + + ok = check_backward_vs_torch_flash( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + total_q=total_q, total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + ) + assert ok + +def check_backward_vs_torch_flash( + q, k, v, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + total_q=None, + total_k=None, + softmax_scale=None, + causal=True, + mha_type='mha', + softcap=0.0, + atol=3e-2, + rtol=3e-2, +): + assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" + + def clone_like(t): + c = t.clone().detach().requires_grad_(True) + return c + + q_fa, k_fa, v_fa = map(clone_like, (q, k, v)) + q_t, k_t, v_t = map(clone_like, (q, k, v)) + + if cu_seqlens_q is not None: + cu_seqlens_q_fa = cu_seqlens_q.clone() + cu_seqlens_q_t = cu_seqlens_q.clone() + else: + cu_seqlens_q_fa = None + cu_seqlens_q_t = None + + if cu_seqlens_k is not None: + cu_seqlens_k_fa = cu_seqlens_k.clone() + cu_seqlens_k_t = cu_seqlens_k.clone() + else: + cu_seqlens_k_fa = None + cu_seqlens_k_t = None + + out_fa, lse_fa = flash_attn_varlen_func( + q_fa, k_fa, v_fa, + cu_seqlens_q=cu_seqlens_q_fa, + cu_seqlens_k=cu_seqlens_k_fa, + seqused_q=seqused_q, + seqused_k=seqused_k, + softmax_scale=(1.0 / q.shape[-1]**0.5) if softmax_scale is None else softmax_scale, + causal=causal, + window_size=(None, None), + learnable_sink=None, + softcap=softcap, + pack_gqa=None, + ) + + out_t = torch_flash_ref( + q_t, k_t, v_t, + cu_seqlens_q=cu_seqlens_q_t, + cu_seqlens_k=cu_seqlens_k_t, + seqused_q=seqused_q, + seqused_k=seqused_k, + total_q=total_q, + total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + ) + + # Use the same upstream gradient to compare backward paths + grad_out = torch.randn_like(out_fa) + + grad_fa = clone_like(grad_out) + grad_t = clone_like(grad_out) + + # Cute bwd + out_fa.backward(grad_fa, retain_graph=False) + dq_fa, dk_fa, dv_fa = q_fa.grad, k_fa.grad, v_fa.grad + + # Ref bwd + out_t.backward(grad_t, retain_graph=False) + dq_t, dk_t, dv_t = q_t.grad, k_t.grad, v_t.grad + + # mean_ok_q = _stats("dQ", dq_fa, dq_t, atol=atol, rtol=rtol) + # mean_ok_k = _stats("dK", dk_fa, dk_t, atol=atol, rtol=rtol) + # mean_ok_v = _stats("dV", dv_fa, dv_t, atol=atol, rtol=rtol) + + # return mean_ok_q and mean_ok_k and mean_ok_v + + ok_q = torch.allclose(dq_fa.float(), dq_t.float(), atol=atol, rtol=rtol) + ok_k = torch.allclose(dk_fa.float(), dk_t.float(), atol=atol, rtol=rtol) + ok_v = torch.allclose(dv_fa.float(), dv_t.float(), atol=atol, rtol=rtol) + # print(f"Close? dQ={ok_q}, dK={ok_k}, dV={ok_v}") + return ok_q and ok_k and ok_v + +def generate_varlen_args( + batch_size=8, + n_heads=16, + d_head=128, + min_len=32, + max_len=64, + mha_type="mha", + dtype = torch.bfloat16, +): + + torch.manual_seed(0) + device = "cuda" + + assert mha_type in ["mha", "mqa", "gqa"] + + lens_q = torch.randint(low=min_len, high=max_len + 1, size=(batch_size,)) + lens_k = lens_q.clone() + + cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32), lens_q.cumsum(0)]) + cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32), lens_k.cumsum(0)]) + + total_q = cu_seqlens_q[-1] + total_k = cu_seqlens_k[-1] + + cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) + cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) + + if mha_type == "gqa": + H = 3 * n_heads + H_kv = n_heads + elif mha_type == "mha": + H = H_kv = n_heads + else: # MQA + H = n_heads + H_kv = 1 + + d_head_v = d_head + + q = torch.randn(total_q, H, d_head, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(total_k, H_kv, d_head, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(total_k, H_kv, d_head_v, device=device, dtype=dtype, requires_grad=True) + + return q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k + +# Simple for loop over batch dim implementation +def torch_flash_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_k: torch.Tensor = None, + total_q: int = 0, + total_k: int = 0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs + ): + + """ + q: (total_q, H, d) if cu_seqlens_q is not None, otherwise (B, L, H, d) + k: (total_k, H_kv, d) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d) + v: (total_k, H_kv, d_v) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d_v) + cu_seqlens_q: (B+1,) int32, cumulative + cu_seqlens_k: (B+1,) int32, cumulative + + seqused_q: (B+1,) int32 + seqused_k: (B+1,) int32 + Returns: + out packed like q: (total_q, H, d_v) + """ + + if cu_seqlens_q is not None: + assert cu_seqlens_q.dim() == 1 + assert total_q == q.shape[0] + assert q.dim() == 3 + H = q.shape[1] + B = cu_seqlens_q.shape[0] - 1 + else: + assert q.dim() == 4 + H = q.shape[2] + B = q.shape[0] + + if cu_seqlens_k is not None: + assert cu_seqlens_k.dim() == 1 + assert total_k == k.shape[0] == v.shape[0] + assert k.dim() == v.dim() == 3 + H_kv = k.shape[1] + B_kv = cu_seqlens_k.shape[0] - 1 + else: + assert k.dim() == v.dim() == 4 + assert k.shape[0] == v.shape[0] + H_kv = k.shape[2] + B_kv = k.shape[0] + + d = q.shape[-1] + d_v = v.shape[-1] + + assert H_kv == v.shape[-2] + assert d == k.shape[-1] + assert B == B_kv + + assert q.device == k.device == v.device + assert q.is_floating_point() and k.is_floating_point() and v.is_floating_point() + + device = q.device + dtype = q.dtype + + hcseq_q = cu_seqlens_q.to(device='cpu') + hcseq_k = cu_seqlens_k.to(device='cpu') + + outs = [] + for b in range(B): + if hcseq_q is not None: + q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) + qb = q[q_start:q_end] + else: + qb = q[b] + + if hcseq_k is not None: + k_start, k_end = int(hcseq_k[b]), int(hcseq_k[b+1]) + kb = k[k_start:k_end] + vb = v[k_start:k_end] + else: + kb = k[b] + vb = v[b] + + qb = qb.permute(1, 0, 2).unsqueeze(0) + kb = kb.permute(1, 0, 2).unsqueeze(0) + vb = vb.permute(1, 0, 2).unsqueeze(0) + + ob = F.scaled_dot_product_attention( + qb, kb, vb, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + enable_gqa=H_kv!=H + ) + + ob = ob.squeeze(0).permute(1, 0, 2).contiguous() + outs.append(ob) + + if cu_seqlens_q is not None: + out = torch.cat(outs, dim=0).to(device=device, dtype=dtype) + else: + out = torch.stack(outs, dim=0).to(device=device, dtype=dtype) + return out + +@torch.no_grad() +def _stats(name, a, b, atol, rtol): + diff = (a - b).float() + mean_abs = diff.abs().mean().item() + mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) + print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") + return mean_abs < atol and mean_rel < rtol \ No newline at end of file From b4e589699c5f2d6070e9517504b635ea3b3c2cf9 Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Mon, 13 Oct 2025 14:19:46 -0700 Subject: [PATCH 130/258] Remove self refs in softmax for loop (#1924) Co-authored-by: Tri Dao --- flash_attn/cute/softmax.py | 68 +++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index b283e7c7035..59e5add7abe 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -64,33 +64,49 @@ def online_softmax( # Change acc_S to M,N layout view. acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + # Each iteration processes one row of acc_S - for r in cutlass.range_constexpr(cute.size(self.row_max)): + for r in cutlass.range(cute.size(row_max), unroll_full=True): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) - row_max_cur = self._compute_row_max( + + row_max_cur = utils.fmax_reduce( acc_S_row, - init_val=self.row_max[r] if cutlass.const_expr(not is_first) else None, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch ) + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) if cutlass.const_expr(check_inf): row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + if cutlass.const_expr(is_first): - row_max_cur_scaled = row_max_cur * self.scale_log2 - acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) - acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) + + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) row_scale[r] = 1.0 else: - row_max_prev = self.row_max[r] - row_max_cur_scaled = row_max_cur * self.scale_log2 - acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) + row_max_prev = row_max[r] + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) - row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) - acc_S_row_sum = ( - self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[r] * row_scale[r]) + row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2) + + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, + init_val=row_sum[r] * row_scale[r], + arch=arch ) - self.row_max[r] = row_max_cur - self.row_sum[r] = acc_S_row_sum + + row_max[r] = row_max_cur + row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) + return row_scale @cute.jit @@ -98,25 +114,31 @@ def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | """Finalize the online softmax by computing the scale and logsumexp.""" if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + # quad reduction for row_sum as we didn't do it during each iteration of online softmax - self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) - row_scale = cute.make_fragment_like(self.row_max, Float32) - for r in cutlass.range_constexpr(cute.size(self.row_sum)): + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_fragment_like(row_max, Float32) + + for r in cutlass.range(cute.size(row_sum), unroll_full=True): if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) - self.row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - self.row_max[r] * self.scale_log2) + row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2) + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = ( - self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + row_sum[r] == 0.0 or row_sum[r] != row_sum[r] ) row_scale[r] = ( - cute.arch.rcp_approx(self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) ) * final_scale - row_sum_cur = self.row_sum[r] + row_sum_cur = row_sum[r] LN2 = math.log(2.0) - self.row_sum[r] = ( - (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 + row_sum[r] = ( + (row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) From 13afe0d51d4ff24ddbc95938af0d555528660817 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 17:54:47 -0400 Subject: [PATCH 131/258] [Cute,Bwd,Sm90] Make postprocessing kernel work --- flash_attn/cute/flash_bwd_postprocess.py | 401 +++++------------------ flash_attn/cute/flash_bwd_sm90.py | 108 +++--- flash_attn/cute/interface.py | 38 ++- 3 files changed, 180 insertions(+), 367 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 8adb4963815..ef1e027a62d 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -2,21 +2,26 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math -from typing import Callable, Optional, Type +from typing import Callable, Optional, Type, Literal import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass.cute.nvgpu import cpasync, warp, warpgroup -from flash_attn.cute import ampere_helpers as sm80_utils import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass import Int32, Float32, const_expr +from cutlass.utils import LayoutEnum + from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.tile_scheduler import ( - ParamsBase, - SingleTileScheduler, - SingleTileVarlenScheduler, + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, TileSchedulerArguments ) @@ -25,44 +30,41 @@ class FlashAttentionBackwardPostprocess: def __init__( self, dtype: Type[cutlass.Numeric], - # tiled_mma: cute.TiledMma, head_dim: int, - m_block_size: int = 128, + arch: Literal[80, 90], + tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, dQ_swapAB: bool = False, ): - """Initializes the configuration for a flash attention v2 kernel. - - All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension - should be a multiple of 8. - + """ :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int """ self.dtype = dtype - self.m_block_size = m_block_size + self.tile_m = tile_m + assert arch in [80, 90], "Only Ampere (80) and Hopper (90) are supported" + self.arch = arch # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) - self.check_hdim_oob = head_dim != self.head_dim_padded - # self.tiled_mma = tiled_mma + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.tile_hdim self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB @staticmethod - def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :return: True if the kernel can be implemented, False otherwise :rtype: bool @@ -75,73 +77,68 @@ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: return False return True + def _get_tiled_mma(self): + if const_expr(self.arch == 80): + num_mma_warps = self.num_threads // 32 + AtomLayoutdQ = ( + (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) + if const_expr(not self.dQ_swapAB) + else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + ) + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + AtomLayoutdQ, + permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + ) + else: + tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.tile_m // 64, 2, 1), + tiler_mn=(64, self.tile_hdim // 2), + ) + assert self.num_threads == tiled_mma.size + return tiled_mma + def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies universal_copy_bits = 128 - async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + async_copy_elems_accum = universal_copy_bits // Float32.width atom_async_copy_accum = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), - cutlass.Float32, + Float32, num_bits_per_copy=universal_copy_bits, ) # We don't do bound checking for the gmem -> smem load so we just assert here. - assert ( - self.m_block_size * self.head_dim_padded // async_copy_elems_accum - ) % self.tiled_mma.size == 0 + assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0 self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_async_copy_accum, - cute.make_layout(self.tiled_mma.size), + cute.make_layout(self.num_threads), cute.make_layout(async_copy_elems_accum), ) - atom_universal_copy_accum = cute.make_copy_atom( - # multiply by 4 for Sm90 - cute.nvgpu.CopyUniversalOp(), - cutlass.Float32, - num_bits_per_copy=cutlass.Float32.width, - ) - self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - atom_universal_copy_accum, - cute.make_layout(self.tiled_mma.size), - cute.make_layout(1), # 4 for Sm90 - ) + num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(Float32, self.num_threads, num_s2r_copy_elems) - async_copy_elems = universal_copy_bits // self.dtype.width - # atom_universal_copy: universal copy atom for dQ store - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.dtype, - num_bits_per_copy=universal_copy_bits, - ) - # tdQ_layout: thread layout for dQ store - assert self.head_dim_padded % async_copy_elems == 0 - gmem_threads_per_row = math.gcd( - self.head_dim_padded // async_copy_elems, self.tiled_mma.size - ) - assert self.tiled_mma.size % gmem_threads_per_row == 0 - tdQ_layout = cute.make_ordered_layout( - (self.tiled_mma.size // gmem_threads_per_row, gmem_threads_per_row), - order=(1, 0), - ) - # Value layouts for copies - vdQ_layout = cute.make_layout((1, async_copy_elems)) - self.gmem_tiled_copy_dQ = cute.make_tiled_copy_tv( - atom_universal_copy, tdQ_layout, vdQ_layout - ) + self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(self.dtype, self.tile_hdim, self.num_threads) # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQaccum / dQ # /////////////////////////////////////////////////////////////////////////////// - self.sdQaccum_layout = cute.make_layout(self.m_block_size * self.head_dim_padded) + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. mma_shape_n = self.tiled_mma.get_tile_size(1) - sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) - self.sdQ_layout = cute.tile_to_shape( - sdQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1) - ) + if const_expr(self.arch == 80): + sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) + self.sdQ_layout = cute.tile_to_shape(sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)) + else: + self.sdQ_layout = sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)) @cute.jit def __call__( @@ -154,29 +151,17 @@ def __call__( stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not mdQ.element_type in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(not mdQ.element_type in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + if const_expr(mdQaccum is not None): + if const_expr(not mdQaccum.element_type in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) mdQaccum, mdQ = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mdQaccum, mdQ)] - num_mma_warps = self.num_threads // 32 - AtomLayoutdQ = ( - (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) - if cutlass.const_expr(not self.dQ_swapAB) - else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) - ) - tiled_mma = cute.make_tiled_mma( - warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), - AtomLayoutdQ, - permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), - ) - self.tiled_mma = tiled_mma - + self.tiled_mma = self._get_tiled_mma() self._setup_attributes() smem_size = max( @@ -184,7 +169,7 @@ def __call__( cute.size_in_bytes(self.dtype, self.sdQ_layout), ) - if cutlass.const_expr(mCuSeqlensQ is not None): + if const_expr(mCuSeqlensQ is not None): TileScheduler = SingleTileVarlenScheduler num_head = mdQ.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 @@ -195,14 +180,14 @@ def __call__( tile_sched_args = TileSchedulerArguments( - num_block=cute.ceil_div(mdQ.shape[1], self.m_block_size), + num_block=cute.ceil_div(mdQ.shape[1], self.tile_m), num_head=num_head, num_batch=num_batch, seqlen_k=0, headdim=mdQ.shape[2], headdim_v=0, total_q=mdQ.shape[0], - tile_shape_mn=(self.m_block_size, 1), + tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) @@ -218,7 +203,7 @@ def __call__( mCuSeqlensQ, mSeqUsedQ, scale, - tiled_mma, + self.tiled_mma, self.dQ_swapAB, self.sdQaccum_layout, self.sdQ_layout, @@ -229,7 +214,7 @@ def __call__( TileScheduler, ).launch( grid=grid_dim, - block=[tiled_mma.size, 1, 1], + block=[self.tiled_mma.size, 1, 1], smem=smem_size, stream=stream, ) @@ -266,18 +251,18 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// seqlen = SeqlenInfoQK(batch_size, mdQ.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None) - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(not seqlen.has_cu_seqlens_q): mdQ_cur = mdQ[batch_size, None, num_head, None] mdQaccum_cur = mdQaccum[batch_size, num_head, None] head_dim = mdQ.shape[3] else: - padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size + padded_offset_q = seqlen.offset_q + batch_size * self.tile_m mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None]) - mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]) + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None]) head_dim = mdQ.shape[2] - # HACK: Compiler doesn't seem to recognize that padding - # by padded_offset_q * self.head_dim_padded keeps alignment + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.tile_hdim keeps alignment # since statically divisible by 4 mdQaccum_cur_ptr = cute.make_ptr( @@ -291,12 +276,9 @@ def kernel( mdQaccum_cur.layout ) - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile( - mdQaccum_cur, blkdQaccum_shape, (m_block,) - ) - blkdQ_shape = (self.m_block_size, self.head_dim_padded) - gdQ = cute.local_tile(mdQ_cur, blkdQ_shape, (m_block, 0)) + dQaccum_shape = (self.tile_m * self.tile_hdim,) + gdQaccum = cute.local_tile(mdQaccum_cur, dQaccum_shape, (m_block,)) + gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -306,7 +288,7 @@ def kernel( sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) seqlen_q = seqlen.seqlen_q - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) # Step 1: load dQaccum from gmem to smem g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) @@ -327,9 +309,9 @@ def kernel( # thr_mma = tiled_mma.get_slice(tidx) # print(tiled_mma) acc_shape = tiled_mma.partition_shape_C( - (self.m_block_size, self.head_dim_padded) - if cutlass.const_expr(not dQ_swapAB) - else (self.head_dim_padded, self.m_block_size) + (self.tile_m, self.tile_hdim) + if const_expr(not dQ_swapAB) + else (self.tile_hdim, self.tile_m) ) acc = cute.make_fragment(acc_shape, cutlass.Float32) assert cute.size(acc) == cute.size(tdQsdQaccum) @@ -348,9 +330,7 @@ def kernel( # Step 3: Copy dQ from register to smem cute.arch.barrier() # make sure all threads have finished loading dQaccum - smem_copy_atom_dQ = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width - ) + smem_copy_atom_dQ = utils.get_smem_store_atom(self.arch, self.dtype) smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) @@ -368,221 +348,14 @@ def kernel( cute.autovec_copy(tdQsdQ, tdQrdQ) # Step 5: Copy dQ from register to gmem - cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): - if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.m_block_size: + if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m: cute.copy( gmem_tiled_copy_dQ, tdQrdQ[None, rest_m, None], tdQgdQ[None, rest_m, None], pred=tdQpdQ[None, rest_m, None], ) - - -class FlashAttentionBackwardPostprocess_sm90(FlashAttentionBackwardPostprocess): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.universal_copy_bits = 128 - - def _setup_attributes(self): - self.sdQaccum_layout = cute.make_layout( - shape=(self.m_block_size * self.head_dim_padded, ), - ) - - sdQ_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - cutlass.utils.hopper_helpers.get_smem_layout_atom( - cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded - ), - self.dtype - ) - self.sdQ_layout = cute.tile_to_shape( - sdQ_layout_atom, - (self.m_block_size, self.head_dim_padded), - (0, 1) - ) - # G->S - async_copy_elements = self.universal_copy_bits // cutlass.Float32.width - self.G2S_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - cutlass.Float32, - num_bits_per_copy=self.universal_copy_bits - ), - cute.make_layout(self.tiled_mma.size), - cute.make_layout(async_copy_elements) - ) - - # S->R - self.S2R_tiled_copy_dQaccum = cute.make_tiled_copy_tv( - cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=self.universal_copy_bits), - cute.make_layout(self.tiled_mma.size), - cute.make_layout(async_copy_elements) - ) - - @cute.jit - def __call__( - self, - mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - scale: cutlass.Float32, - mCuSeqlensQ: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - stream: cuda.CUstream, - ): - # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mdQaccum, mdQ = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mdQaccum, mdQ)] - - mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1,3,2,0])) - mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2,1,0])) - - # tiled_mma - tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.MN, - cutlass.Float32, - atom_layout_mnk=(self.m_block_size // 64, 2, 1), - tiler_mn=(64, self.head_dim_padded) - ) - - self.tiled_mma = tiled_mma - self.num_mma_threads = tiled_mma.size - self._setup_attributes() - - - # TMA setup - tma_atom_dQ, mdQ = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), - mdQ, - self.sdQ_layout, - (self.m_block_size, self.head_dim_padded), - ) - - seqlen = mdQ.shape[0] - grid_dim = [ - cute.ceil_div(seqlen, self.m_block_size), - cute.size(mdQ.shape[2]), - cute.size(mdQ.shape[3]), - ] - smem_size = max( - cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), - cute.size_in_bytes(self.dtype, self.sdQ_layout) - ) - self.kernel( - mdQaccum, - mdQ, - tma_atom_dQ, - tiled_mma, - self.sdQaccum_layout, - self.sdQ_layout, - self.G2S_tiled_copy_dQaccum, - self.S2R_tiled_copy_dQaccum, - scale, - ).launch( - grid=grid_dim, - block=[self.num_mma_threads, 1, 1], - smem=smem_size, - stream=stream, - ) - - @cute.kernel - def kernel( - self, - mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - tma_atom_dQ: cute.CopyAtom, - tiled_mma: cute.TiledMma, - sdQaccum_layout: cute.Layout, - sdQ_layout: cute.ComposedLayout, - g2s_tiled_copy_dQaccum: cute.TiledCopy, - s2r_tiled_copy_dQaccum: cute.TiledCopy, - scale: cutlass.Float32, - ): - # basic setup - tidx = cute.arch.thread_idx()[0] - m_block, head_idx, batch_idx = cute.arch.block_idx() - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - - smem = cutlass.utils.SmemAllocator() - sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=128) - sdQ = cute.make_tensor( - cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), - sdQ_layout.outer - ) - - if warp_idx == 0: - cpasync.prefetch_descriptor(tma_atom_dQ) - - # G->S - gdQaccum = cute.local_tile( - mdQaccum[None, head_idx, batch_idx], - (self.m_block_size * self.head_dim_padded, ), - (m_block,) - ) - - gmem_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) - tdQaccumgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - tdQaccumsdQaccum = gmem_thr_copy_dQaccum.partition_D(sdQaccum) - - cute.copy(g2s_tiled_copy_dQaccum, tdQaccumgdQaccum, tdQaccumsdQaccum) - cute.arch.barrier() - - # S->R - acc_dQaccum = cute.make_fragment( - tiled_mma.partition_shape_C((self.m_block_size, self.head_dim_padded)), - cutlass.Float32 - ) - acc_dQaccum.fill(0) - - smem_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) - tdQaccumsdQaccum = smem_thr_copy_dQaccum.partition_S(sdQaccum) - - - tdQaccumrdQaccum = cute.make_tensor(acc_dQaccum.iterator, cute.make_layout(tdQaccumsdQaccum.shape)) - cute.copy(smem_thr_copy_dQaccum, tdQaccumsdQaccum, tdQaccumrdQaccum) - - - # Scale + FP32->BF16/FP16 - acc_mmaA_view = cute.make_tensor(acc_dQaccum.iterator, utils.convert_layout_acc_frgA(acc_dQaccum.layout)) - rdQ = cute.make_fragment_like(acc_mmaA_view, self.dtype) - - acc_dQaccum.store(acc_dQaccum.load() * scale) - utils.cvt_f16(acc_mmaA_view, rdQ) # BF16/FP16 output - - - # R->S (StMatrix) - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), - self.dtype, #BF16/FP16 - ) - - smem_thr_copy = cute.make_tiled_copy_C(smem_copy_atom, tiled_mma).get_slice(tidx) - tdQsdQ = smem_thr_copy.partition_D(sdQ) - tdQrdQ = cute.make_tensor(rdQ.iterator, cute.make_layout(tdQsdQ.shape)) - - cute.copy(smem_thr_copy, tdQrdQ, tdQsdQ) - cute.arch.barrier() - - #S->G (TMA) - gdQ = cute.local_tile( - mdQ[None, None, head_idx, batch_idx], - (self.m_block_size, self.head_dim_padded), - (m_block, 0) - ) - - tdQsdQ, tdQgdQ = cpasync.tma_partition( - tma_atom_dQ, - 0, - cute.make_layout(1), - cute.group_modes(sdQ, 0, 2), - cute.group_modes(gdQ, 0, 2) - ) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - if warp_idx == 4: # only one warp writes - cute.copy(tma_atom_dQ, tdQsdQ, tdQgdQ) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 13ccef13962..0284b96905f 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -30,11 +30,20 @@ def __init__( head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, + is_causal: bool = False, tile_m: int = 64, tile_n: int = 128, - num_stages: int = 2, + Q_stage: int = 2, + dO_stage: int = 2, + PdS_stage: int = 2, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 1, + AtomLayoutNdKV: int = 2, + AtomLayoutMdQ: int = 1, num_threads: int = 384, - Q_in_regs: bool = False, + V_in_regs: bool = False, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -47,12 +56,21 @@ def __init__( self.check_hdim_oob = head_dim != self.tile_hdim self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads - self.num_stages = num_stages - self.dS_stage = 2 - self.Q_in_regs = Q_in_regs + self.Q_stage = Q_stage + self.dO_stage = dO_stage + self.PdS_stage = PdS_stage + assert self.dO_stage in [1, self.Q_stage] + assert self.PdS_stage in [1, self.Q_stage] + self.AtomLayoutMSdP = AtomLayoutMSdP + self.AtomLayoutNdKV = AtomLayoutNdKV + self.AtomLayoutMdQ = AtomLayoutMdQ + self.num_mma_warp_groups = (self.num_threads // 128) - 1 + self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == self.num_mma_warp_groups and SdP_swapAB and not dKV_swapAB + self.V_in_regs = V_in_regs @staticmethod def can_implement( @@ -61,9 +79,9 @@ def can_implement( head_dim_v, tile_m, tile_n, - num_stages, + Q_stage, num_threads, - Q_in_regs=False, + V_in_regs=False, ) -> bool: if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False @@ -115,11 +133,11 @@ def _setup_attributes(self): self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) for shape, stage in [ - ((self.tile_m, self.tile_hdim), self.num_stages), + ((self.tile_m, self.tile_hdim), self.Q_stage), ((self.tile_n, self.tile_hdim), None), ((self.tile_n, self.tile_hdimv), None), - ((self.tile_m, self.tile_hdimv), self.num_stages), - ((self.tile_m, self.tile_n), self.dS_stage), + ((self.tile_m, self.tile_hdimv), self.dO_stage), + ((self.tile_m, self.tile_n), self.PdS_stage), ] ] self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) @@ -130,16 +148,21 @@ def _setup_attributes(self): def _get_tiled_mma(self): # S = Q @ K.T, dP = dO @ V.T + atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP) + tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(self.tile_m // 64, 2, 1), - tiler_mn=(64, self.tile_n // 2), + atom_layout_mnk=atom_layout_SdP + (1,), + tiler_mn=tiler_mn_SdP, ) # dV = P.T @ dO, dK = dS.T @ Q + atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) + tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) + tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) tiled_mma_dK, tiled_mma_dV = [ sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -147,20 +170,23 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(self.tile_n // 64, 1, 1), - tiler_mn=(64, tile_hdim), + atom_layout_mnk=atom_layout_dKV + (1,), + tiler_mn=tiler_mn_d, + a_source=warpgroup.OperandSource.RMEM if self.Mma_dKV_is_RS else warpgroup.OperandSource.SMEM, ) - for tile_hdim in (self.tile_hdim, self.tile_hdimv) + for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) ] # dQ = dS @ K + atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(self.tile_m // 64, 2, 1), - tiler_mn=(64, self.tile_hdim // 2), + atom_layout_mnk=atom_layout_dQ + (1,), + tiler_mn=tiler_mn_dQ, ) return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ @@ -179,18 +205,18 @@ def _get_shared_storage_cls(self): ] cosize_sdS = cute.cosize(self.sPdS_layout) - cosize_sP = cute.cosize(self.sPdS_layout) # Could be zero + cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.Mma_dKV_is_RS) else 0 sLSE_struct = cute.struct.Align[ - cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.num_stages], 128 + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128 ] sdPsum_struct = cute.struct.Align[ - cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.num_stages], 128 + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128 ] @cute.struct class SharedStorageQKV: - mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2] + mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2] sLSE: sLSE_struct sdPsum: sdPsum_struct sQ: sQ_struct @@ -256,9 +282,9 @@ def __call__( tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() self.num_mma_threads = tiled_mma_SdP.size + assert self.num_mma_threads + 128 == self.num_threads self.num_threads_per_warp_group = 128 - self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group self.num_producer_threads = 32 self.num_mma_regs = 240 @@ -435,7 +461,7 @@ def kernel( ) pipeline_q = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), - num_stages=self.num_stages, + num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], @@ -443,7 +469,7 @@ def kernel( ) pipeline_do = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_dO.data_ptr(), - num_stages=self.num_stages, + num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"], @@ -454,18 +480,20 @@ def kernel( sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sP = None + if const_expr(not self.Mma_dKV_is_RS): + sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sLSE = storage.sLSE.get_tensor( cute.make_layout( - (self.tile_m, self.num_stages), + (self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) ) sdPsum = storage.sdPsum.get_tensor( cute.make_layout( - (self.tile_m, self.num_stages), + (self.tile_m, self.dO_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) ) @@ -587,7 +615,7 @@ def load( if warp_idx_in_wg == 0: producer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.num_stages + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) tile_scheduler = TileSchedulerCls() @@ -708,9 +736,11 @@ def mma( tdPrdO = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sdO)) tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) # dV += P.T @ dO - sPt = utils.transpose_view(sP) + sPt = utils.transpose_view(sP) if sP is not None else None sdOt = utils.transpose_view(sdO) - tdVrPt = tiled_mma_dV.make_fragment_A(wg_mma_dV.partition_A(sPt)) + tdVrPt = None + if const_expr(sP is not None): + tdVrPt = tiled_mma_dV.make_fragment_A(wg_mma_dV.partition_A(sPt)) tdVrdOt = tiled_mma_dV.make_fragment_B(wg_mma_dV.partition_B(sdOt)) # dK += dS.T @ Q sdSt = utils.transpose_view(sdS) @@ -727,20 +757,22 @@ def mma( smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( tidx ) - tPsP = smem_thr_copy_PdS.partition_D(sP) + tPsP = None + if const_expr(sP is not None): + tPsP = smem_thr_copy_PdS.partition_D(sP) tdSsdS = smem_thr_copy_PdS.partition_D(sdS) sLSE_mma = cute.make_tensor( sLSE.iterator, cute.make_layout( - (self.tile_m, self.tile_n, self.num_stages), + (self.tile_m, self.tile_n, self.Q_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) sdPsum_mma = cute.make_tensor( sdPsum.iterator, cute.make_layout( - (self.tile_m, self.tile_n, self.num_stages), + (self.tile_m, self.tile_n, self.dO_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) @@ -795,7 +827,7 @@ def mma( ) consumer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -888,11 +920,11 @@ def mma_one_m_block( tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) cute.autovec_copy(tLSEsdPsum[None, smem_idx], tLSErdPsum) - PdS_smem_idx = smem_idx if const_expr(self.dS_stage > 1) else 0 + PdS_smem_idx = smem_idx if const_expr(self.PdS_stage > 1) else 0 # R2S for P tPrP = smem_thr_copy_PdS.retile(tdVrP) # sync to make sure P has already been used in the previous iteration before writing new vals - if const_expr(self.dS_stage == 1): + if const_expr(self.PdS_stage == 1): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) @@ -930,7 +962,7 @@ def mma_one_m_block( tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, PdS_smem_idx]) - # (4) [GEMM 3] dV += P.T @ dO + # (5) [GEMM 3] dV += P.T @ dO mma_pdo_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a2a5a44a0fb..a41bfa0fe3c 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -37,7 +37,6 @@ from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess_sm90 from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine @@ -449,6 +448,13 @@ def _flash_attn_bwd( ) m_block_size = 64 n_block_size = 128 + num_stages_Q = 2 + num_stages_dO = 2 + num_stages_PdS = 2 + AtomLayoutMSdP = 1 + AtomLayoutNdKV = 2 + AtomLayoutMdQ = 1 + num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: fa_bwd_sm80 = FlashAttentionBackwardSm80( dtype, @@ -475,19 +481,20 @@ def _flash_attn_bwd( head_dim, head_dim_v, qhead_per_kvhead, + causal, m_block_size, n_block_size, - # num_stages_Q, - # num_stages_dO, - # num_threads, - # causal, - # SdP_swapAB, - # dKV_swapAB, - # dQ_swapAB, - # AtomLayoutMSdP, - # AtomLayoutNdKV, - # AtomLayoutMdQ, - # V_in_regs=V_in_regs, + num_stages_Q, + num_stages_dO, + num_stages_PdS, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + num_threads, + V_in_regs=V_in_regs, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -517,12 +524,13 @@ def _flash_attn_bwd( seqused_k_tensor, ) + num_threads -= 128 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: - # fa_bwd_post = FlashAttentionBackwardPostprocess( - fa_bwd_post = FlashAttentionBackwardPostprocess_sm90( - dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB + arch = 90 + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( From d2c8a6caae73a594dd385d02450a5d81045c0968 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 18:10:15 -0400 Subject: [PATCH 132/258] [Cute] Run ruff format on bwd files --- .pre-commit-config.yaml | 9 +++ flash_attn/cute/flash_bwd_postprocess.py | 55 ++++++++++----- flash_attn/cute/flash_bwd_preprocess.py | 89 ++++++++++++++---------- flash_attn/cute/flash_bwd_sm90.py | 31 +++++++-- flash_attn/cute/softmax.py | 64 ++++++++++------- 5 files changed, 161 insertions(+), 87 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..5c63513faf8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ + - id: ruff-format + files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index ef1e027a62d..9ca76e3c9ba 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -10,7 +10,7 @@ import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warp, warpgroup -from cutlass import Int32, Float32, const_expr +from cutlass import Float32, const_expr from cutlass.utils import LayoutEnum from flash_attn.cute import utils @@ -22,7 +22,7 @@ ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, - TileSchedulerArguments + TileSchedulerArguments, ) @@ -123,9 +123,13 @@ def _setup_attributes(self): cute.make_layout(async_copy_elems_accum), ) num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 - self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(Float32, self.num_threads, num_s2r_copy_elems) + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) - self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(self.dtype, self.tile_hdim, self.num_threads) + self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.dtype, self.tile_hdim, self.num_threads + ) # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQaccum / dQ # /////////////////////////////////////////////////////////////////////////////// @@ -136,9 +140,13 @@ def _setup_attributes(self): mma_shape_n = self.tiled_mma.get_tile_size(1) if const_expr(self.arch == 80): sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) - self.sdQ_layout = cute.tile_to_shape(sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)) + self.sdQ_layout = cute.tile_to_shape( + sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) + ) else: - self.sdQ_layout = sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)) + self.sdQ_layout = sm90_utils.make_smem_layout( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) + ) @cute.jit def __call__( @@ -151,15 +159,21 @@ def __call__( stream: cuda.CUstream, ): # Get the data type and check if it is fp16 or bf16 - if const_expr(not mdQ.element_type in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mdQaccum is not None): - if const_expr(not mdQaccum.element_type in [cutlass.Float32]): + if const_expr(mdQaccum.element_type not in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mdQaccum, mdQ = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mdQaccum, mdQ)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mdQaccum, mdQ = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mdQaccum, mdQ) + ] self.tiled_mma = self._get_tiled_mma() self._setup_attributes() @@ -178,7 +192,6 @@ def __call__( num_head = mdQ.shape[2] num_batch = mdQ.shape[0] - tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mdQ.shape[1], self.tile_m), num_head=num_head, @@ -195,7 +208,6 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - # grid_dim: (m_block, num_head, batch_size) self.kernel( mdQaccum, @@ -250,7 +262,15 @@ def kernel( # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK(batch_size, mdQ.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None) + seqlen = SeqlenInfoQK( + batch_size, + mdQ.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) if const_expr(not seqlen.has_cu_seqlens_q): mdQ_cur = mdQ[batch_size, None, num_head, None] mdQaccum_cur = mdQaccum[batch_size, num_head, None] @@ -258,7 +278,9 @@ def kernel( else: padded_offset_q = seqlen.offset_q + batch_size * self.tile_m mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None]) - mdQaccum_cur = cute.domain_offset((padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None]) + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None] + ) head_dim = mdQ.shape[2] # HACK: Compiler doesn't seem to recognize that padding @@ -271,10 +293,7 @@ def kernel( mem_space=mdQaccum_cur.iterator.memspace, assumed_align=mdQaccum.iterator.alignment, ) - mdQaccum_cur = cute.make_tensor( - mdQaccum_cur_ptr, - mdQaccum_cur.layout - ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) dQaccum_shape = (self.tile_m * self.tile_hdim,) gdQaccum = cute.local_tile(mdQaccum_cur, dQaccum_shape, (m_block,)) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index ee6535be527..1a900f83a67 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -14,7 +14,12 @@ from flash_attn.cute import utils from flash_attn.cute import copy_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments, +) class FlashAttentionBackwardPreprocess: @@ -86,13 +91,17 @@ def _setup_attributes(self): else (32 if self.head_dim_padded % 32 == 0 else 16) ) ) - self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(self.dtype, gmem_k_block_size, self.num_threads) + self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( + self.dtype, gmem_k_block_size, self.num_threads + ) universal_copy_bits = 128 num_copy_elems_dQaccum = universal_copy_bits // Float32.width assert ( self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum ) % self.num_threads == 0 - self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(Float32, self.num_threads, num_copy_elems_dQaccum) + self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_copy_elems_dQaccum + ) @cute.jit def __call__( @@ -110,23 +119,31 @@ def __call__( # Get the data type and check if it is fp16 or bf16 if cutlass.const_expr(not (mO.element_type == mdO.element_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(not mO.element_type in [cutlass.Float16, cutlass.BFloat16]): + if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(not mdPsum.element_type in [Float32]): + if cutlass.const_expr(mdPsum.element_type not in [Float32]): raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not mdQaccum.element_type in [Float32]): + if cutlass.const_expr(mdQaccum.element_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") if cutlass.const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" - if cutlass.const_expr(not mLSE.element_type in [Float32]): + if cutlass.const_expr(mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(not mLSElog2.element_type in [Float32]): + if cutlass.const_expr(mLSElog2.element_type not in [Float32]): raise TypeError("LSElog2 tensor must be Float32") # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mO, mdO, mdQaccum = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mO, mdO, mdQaccum)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO, mdO, mdQaccum = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mO, mdO, mdQaccum) + ] self._setup_attributes() @@ -139,7 +156,6 @@ def __call__( num_head = mO.shape[2] num_batch = mO.shape[0] - tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mO.shape[1], self.m_block_size), num_head=num_head, @@ -202,7 +218,15 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK(batch_size, mO.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None) + seqlen = SeqlenInfoQK( + batch_size, + mO.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mO_cur = mO[batch_size, None, num_head, None] @@ -216,7 +240,7 @@ def kernel( padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None]) headdim_v = mO.shape[2] - + blkOdO_shape = (self.m_block_size, self.head_dim_padded) # (m_block_size, head_dim) gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) @@ -238,7 +262,7 @@ def kernel( tOpO = utils.predicate_k(tOcO, limit=headdim_v) tOpdO = utils.predicate_k(tOcO, limit=headdim_v) - seqlen_q = seqlen.seqlen_q + seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) if cutlass.const_expr(mLSE is not None): @@ -247,9 +271,7 @@ def kernel( else: mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None]) - gLSE = cute.local_tile( - mLSE_cur, (self.m_block_size,), (m_block,) - ) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) lse = Float32.inf if tidx < seqlen_q - m_block * self.m_block_size: lse = gLSE[tidx] @@ -267,13 +289,17 @@ def kernel( gmem_thr_copy_O, tOgO[None, m, None], tOrO[None, m, None], - pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + pred=tOpO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, ) cute.copy( gmem_thr_copy_O, tOgdO[None, m, None], tOrdO[None, m, None], - pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + pred=tOpdO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, ) # Sum across the "k" dimension dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( @@ -286,9 +312,7 @@ def kernel( dP_sum.store(dpsum) # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile( - mdPsum_cur, (self.m_block_size,), (m_block,) - ) + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,)) # Only the thread corresponding to column 0 writes out the dPsum to gmem if tOcO[0, 0, 0][1] == 0: for m in cutlass.range(cute.size(dP_sum), unroll_full=True): @@ -301,10 +325,12 @@ def kernel( mdQaccum_cur = mdQaccum[batch_size, num_head, None] else: padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size - mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]) + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None] + ) - # HACK: Compiler doesn't seem to recognize that padding - # by padded_offset_q * self.head_dim_padded keeps alignment + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.head_dim_padded keeps alignment # since statically divisible by 4 mdQaccum_cur_ptr = cute.make_ptr( @@ -313,15 +339,10 @@ def kernel( mem_space=mdQaccum_cur.iterator.memspace, assumed_align=mdQaccum.iterator.alignment, ) - mdQaccum_cur = cute.make_tensor( - mdQaccum_cur_ptr, - mdQaccum_cur.layout - ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) - gdQaccum = cute.local_tile( - mdQaccum_cur, blkdQaccum_shape, (m_block,) - ) + gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) zero = cute.make_fragment_like(tQgQaccum) @@ -335,9 +356,7 @@ def kernel( padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None]) - gLSElog2 = cute.local_tile( - mLSElog2_cur, (self.m_block_size,), (m_block,) - ) + gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.m_block_size: gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 0284b96905f..6021ffa8584 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -69,7 +69,12 @@ def __init__( self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ self.num_mma_warp_groups = (self.num_threads // 128) - 1 - self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == self.num_mma_warp_groups and SdP_swapAB and not dKV_swapAB + self.Mma_dKV_is_RS = ( + AtomLayoutMSdP == 1 + and AtomLayoutNdKV == self.num_mma_warp_groups + and SdP_swapAB + and not dKV_swapAB + ) self.V_in_regs = V_in_regs @staticmethod @@ -172,7 +177,9 @@ def _get_tiled_mma(self): Float32, atom_layout_mnk=atom_layout_dKV + (1,), tiler_mn=tiler_mn_d, - a_source=warpgroup.OperandSource.RMEM if self.Mma_dKV_is_RS else warpgroup.OperandSource.SMEM, + a_source=warpgroup.OperandSource.RMEM + if self.Mma_dKV_is_RS + else warpgroup.OperandSource.SMEM, ) for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) ] @@ -666,7 +673,9 @@ def load( # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): load_LSE(m_block, producer_state=producer_state) - pipeline_do.producer_acquire(producer_state, extra_tx_count=self.tma_copy_bytes["V"]) + pipeline_do.producer_acquire( + producer_state, extra_tx_count=self.tma_copy_bytes["V"] + ) load_V(tma_bar_ptr=pipeline_do.producer_get_barrier(producer_state)) load_dO(m_block, producer_state=producer_state) with cute.arch.elect_one(): @@ -963,7 +972,9 @@ def mma_one_m_block( cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, PdS_smem_idx]) # (5) [GEMM 3] dV += P.T @ dO - mma_pdo_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=-1) + mma_pdo_fn( + A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=-1 + ) # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_proxy( @@ -978,7 +989,9 @@ def mma_one_m_block( pipeline_do.consumer_release(smem_pipe_read) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q - mma_dsq_fn(A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=1) + mma_dsq_fn( + A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=1 + ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) cute.arch.barrier( @@ -1055,7 +1068,9 @@ def epilogue_dKV( taccdVsdV = smem_thr_copy_dV.partition_D(sV) # reuse sV SMEM cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) # ensure smem writes are visible to TMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) @@ -1065,7 +1080,9 @@ def epilogue_dKV( taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) # ensure smem writes are visible to TMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 59e5add7abe..398f9e40c55 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -27,7 +27,7 @@ def create( scale_log2: Float32, num_rows: cutlass.Constexpr[int], arch: cutlass.Constexpr[int] = 80, - softmax_scale: Float32 | None = None + softmax_scale: Float32 | None = None, ): row_max = cute.make_fragment(num_rows, Float32) row_sum = cute.make_fragment(num_rows, Float32) @@ -64,30 +64,30 @@ def online_softmax( # Change acc_S to M,N layout view. acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) - + row_max = self.row_max row_sum = self.row_sum scale_log2 = self.scale_log2 arch = self.arch - + # Each iteration processes one row of acc_S for r in cutlass.range(cute.size(row_max), unroll_full=True): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) - + row_max_cur = utils.fmax_reduce( acc_S_row, init_val=row_max[r] if cutlass.const_expr(not is_first) else None, - arch=arch + arch=arch, ) - + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) if cutlass.const_expr(check_inf): row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur - + if cutlass.const_expr(is_first): row_max_cur_scaled = row_max_cur * scale_log2 acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) - + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) row_scale[r] = 1.0 else: @@ -96,42 +96,40 @@ def online_softmax( acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2) - + acc_S_row_sum = utils.fadd_reduce( - acc_S_row_exp, - init_val=row_sum[r] * row_scale[r], - arch=arch + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch ) - + row_max[r] = row_max_cur row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) - + return row_scale @cute.jit - def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None) -> cute.Tensor: + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: """Finalize the online softmax by computing the scale and logsumexp.""" if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): assert cute.size(sink_val) == cute.size(self.row_sum) row_sum = self.row_sum row_max = self.row_max scale_log2 = self.scale_log2 - + # quad reduction for row_sum as we didn't do it during each iteration of online softmax row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(row_max, Float32) - + for r in cutlass.range(cute.size(row_sum), unroll_full=True): if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2) - + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 - acc_O_mn_row_is_zero_or_nan = ( - row_sum[r] == 0.0 or row_sum[r] != row_sum[r] - ) + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] row_scale[r] = ( cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) ) * final_scale @@ -172,7 +170,15 @@ def create( arch = 100 row_max = cute.make_fragment(num_rows, Float32) row_sum = cute.make_fragment(num_rows, Float32) - return SoftmaxSm100(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale, rescale_threshold=rescale_threshold) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: @@ -245,12 +251,16 @@ def apply_exp2_convert( acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) else: - if cutlass.const_expr(k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit): + if cutlass.const_expr( + k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit + ): acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) else: # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) - acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) @@ -314,11 +324,11 @@ def apply_score_mod_inner( batch_idx, head_idx, softmax_scale, - vec_size:cutlass.Constexpr, + vec_size: cutlass.Constexpr, qk_acc_dtype: cutlass.Constexpr, buffers, fastdiv_mods, - constant_q_idx:cutlass.Constexpr, + constant_q_idx: cutlass.Constexpr, ): """Shared implementation for applying score modification. @@ -385,7 +395,7 @@ def apply_score_mod_inner( head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, - buffers=buffer_args + buffers=buffer_args, ) # Write back modified scores From ee3a533becf05e5d761d6c954518e89b7b78cefe Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 18:28:48 -0400 Subject: [PATCH 133/258] [CI] Add pre-commit GH action --- .github/workflows/pre-commit.yaml | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000000..1613bb365bd --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,33 @@ +name: Lint + +on: + pull_request: + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + push: + branches: + - main + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 From 93e433b6f1977c45a5ac0e7c4186e3a421399f46 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 21:20:26 -0400 Subject: [PATCH 134/258] [Cute,Bwd,Sm90] Try dO_stage=1, PdS_stage=1 --- flash_attn/cute/flash_bwd_sm90.py | 169 +++++++++++++++++------------- flash_attn/cute/interface.py | 4 +- 2 files changed, 99 insertions(+), 74 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 6021ffa8584..d5db25372a3 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -466,7 +466,7 @@ def kernel( pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) - pipeline_q = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_Q = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, @@ -474,7 +474,7 @@ def kernel( tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], init_wait=False, ) - pipeline_do = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_dO = pipeline.PipelineTmaAsyncNoCluster.create( barrier_storage=storage.mbar_ptr_dO.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, @@ -547,8 +547,8 @@ def kernel( tma_atom_K, tma_atom_V, tma_atom_dO, - pipeline_q, - pipeline_do, + pipeline_Q, + pipeline_dO, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -580,8 +580,8 @@ def kernel( sLSE, sdPsum, sdQaccum, - pipeline_q, - pipeline_do, + pipeline_Q, + pipeline_dO, tidx, tma_atom_dK, tma_atom_dV, @@ -612,8 +612,8 @@ def load( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_do: cutlass.pipeline.PipelineAsync, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -621,13 +621,16 @@ def load( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 if warp_idx_in_wg == 0: - producer_state = pipeline.make_pipeline_state( + producer_state_Q = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) - + producer_state_dO = producer_state_Q + if const_expr(self.dO_stage != self.Q_stage): + producer_state_dO = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) @@ -654,45 +657,51 @@ def load( load_Q, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), gQ, sQ ) - load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_q) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) load_dO, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dO, 0, cute.make_layout(1), gdO, sdO ) - load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_do) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) - load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_q) + load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q) load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) - load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_do) + load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) # First iteration: load K together w Q & LSE, then V together w dO & dPsum m_block = m_block_min - pipeline_q.producer_acquire(producer_state, extra_tx_count=self.tma_copy_bytes["K"]) - load_K(tma_bar_ptr=pipeline_q.producer_get_barrier(producer_state)) - load_Q(m_block, producer_state=producer_state) + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + load_Q(m_block, producer_state=producer_state_Q) # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state) - pipeline_do.producer_acquire( - producer_state, extra_tx_count=self.tma_copy_bytes["V"] + load_LSE(m_block, producer_state=producer_state_Q) + pipeline_dO.producer_acquire( + producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"] ) - load_V(tma_bar_ptr=pipeline_do.producer_get_barrier(producer_state)) - load_dO(m_block, producer_state=producer_state) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) + load_dO(m_block, producer_state=producer_state_dO) with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state) - producer_state.advance() + load_dPsum(m_block, producer_state=producer_state_dO) + producer_state_Q.advance() + if const_expr(self.Q_stage != self.dO_stage): + producer_state_dO.advance() # Subsequent iterations: load Q & LSE, then dO & dPsum for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - pipeline_q.producer_acquire(producer_state) - load_Q(m_block, producer_state=producer_state) + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state) - pipeline_do.producer_acquire(producer_state) - load_dO(m_block, producer_state=producer_state) + load_LSE(m_block, producer_state=producer_state_Q) + pipeline_dO.producer_acquire(producer_state_dO) + load_dO(m_block, producer_state=producer_state_dO) with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state) - producer_state.advance() + load_dPsum(m_block, producer_state=producer_state_dO) + producer_state_Q.advance() + if const_expr(self.dO_stage != self.Q_stage): + producer_state_dO.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -717,8 +726,8 @@ def mma( sLSE: cute.Tensor, sdPsum: cute.Tensor, sdQaccum: cute.Tensor, - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_do: cutlass.pipeline.PipelineAsync, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, tidx: Int32, tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, @@ -821,8 +830,8 @@ def mma( mma_pdo_fn=mma_pdo_fn, mma_dsq_fn=mma_dsq_fn, mma_dsk_fn=mma_dsk_fn, - pipeline_q=pipeline_q, - pipeline_do=pipeline_do, + pipeline_Q=pipeline_Q, + pipeline_dO=pipeline_dO, tLSEsLSE=tLSEsLSE, tLSEsdPsum=tLSEsdPsum, tPsP=tPsP, @@ -835,9 +844,14 @@ def mma( # acc_dK=acc_dK, ) - consumer_state = pipeline.make_pipeline_state( + consumer_state_Q = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) + consumer_state_dO = consumer_state_Q + if const_expr(self.dO_stage != self.Q_stage): + consumer_state_dO = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -847,8 +861,11 @@ def mma( # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) dKV_should_accumulate = False for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - consumer_state = mma_one_m_block_all( - m_block, consumer_state, dKV_should_accumulate=dKV_should_accumulate + consumer_state_Q, consumer_state_dO = mma_one_m_block_all( + m_block, + consumer_state_Q, + consumer_state_dO, + dKV_should_accumulate=dKV_should_accumulate, ) dKV_should_accumulate = True @@ -879,15 +896,16 @@ def mma( def mma_one_m_block( self, m_block: Int32, - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + smem_pipe_read_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + smem_pipe_read_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, warp_group_idx: Int32, mma_qk_fn: Callable, mma_dov_fn: Callable, mma_pdo_fn: Callable, mma_dsq_fn: Callable, mma_dsk_fn: Callable, - pipeline_q: cutlass.pipeline.PipelineAsync, - pipeline_do: cutlass.pipeline.PipelineAsync, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, tLSEsLSE: cute.Tensor, tLSEsdPsum: cute.Tensor, tPsP: Optional[cute.Tensor], @@ -900,16 +918,20 @@ def mma_one_m_block( # acc_dK, dKV_should_accumulate: Boolean = True, ): - smem_idx = smem_pipe_read.index + smem_idx_Q = smem_pipe_read_Q.index + smem_idx_dO = smem_pipe_read_dO.index + smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0 # (1) [GEMM 1] S = Q @ K^T - pipeline_q.consumer_wait(smem_pipe_read, pipeline_q.consumer_try_wait(smem_pipe_read)) - acc_S = mma_qk_fn(A_idx=smem_idx, wg_wait=-1) + pipeline_Q.consumer_wait(smem_pipe_read_Q, pipeline_Q.consumer_try_wait(smem_pipe_read_Q)) + acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) # S2R for LSE tLSErLSE = cute.make_fragment_like(tLSEsLSE[None, 0]) - cute.autovec_copy(tLSEsLSE[None, smem_idx], tLSErLSE) + cute.autovec_copy(tLSEsLSE[None, smem_idx_Q], tLSErLSE) # (2) [GEMM 2] dP = dO @ V.T - pipeline_do.consumer_wait(smem_pipe_read, pipeline_do.consumer_try_wait(smem_pipe_read)) - acc_dP = mma_dov_fn(A_idx=smem_idx, wg_wait=1) + pipeline_dO.consumer_wait( + smem_pipe_read_dO, pipeline_dO.consumer_try_wait(smem_pipe_read_dO) + ) + acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) # (3) [Pointwise 1] P = exp(S - LSE) acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) @@ -927,17 +949,17 @@ def mma_one_m_block( tdVrP.store(tdVrP_acc.load().to(self.dtype)) # S2R for dPsum tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) - cute.autovec_copy(tLSEsdPsum[None, smem_idx], tLSErdPsum) + cute.autovec_copy(tLSEsdPsum[None, smem_idx_dO], tLSErdPsum) - PdS_smem_idx = smem_idx if const_expr(self.PdS_stage > 1) else 0 # R2S for P - tPrP = smem_thr_copy_PdS.retile(tdVrP) - # sync to make sure P has already been used in the previous iteration before writing new vals - if const_expr(self.PdS_stage == 1): - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads - ) - cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, PdS_smem_idx]) + if const_expr(not self.Mma_dKV_is_RS): + # sync to ensure P has already been used in the previous iteration before overwriting + if const_expr(self.PdS_stage == 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + tPrP = smem_thr_copy_PdS.retile(tdVrP) + cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS]) # (4) [Pointwise 2] dS = P*(dP-dPsum) warpgroup.wait_group(0) @@ -960,20 +982,21 @@ def mma_one_m_block( # this race condition is not possible. # This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and # (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads - ) + if const_expr(not self.Mma_dKV_is_RS or (self.PdS_stage == 1 and self.Mma_dKV_is_RS)): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) # R2S for dS tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) - cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, PdS_smem_idx]) + cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS]) # (5) [GEMM 3] dV += P.T @ dO mma_pdo_fn( - A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=-1 + A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_should_accumulate, wg_wait=-1 ) # smem fence to make sure sdS is written before it's read by WGMMA @@ -984,13 +1007,13 @@ def mma_one_m_block( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) # (6) [GEMM 4] dQ = dS @ K - acc_dQ = mma_dsk_fn(A_idx=PdS_smem_idx, wg_wait=1) + acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) - pipeline_do.consumer_release(smem_pipe_read) # release dO as dV mma is done + pipeline_dO.consumer_release(smem_pipe_read_dO) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q mma_dsq_fn( - A_idx=PdS_smem_idx, B_idx=smem_idx, zero_init=not dKV_should_accumulate, wg_wait=1 + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_should_accumulate, wg_wait=1 ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) @@ -1010,11 +1033,13 @@ def mma_one_m_block( warpgroup.wait_group(0) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) - pipeline_q.consumer_release(smem_pipe_read) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_q consumer release", cute.arch.thread_idx()[0], m_block) + pipeline_Q.consumer_release(smem_pipe_read_Q) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) - smem_pipe_read.advance() - return smem_pipe_read + smem_pipe_read_Q.advance() + if const_expr(self.Q_stage != self.dO_stage): + smem_pipe_read_dO.advance() + return smem_pipe_read_Q, smem_pipe_read_dO @cute.jit def epilogue_dKV( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a41bfa0fe3c..70cd5a9da1d 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -449,8 +449,8 @@ def _flash_attn_bwd( m_block_size = 64 n_block_size = 128 num_stages_Q = 2 - num_stages_dO = 2 - num_stages_PdS = 2 + num_stages_dO = 1 + num_stages_PdS = 1 AtomLayoutMSdP = 1 AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 From 57d0ce99cba657c565f2112164b170a84d7a94a2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 22:07:23 -0400 Subject: [PATCH 135/258] [Cute,Bwd,Sm90] Make causal work --- flash_attn/cute/block_info.py | 2 +- flash_attn/cute/flash_bwd_sm90.py | 41 +++++++++++++++++++++++++------ flash_attn/cute/mask.py | 2 +- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 9e911fdd581..9f50321a28c 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -50,7 +50,7 @@ def get_m_block_min_max( m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 if const_expr(self.is_causal): - m_block_min = max(m_block_min, cute.ceil_div(seqlen_info.seqlen_q - seqlen_info.seqlen_k + (n_block + 1) * self.tile_n, self.tile_m)) + m_block_min = max(m_block_min, (n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k) // self.tile_m) return m_block_min, m_block_max @cute.jit diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index d5db25372a3..cff3722e593 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -14,6 +14,7 @@ from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils from flash_attn.cute import copy_utils +from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline @@ -57,6 +58,7 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal + self.is_local = False self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads @@ -509,8 +511,8 @@ def kernel( block_info = BlockInfo( self.tile_m, self.tile_n, - False, - False, + self.is_causal, + self.is_local, None, None, qhead_per_kvhead_packgqa=1, @@ -524,7 +526,13 @@ def kernel( mSeqUsedQ=None, mSeqUsedK=None, ) - + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=None, + window_size_right=None, + ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) if warp_idx < 4: @@ -590,6 +598,7 @@ def kernel( softmax_scale, block_info, SeqlenInfoCls, + AttentionMaskCls, TileSchedulerCls, ) @@ -695,6 +704,8 @@ def load( # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): load_LSE(m_block, producer_state=producer_state_Q) + if const_expr(self.Q_stage == self.dO_stage): + producer_state_dO = producer_state_Q pipeline_dO.producer_acquire(producer_state_dO) load_dO(m_block, producer_state=producer_state_dO) with cute.arch.elect_one(): @@ -736,6 +747,7 @@ def mma( softmax_scale: Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, TileSchedulerCls: Callable, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -857,6 +869,15 @@ def mma( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) dKV_should_accumulate = False @@ -865,6 +886,7 @@ def mma( m_block, consumer_state_Q, consumer_state_dO, + mask_fn=mask_fn, dKV_should_accumulate=dKV_should_accumulate, ) dKV_should_accumulate = True @@ -914,6 +936,7 @@ def mma_one_m_block( smem_thr_copy_PdS: cute.TiledCopy, smem_thr_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: Float32, + mask_fn: Optional[Callable] = None, # acc_dV, # acc_dK, dKV_should_accumulate: Boolean = True, @@ -933,6 +956,8 @@ def mma_one_m_block( ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) # (3) [Pointwise 1] P = exp(S - LSE) + if cutlass.const_expr(mask_fn is not None): + mask_fn(acc_S, m_block=m_block) acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): @@ -945,8 +970,8 @@ def mma_one_m_block( # Convert P from f32 -> f16 tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) - # utils.cvt_f16(tdVrP_acc, tdVrP) - tdVrP.store(tdVrP_acc.load().to(self.dtype)) + utils.cvt_f16(tdVrP_acc, tdVrP) + # tdVrP.store(tdVrP_acc.load().to(self.dtype)) # S2R for dPsum tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) cute.autovec_copy(tLSEsdPsum[None, smem_idx_dO], tLSErdPsum) @@ -973,8 +998,8 @@ def mma_one_m_block( # Convert dS from f32 -> f16 tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) tdKrdS = cute.make_fragment_like(tdKrdS_acc, self.dtype) - # utils.cvt_f16(tdKrdS_acc, tdKrdS) - tdKrdS.store(tdKrdS_acc.load().to(self.dtype)) + utils.cvt_f16(tdKrdS_acc, tdKrdS) + # tdKrdS.store(tdKrdS_acc.load().to(self.dtype)) # If there's double buffering on dS, we don't need to sync here. # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. @@ -1039,6 +1064,8 @@ def mma_one_m_block( smem_pipe_read_Q.advance() if const_expr(self.Q_stage != self.dO_stage): smem_pipe_read_dO.advance() + else: + smem_pipe_read_dO = smem_pipe_read_Q return smem_pipe_read_Q, smem_pipe_read_dO @cute.jit diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 9b20323aebe..246271f55f8 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -41,7 +41,7 @@ def apply_mask( seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if cutlass.const_expr(not mask_causal and not mask_local): if cutlass.const_expr(mask_seqlen): - if cutlass.const_expr(False): + if cutlass.const_expr(True): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit From 89b94f84ae2b55dd27ce4af4fa60bbd01708c2ca Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 13 Oct 2025 23:29:32 -0400 Subject: [PATCH 136/258] [Cute,Bwd,Sm90] Implement dQ_swapAB --- flash_attn/cute/flash_bwd_postprocess.py | 28 +++++--- flash_attn/cute/flash_bwd_sm90.py | 84 ++++++++++++++++-------- flash_attn/cute/hopper_helpers.py | 16 +++-- flash_attn/cute/interface.py | 1 + 4 files changed, 86 insertions(+), 43 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 9ca76e3c9ba..22b227227b0 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -80,25 +80,29 @@ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: def _get_tiled_mma(self): if const_expr(self.arch == 80): num_mma_warps = self.num_threads // 32 - AtomLayoutdQ = ( + atom_layout_dQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) ) tiled_mma = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), - AtomLayoutdQ, - permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + atom_layout_dQ, + permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), ) else: + num_mma_warp_groups = self.num_threads // 128 + atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(self.tile_m // 64, 2, 1), - tiler_mn=(64, self.tile_hdim // 2), + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], ) assert self.num_threads == tiled_mma.size return tiled_mma @@ -305,6 +309,7 @@ def kernel( smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + sdQt = utils.transpose_view(sdQ) seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) @@ -327,10 +332,9 @@ def kernel( # print(sdQaccum) # thr_mma = tiled_mma.get_slice(tidx) # print(tiled_mma) + tile_shape = (self.tile_m, self.tile_hdim) acc_shape = tiled_mma.partition_shape_C( - (self.tile_m, self.tile_hdim) - if const_expr(not dQ_swapAB) - else (self.tile_hdim, self.tile_m) + tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] ) acc = cute.make_fragment(acc_shape, cutlass.Float32) assert cute.size(acc) == cute.size(tdQsdQaccum) @@ -349,10 +353,14 @@ def kernel( # Step 3: Copy dQ from register to smem cute.arch.barrier() # make sure all threads have finished loading dQaccum - smem_copy_atom_dQ = utils.get_smem_store_atom(self.arch, self.dtype) + smem_copy_atom_dQ = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.dQ_swapAB + ) smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) - taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) + taccdQsdQ = smem_thr_copy_dQ.partition_D( + sdQ if const_expr(not self.dQ_swapAB) else sdQt + ) cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) # print(taccdQrdQ) # print(taccdQsdQ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index cff3722e593..a15001225f2 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -14,6 +14,7 @@ from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute import utils from flash_attn.cute import copy_utils +from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo @@ -22,6 +23,21 @@ from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +def mma_partition_fragment_AB( + thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool +): + if const_expr(not swap_AB): + return ( + thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None, + thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None, + ) + else: + return ( + thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None, + thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None, + ) + + class FlashAttentionBackwardSm90: arch = 90 @@ -67,6 +83,9 @@ def __init__( self.PdS_stage = PdS_stage assert self.dO_stage in [1, self.Q_stage] assert self.PdS_stage in [1, self.Q_stage] + self.SdP_swapAB = SdP_swapAB + self.dKV_swapAB = dKV_swapAB + self.dQ_swapAB = dQ_swapAB self.AtomLayoutMSdP = AtomLayoutMSdP self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ @@ -163,8 +182,9 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=atom_layout_SdP + (1,), - tiler_mn=tiler_mn_SdP, + atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1]) + + (1,), + tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1], ) # dV = P.T @ dO, dK = dS.T @ Q atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) @@ -177,8 +197,9 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=atom_layout_dKV + (1,), - tiler_mn=tiler_mn_d, + atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) + + (1,), + tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], a_source=warpgroup.OperandSource.RMEM if self.Mma_dKV_is_RS else warpgroup.OperandSource.SMEM, @@ -191,11 +212,11 @@ def _get_tiled_mma(self): tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=atom_layout_dQ + (1,), - tiler_mn=tiler_mn_dQ, + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], ) return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ @@ -493,7 +514,6 @@ def kernel( if const_expr(not self.Mma_dKV_is_RS): sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) - sLSE = storage.sLSE.get_tensor( cute.make_layout( (self.tile_m, self.Q_stage), @@ -760,27 +780,20 @@ def mma( wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) # S = Q @ K.T - tSrQ = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sQ)) - tSrK = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sK)) + tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB) # dP = dO @ V.T - tdPrdO = tiled_mma_SdP.make_fragment_A(wg_mma_SdP.partition_A(sdO)) - tdPrV = tiled_mma_SdP.make_fragment_B(wg_mma_SdP.partition_B(sV)) + tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB) # dV += P.T @ dO sPt = utils.transpose_view(sP) if sP is not None else None sdOt = utils.transpose_view(sdO) - tdVrPt = None - if const_expr(sP is not None): - tdVrPt = tiled_mma_dV.make_fragment_A(wg_mma_dV.partition_A(sPt)) - tdVrdOt = tiled_mma_dV.make_fragment_B(wg_mma_dV.partition_B(sdOt)) + tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB) # dK += dS.T @ Q sdSt = utils.transpose_view(sdS) sQt = utils.transpose_view(sQ) - tdKrdSt = tiled_mma_dK.make_fragment_A(wg_mma_dK.partition_A(sdSt)) - tdKrQt = tiled_mma_dK.make_fragment_B(wg_mma_dK.partition_B(sQt)) + tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB) # dQ = dS @ K sKt = utils.transpose_view(sK) - tdQrdS = tiled_mma_dQ.make_fragment_A(wg_mma_dQ.partition_A(sdS)) - tdQrKt = tiled_mma_dQ.make_fragment_B(wg_mma_dQ.partition_B(sKt)) + tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB) # Smem copy atom tiling smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) @@ -823,15 +836,30 @@ def mma( ) mma_qk_fn = partial( - sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tSrQ, tSrK + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tSrQ, + tSrK, + swap_AB=self.SdP_swapAB, ) mma_dov_fn = partial( - sm90_utils.gemm_zero_init, tiled_mma_SdP, (self.tile_m, self.tile_n), tdPrdO, tdPrV + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tdPrdO, + tdPrV, + swap_AB=self.SdP_swapAB, ) - mma_pdo_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) - mma_dsq_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) mma_dsk_fn = partial( - sm90_utils.gemm_zero_init, tiled_mma_dQ, (self.tile_m, self.tile_hdim), tdQrdS, tdQrKt + gemm_zero_init, + tiled_mma_dQ, + (self.tile_m, self.tile_hdim), + tdQrdS, + tdQrKt, + swap_AB=self.dQ_swapAB, ) mma_one_m_block_all = partial( @@ -1046,8 +1074,8 @@ def mma_one_m_block( barrier_id=int(NamedBarrierBwd.dQEmpty), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) - tdQrdQaccum_tmp = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) - cute.copy(smem_thr_copy_dQaccum, tdQrdQaccum_tmp, tdQsdQaccum) + tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.copy(smem_thr_copy_dQaccum, tdQrdQaccum_flat, tdQsdQaccum) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 2597cd4a566..14e6bf8ceb0 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -45,12 +45,18 @@ def gemm_zero_init( A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, wg_wait: int = -1, + swap_AB: bool = False, ) -> cute.Tensor: - acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) - return acc + if const_expr(swap_AB): + return gemm_zero_init( + tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False + ) + else: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc def gemm_w_idx( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 70cd5a9da1d..ba5c3526119 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -451,6 +451,7 @@ def _flash_attn_bwd( num_stages_Q = 2 num_stages_dO = 1 num_stages_PdS = 1 + dQ_swapAB = True AtomLayoutMSdP = 1 AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 From 54d8aa6751fc9d5f0357854079261913d5df1f9d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Oct 2025 00:19:13 -0400 Subject: [PATCH 137/258] [Cute,Bwd,Sm90] Implement SdP_swapAB --- flash_attn/cute/flash_bwd_sm90.py | 28 +++++++++++------- flash_attn/cute/interface.py | 4 ++- flash_attn/cute/mask.py | 1 + flash_attn/cute/utils.py | 48 ++++++++++++++++--------------- 4 files changed, 47 insertions(+), 34 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index a15001225f2..c8a2899c216 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -95,6 +95,7 @@ def __init__( and AtomLayoutNdKV == self.num_mma_warp_groups and SdP_swapAB and not dKV_swapAB + and False # TODO ) self.V_in_regs = V_in_regs @@ -119,7 +120,6 @@ def can_implement( return False if num_threads % 32 != 0: return False - if (tile_m * 2) % num_threads != 0: return False return True @@ -796,14 +796,16 @@ def mma( tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB) # Smem copy atom tiling - smem_copy_atom_PdS = utils.get_smem_store_atom(self.arch, self.dtype) + smem_copy_atom_PdS = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.SdP_swapAB + ) smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( tidx ) tPsP = None if const_expr(sP is not None): - tPsP = smem_thr_copy_PdS.partition_D(sP) - tdSsdS = smem_thr_copy_PdS.partition_D(sdS) + tPsP = smem_thr_copy_PdS.partition_D(sP if const_expr(not self.SdP_swapAB) else sPt) + tdSsdS = smem_thr_copy_PdS.partition_D(sdS if const_expr(not self.SdP_swapAB) else sdSt) sLSE_mma = cute.make_tensor( sLSE.iterator, @@ -819,19 +821,24 @@ def mma( stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) - LSEslice = (None, 0, None) + if const_expr(self.SdP_swapAB): + sLSE_mma = utils.transpose_view(sLSE_mma) + sdPsum_mma = utils.transpose_view(sdPsum_mma) + LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None) tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice] tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice] smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + dV_shape = (self.tile_n, self.tile_hdimv) acc_dV = cute.make_fragment( - tiled_mma_dV.partition_shape_C((self.tile_n, self.tile_hdimv)), + tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]), Float32, ) + dK_shape = (self.tile_n, self.tile_hdim) acc_dK = cute.make_fragment( - tiled_mma_dK.partition_shape_C((self.tile_n, self.tile_hdim)), + tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]), Float32, ) @@ -984,9 +991,10 @@ def mma_one_m_block( ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) # (3) [Pointwise 1] P = exp(S - LSE) - if cutlass.const_expr(mask_fn is not None): + # if cutlass.const_expr(mask_fn is not None): + if cutlass.const_expr(mask_fn is not None and not self.SdP_swapAB): # TODO: impl mask mask_fn(acc_S, m_block=m_block) - acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): acc_S_mn[r, None].store( @@ -1016,7 +1024,7 @@ def mma_one_m_block( # (4) [Pointwise 2] dS = P*(dP-dPsum) warpgroup.wait_group(0) - acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): acc_dP_mn[r, None].store( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ba5c3526119..a2b86ebe4ef 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -451,7 +451,9 @@ def _flash_attn_bwd( num_stages_Q = 2 num_stages_dO = 1 num_stages_PdS = 1 - dQ_swapAB = True + SdP_swapAB = False + dKV_swapAB = False + dQ_swapAB = False AtomLayoutMSdP = 1 AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 246271f55f8..1da693141cf 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -30,6 +30,7 @@ def apply_mask( mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, ) -> None: + # TODO: implement swap_AB assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 06e7824dc13..2851d59c84d 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -102,7 +102,7 @@ def mma_make_fragment_B( def get_smem_store_atom( - arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False ) -> cute.CopyAtom: if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( @@ -112,7 +112,7 @@ def get_smem_store_atom( ) else: return cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), element_type, ) @@ -135,37 +135,39 @@ def warp_reduce( return val -def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: """ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). """ acc_layout_col_major = cute.make_layout(acc_layout.shape) - acc_layout_mn = cute.make_layout( + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M ( - (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M - ( - acc_layout_col_major.shape[0][0], - *acc_layout_col_major.shape[0][2:], - acc_layout_col_major.shape[2], - ), # MMA_N - *acc_layout_col_major.shape[3:], - ), - stride=( - (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M - ( - acc_layout_col_major.stride[0][0], - *acc_layout_col_major.stride[0][2:], - acc_layout_col_major.stride[2], - ), # MMA_N - *acc_layout_col_major.stride[3:], - ), + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) return cute.composition(acc_layout, acc_layout_mn) -def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: - return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) @cute.jit From 72b793ac6ad3209cc8b4361b3d3d55c5c62c951d Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 14 Oct 2025 09:05:24 -0400 Subject: [PATCH 138/258] [AMD] Torch Compile Issues (#1756) * fix rounding and dropout metdata bug * fix lse shape and bug in interface * return softmax is true --- flash_attn/flash_attn_interface.py | 22 ++++++++++++++----- .../bwd_prefill_split.py | 2 +- .../flash_attn_triton_amd/fwd_prefill.py | 5 +++-- .../flash_attn_triton_amd/interface_fa.py | 22 +++++++------------ flash_attn/flash_attn_triton_amd/utils.py | 9 ++++---- 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 535bd416745..865f1db5432 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -127,7 +127,10 @@ def _flash_attn_forward_fake( softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -220,10 +223,11 @@ def _flash_attn_varlen_forward_fake( out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -315,7 +319,10 @@ def _flash_attn_backward_fake( if dv is None: dv = torch.empty_like(v) batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) return softmax_d @@ -426,7 +433,10 @@ def _flash_attn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py index c1e2ff5985f..5cc93edc5e4 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -1161,7 +1161,7 @@ def attention_prefill_backward_triton_split_impl( delta = torch.zeros_like(softmax_lse) if IS_VARLEN: stride_deltab = 0 - stride_deltam, stride_deltah = delta.stride() + stride_deltah, stride_deltam = delta.stride() else: stride_deltab, stride_deltah, stride_deltam = delta.stride() pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index dec5673e3e5..6f69cd02813 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -621,8 +621,9 @@ def attention_prefill_forward_triton_impl( # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: - softmax_lse = torch.zeros((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) - stride_lse_m, stride_lse_h = softmax_lse.stride() + total_seqlen_q, _, _ = q.shape + softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) + stride_lse_h, stride_lse_m = softmax_lse.stride() stride_lse_z = 0 else: softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index bb6e25b509c..06ab7d24d56 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -74,11 +74,9 @@ def fwd(q: torch.Tensor, if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - else: - rng_state = None + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast # check arguments metadata.check_args(q, k, v, out) @@ -212,8 +210,7 @@ def bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p > 0.0: - assert rng_state is not None + if rng_state is not None: philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None @@ -423,11 +420,9 @@ def varlen_fwd( if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - if dropout_p > 0.0: - metadata.need_dropout(dropout_p) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - else: - rng_state = None + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast # Check arguments metadata.check_args(q, k, v, out) @@ -563,8 +558,7 @@ def varlen_bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p > 0.0: - assert rng_state is not None + if rng_state is not None: philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 0300e3902a1..5d3bf02e1f8 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -112,11 +112,10 @@ def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): self.rotary_interleaved = rotary_interleaved self.rotary_conjunction = rotary_conjunction - def need_dropout(self, dropout_p, return_scores = True): - if dropout_p > 0.0: - self.dropout_p = dropout_p - self.return_scores = return_scores - self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 + def need_dropout(self, dropout_p, return_softmax = True): + self.dropout_p = dropout_p + self.return_softmax = return_softmax + self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() From 5685ace888875846002f7cb7879aaf08f87b0049 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Oct 2025 12:45:47 -0400 Subject: [PATCH 139/258] [Cute,Bwd,Sm90] Implement mma_dkv_is_rs --- flash_attn/cute/flash_bwd_sm90.py | 77 +++++++++++++++++++------------ flash_attn/cute/hopper_helpers.py | 10 ++-- flash_attn/cute/interface.py | 4 +- 3 files changed, 57 insertions(+), 34 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index c8a2899c216..45aa80f86c2 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -90,12 +90,11 @@ def __init__( self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ self.num_mma_warp_groups = (self.num_threads // 128) - 1 - self.Mma_dKV_is_RS = ( + self.mma_dkv_is_rs = ( AtomLayoutMSdP == 1 and AtomLayoutNdKV == self.num_mma_warp_groups and SdP_swapAB and not dKV_swapAB - and False # TODO ) self.V_in_regs = V_in_regs @@ -194,14 +193,16 @@ def _get_tiled_mma(self): sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, - warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN + if not self.mma_dkv_is_rs + else warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) + (1,), tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], a_source=warpgroup.OperandSource.RMEM - if self.Mma_dKV_is_RS + if self.mma_dkv_is_rs else warpgroup.OperandSource.SMEM, ) for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) @@ -235,7 +236,7 @@ def _get_shared_storage_cls(self): ] cosize_sdS = cute.cosize(self.sPdS_layout) - cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.Mma_dKV_is_RS) else 0 + cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0 sLSE_struct = cute.struct.Align[ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128 ] @@ -511,7 +512,7 @@ def kernel( sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sP = None - if const_expr(not self.Mma_dKV_is_RS): + if const_expr(not self.mma_dkv_is_rs): sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sLSE = storage.sLSE.get_tensor( @@ -858,8 +859,17 @@ def mma( tdPrV, swap_AB=self.SdP_swapAB, ) - mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt) - mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt) + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn = partial( + gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB + ) + mma_dsq_fn = partial( + gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB + ) + else: + assert not self.dKV_swapAB + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) mma_dsk_fn = partial( gemm_zero_init, tiled_mma_dQ, @@ -915,17 +925,18 @@ def mma( ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) - dKV_should_accumulate = False + dKV_accumulate = False for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): consumer_state_Q, consumer_state_dO = mma_one_m_block_all( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn, - dKV_should_accumulate=dKV_should_accumulate, + dKV_accumulate=dKV_accumulate, ) - dKV_should_accumulate = True + dKV_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) # scale dK acc_dK.store(acc_dK.load() * softmax_scale) self.epilogue_dKV( @@ -974,7 +985,7 @@ def mma_one_m_block( mask_fn: Optional[Callable] = None, # acc_dV, # acc_dK, - dKV_should_accumulate: Boolean = True, + dKV_accumulate: Boolean = True, ): smem_idx_Q = smem_pipe_read_Q.index smem_idx_dO = smem_pipe_read_dO.index @@ -1003,17 +1014,17 @@ def mma_one_m_block( ) ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) + # S2R for dPsum + tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) + cute.autovec_copy(tLSEsdPsum[None, smem_idx_dO], tLSErdPsum) + # Convert P from f32 -> f16 tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) utils.cvt_f16(tdVrP_acc, tdVrP) # tdVrP.store(tdVrP_acc.load().to(self.dtype)) - # S2R for dPsum - tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) - cute.autovec_copy(tLSEsdPsum[None, smem_idx_dO], tLSErdPsum) - # R2S for P - if const_expr(not self.Mma_dKV_is_RS): + if const_expr(not self.mma_dkv_is_rs): # sync to ensure P has already been used in the previous iteration before overwriting if const_expr(self.PdS_stage == 1): cute.arch.barrier( @@ -1041,9 +1052,9 @@ def mma_one_m_block( # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. # But because both WGs have to sync at the end of the loop and double buffering, # this race condition is not possible. - # This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and - # (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. - if const_expr(not self.Mma_dKV_is_RS or (self.PdS_stage == 1 and self.Mma_dKV_is_RS)): + # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and + # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. + if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) @@ -1056,9 +1067,12 @@ def mma_one_m_block( cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS]) # (5) [GEMM 3] dV += P.T @ dO - mma_pdo_fn( - A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_should_accumulate, wg_wait=-1 - ) + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1 + ) + else: + mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_proxy( @@ -1073,9 +1087,12 @@ def mma_one_m_block( pipeline_dO.consumer_release(smem_pipe_read_dO) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q - mma_dsq_fn( - A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_should_accumulate, wg_wait=1 - ) + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) cute.arch.barrier( @@ -1134,7 +1151,7 @@ def epilogue_dKV( ) smem_copy_atom_dKV = cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=self.dKV_swapAB, num_matrices=4), self.dtype, ) smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice(tidx) @@ -1153,7 +1170,8 @@ def epilogue_dKV( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # rmem -> smem taccdVrdV = smem_thr_copy_dV.retile(rdV) - taccdVsdV = smem_thr_copy_dV.partition_D(sV) # reuse sV SMEM + sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV) # reuse sV SMEM + taccdVsdV = smem_thr_copy_dV.partition_D(sdV) cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) # ensure smem writes are visible to TMA cute.arch.fence_proxy( @@ -1165,7 +1183,8 @@ def epilogue_dKV( if warp_idx == 4: store_dV() taccdKrdK = smem_thr_copy_dK.retile(rdK) - taccdKsdK = smem_thr_copy_dK.partition_D(sK) # reuse sK SMEM + sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK) # reuse sK SMEM + taccdKsdK = smem_thr_copy_dK.partition_D(sdK) # reuse sK SMEM cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) # ensure smem writes are visible to TMA cute.arch.fence_proxy( diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 14e6bf8ceb0..1016a4189fe 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -68,10 +68,14 @@ def gemm_w_idx( A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, wg_wait: int = -1, + swap_AB: bool = False, ) -> None: - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + if const_expr(swap_AB): + gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) @dsl_user_op diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a2b86ebe4ef..47526e6bfef 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -450,8 +450,8 @@ def _flash_attn_bwd( n_block_size = 128 num_stages_Q = 2 num_stages_dO = 1 - num_stages_PdS = 1 - SdP_swapAB = False + num_stages_PdS = 2 + SdP_swapAB = True dKV_swapAB = False dQ_swapAB = False AtomLayoutMSdP = 1 From a76e692a6eb13121c27db6187629acacda6160bc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Oct 2025 16:55:25 -0400 Subject: [PATCH 140/258] [Cute,Bwd,Sm90] Use block size 80x128 --- flash_attn/cute/flash_bwd_sm90.py | 8 ++++++++ flash_attn/cute/interface.py | 28 ++++++++++++++-------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 45aa80f86c2..3d2ae593160 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -97,6 +97,14 @@ def __init__( and not dKV_swapAB ) self.V_in_regs = V_in_regs + # These are tuned for speed + # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share + # them and then shuffle to get the value whenever we need? This can reduce register + # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) + # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. + # TODO: impl these for hdim 64 + self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 + self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 @staticmethod def can_implement( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 47526e6bfef..507899c6d26 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -312,8 +312,19 @@ def _flash_attn_bwd( seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + m_block_size = 80 if not causal else 64 + n_block_size = 128 + num_stages_Q = 2 + num_stages_dO = 2 + num_stages_PdS = 2 + SdP_swapAB = True + dKV_swapAB = False + dQ_swapAB = not causal + AtomLayoutMSdP = 1 + AtomLayoutNdKV = 2 + AtomLayoutMdQ = 1 q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ - maybe_contiguous(t) + maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] num_head, head_dim = q.shape[-2:] @@ -344,7 +355,7 @@ def _flash_attn_bwd( assert v.shape == (total_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" - if cu_seqlens_q is not None: + if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" assert out.shape == (total_q, num_head, head_dim_v) @@ -436,7 +447,7 @@ def _flash_attn_bwd( dq_accum_tensor, cu_seqlens_q_tensor, seqused_q_tensor, current_stream ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, + o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, cu_seqlens_q_tensor, seqused_q_tensor, current_stream ) @@ -446,17 +457,6 @@ def _flash_attn_bwd( n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs ) - m_block_size = 64 - n_block_size = 128 - num_stages_Q = 2 - num_stages_dO = 1 - num_stages_PdS = 2 - SdP_swapAB = True - dKV_swapAB = False - dQ_swapAB = False - AtomLayoutMSdP = 1 - AtomLayoutNdKV = 2 - AtomLayoutMdQ = 1 num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: fa_bwd_sm80 = FlashAttentionBackwardSm80( From 6bc3d1f59f5c843c9ccbc4f0d14cfe02b5e88ab3 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 15 Oct 2025 12:24:05 -0700 Subject: [PATCH 141/258] [CUTE] Enable Pack GQA for score mods (#1937) --- flash_attn/cute/flash_fwd.py | 7 +-- flash_attn/cute/flash_fwd_sm100.py | 18 +++++-- flash_attn/cute/interface.py | 2 - flash_attn/cute/softmax.py | 46 ++++++++++++++-- tests/cute/test_score_mod.py | 84 ++++++++---------------------- 5 files changed, 81 insertions(+), 76 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 222d0790967..75232662d0d 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -601,7 +601,7 @@ def __call__( fastdiv_mods = None if cutlass.const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) + seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -1250,7 +1250,7 @@ def __call__( fastdiv_mods = None if cutlass.const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) + seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -1939,7 +1939,8 @@ def apply_score_mod( self.qk_acc_dtype, buffers, fastdiv_mods, - constant_q_idx=None + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) def warp_scheduler_barrier_sync(self): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index cb52f157ad3..0a93f3d044f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -490,7 +490,7 @@ class SharedStorage: fastdiv_mods = None if cutlass.const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) + seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -1987,10 +1987,19 @@ def apply_score_mod( tScS_t2r = thr_tmem_load.partition_D(tScS) # Shared q_idx for all scores - q_idx_wrapped = tScS_t2r[0][0] + q_idx_logical = tScS_t2r[0][0] + + # For Pack-GQA, compute the logical head index for this tile + if cutlass.const_expr(self.pack_gqa): + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_physical = q_idx_logical + q_idx_logical = q_physical // self.qhead_per_kvhead + head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead + head_idx = head_idx * self.qhead_per_kvhead + head_offset + if cutlass.const_expr(buffers is not None): seqlen_q_divmod, _ = fastdiv_mods - _, q_idx_wrapped = seqlen_q_divmod.divmod(tScS_t2r[0][0]) + _, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical) apply_score_mod_inner( tSrS_t2r, @@ -2003,5 +2012,6 @@ def apply_score_mod( self.qk_acc_dtype, buffers, fastdiv_mods, - constant_q_idx=q_idx_wrapped + constant_q_idx=q_idx_logical, + qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 507899c6d26..07a6c48bfbf 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -211,8 +211,6 @@ def _flash_attn_fwd( is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None if is_varlen: raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") - if pack_gqa: - raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") cute_buffers = None if buffers is not None: diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 398f9e40c55..72de115732a 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -316,6 +316,17 @@ def scale_apply_exp2_convert( ) +@cute.jit +def floor_if_packed( + q_idx, + qhead_per_kvhead: cutlass.Constexpr[int], +) -> cute.Tensor: + """Convert q_idx to packed format for Pack-GQA.""" + if cutlass.const_expr(qhead_per_kvhead == 1): + return q_idx + return q_idx // qhead_per_kvhead + + @cute.jit def apply_score_mod_inner( score_tensor, @@ -329,6 +340,7 @@ def apply_score_mod_inner( buffers, fastdiv_mods, constant_q_idx: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, ): """Shared implementation for applying score modification. @@ -345,26 +357,42 @@ def apply_score_mod_inner( fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping constant_q_idx: If provided, use this constant for all q_idx values If None, compute q_idx per-element + qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this + when greater than 1 so score mods see logical heads. """ n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) score_vec = cute.make_fragment(vec_size, qk_acc_dtype) kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) - # SSA values for batch and head (constant across all elements) + # SSA values for batch (constant across all elements) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) - head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) # Handle q_idx based on whether it's constant q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + # For Pack-GQA with non-constant q_idx, we need per-element head indices + # since a thread my process multiple query head indices + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): for j in cutlass.range(vec_size, unroll_full=True): score_vec[j] = score_tensor[i + j] * softmax_scale + # Extract head offset from packed q_idx for Pack-GQA + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + q_idx_packed = index_tensor[i + j][0] + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_idx_logical = q_idx_packed // qhead_per_kvhead + head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead + head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset + # If we will do loads we mod, in order to not read OOB if cutlass.const_expr(buffers is not None and fastdiv_mods is not None): if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods - _, q_idx_wrapped = seqlen_q_divmod.divmod(index_tensor[i + j][0]) + q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) + _, q_idx_wrapped = seqlen_q_divmod.divmod(q_idx_floored) q_idx_vec[j] = q_idx_wrapped else: _, seqlen_k_divmod = fastdiv_mods @@ -374,7 +402,7 @@ def apply_score_mod_inner( else: # No bounds checking - direct indexing if constant_q_idx is None: - q_idx_vec[j] = index_tensor[i + j][0] + q_idx_vec[j] = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) kv_idx_vec[j] = index_tensor[i + j][1] # Convert to SSA for score_mod call @@ -383,7 +411,15 @@ def apply_score_mod_inner( if cutlass.const_expr(constant_q_idx is None): q_idx_ssa = q_idx_vec.load() else: - q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,)) + # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical + q_idx_const = constant_q_idx + q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,)) + + # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_ssa = head_idx_vec.load() + else: + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) buffer_args = [] if cutlass.const_expr(buffers is not None): diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 014d7969184..0d8b2234467 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -248,7 +248,7 @@ def create_tensors( return q, k, v -def run_cute_flash(q, k, v, cute_score_mod, buffers=None) -> torch.Tensor: +def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> torch.Tensor: q_transposed, k_transposed, v_transposed = map( lambda x: x.transpose(1, 2), (q, k, v) ) @@ -262,6 +262,7 @@ def run_cute_flash(q, k, v, cute_score_mod, buffers=None) -> torch.Tensor: out=out, lse=None, buffers=buffers, + pack_gqa=pack_gqa, ) return out.transpose(1, 2) @@ -297,21 +298,26 @@ def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: (4224, 4224), ], ) -@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) -def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, num_heads, dtype, score_mod_pair): +def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair): torch.random.manual_seed(42) cute_score_mod, eager_score_mod = score_mod_pair + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( - seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_heads, dtype=dtype + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash(q, k, v, cute_score_mod) + out_cute = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -367,23 +373,28 @@ def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, num_heads, dtype, score_mod (4224, 4224), ], ) -@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_BUFFERS) def test_cute_vs_flex_attention_with_buffers( - seqlen_q, seqlen_kv, num_heads, dtype, score_mod_pair + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair ): torch.random.manual_seed(42) cute_score_mod, eager_score_mod_factory = score_mod_pair batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( batch_size=batch_size, seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, - num_heads=num_heads, + num_heads=num_q_heads, dtype=dtype, ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 @@ -391,17 +402,17 @@ def test_cute_vs_flex_attention_with_buffers( eager_score_mod = eager_score_mod_factory(buffer) assert buffer.shape == (batch_size,) elif cute_score_mod == score_mod_11: - head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 buffers = [head_bias, pos_scale] eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) - assert head_bias.shape == (num_heads,) + assert head_bias.shape == (num_q_heads,) assert pos_scale.shape == (seqlen_q,) out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash(q, k, v, cute_score_mod, buffers=buffers) + out_cute = run_cute_flash(q, k, v, cute_score_mod, buffers=buffers, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -432,57 +443,6 @@ def test_cute_vs_flex_attention_with_buffers( ) -@pytest.mark.xfail(raises=NotImplementedError, reason="PackGQA with score_mod not yet supported") -def test_packgqa_with_score_mod(): - """Test that PackGQA works correctly with score_mod index wrapping. - - Without proper index wrapping, q_idx will be in packed space - (0 to qhead_per_kvhead * seqlen_q - 1) instead of logical space (0 to seqlen_q - 1). - This causes causal masking to be incorrect. - """ - torch.random.manual_seed(42) - - batch_size = 2 - seqlen_q = 128 - seqlen_kv = 128 - qhead_per_kvhead = 4 - num_heads_kv = 2 - num_heads = num_heads_kv * qhead_per_kvhead - dtype = torch.bfloat16 - - q = torch.randn(batch_size, num_heads, seqlen_q, 128, device="cuda", dtype=dtype) - k = torch.randn(batch_size, num_heads_kv, seqlen_kv, 128, device="cuda", dtype=dtype) - v = torch.randn(batch_size, num_heads_kv, seqlen_kv, 128, device="cuda", dtype=dtype) - - q_transposed, k_transposed, v_transposed = map( - lambda x: x.transpose(1, 2), (q, k, v) - ) - out_cute = torch.empty_like(q_transposed) - - _flash_attn_fwd( - q_transposed, - k_transposed, - v_transposed, - return_lse=True, - score_mod=score_mod_2, - out=out_cute, - lse=None, - pack_gqa=True, - ) - out_cute = out_cute.transpose(1, 2) - - out_ref_fp32 = run_flex_reference(q, k, v, causal_mask_eager, dtype=torch.float32) - - fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() - cute_error = (out_cute - out_ref_fp32).abs().max().item() - - assert not torch.isnan(out_cute).any(), "Output contains NaN values" - assert torch.isfinite(out_cute).all(), "Output contains infinite values" - assert cute_error <= fwd_atol * 10, ( - f"CuTE error {cute_error:.2e} exceeds tolerance {fwd_atol * 10:.2e}" - ) - - @pytest.mark.xfail(raises=NotImplementedError, reason="Varlen with score_mod not yet supported") def test_varlen_with_score_mod(): """Test that varlen (variable length sequences) works with score_mod. From 04adaf0e9028d4bec7073f69e4dfa3f6d3357189 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 15 Oct 2025 12:24:52 -0700 Subject: [PATCH 142/258] Add precommit list and then uncomment in chunks (#1941) * create list to work through * include ampere --- .pre-commit-config.yaml | 29 +++++++++++++++++++++++++++-- flash_attn/cute/ampere_helpers.py | 17 +++++++++++------ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c63513faf8..0e60f835330 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,31 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ + files: ^flash_attn/cute/.*\.py$ + exclude: &cute_exclude | + (?x)^flash_attn/cute/( + __init__| + blackwell_helpers| + block_info| + copy_utils| + cute_dsl_utils| + fast_math| + flash_bwd| + flash_fwd| + flash_fwd_combine| + flash_fwd_sm100| + hopper_helpers| + interface| + mask| + mma_sm100_desc| + named_barrier| + pack_gqa| + pipeline| + seqlen_info| + testing| + tile_scheduler| + utils + )\.py$ - id: ruff-format - files: ^flash_attn/cute/(flash_bwd_sm90|flash_bwd_preprocess|flash_bwd_postprocess|softmax)\.py$ + files: ^flash_attn/cute/.*\.py$ + exclude: *cute_exclude diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py index 839f407f75c..e3072d8ce85 100644 --- a/flash_attn/cute/ampere_helpers.py +++ b/flash_attn/cute/ampere_helpers.py @@ -8,11 +8,14 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: dtype_byte = cutlass.const_expr(dtype.width // 8) bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) - smem_k_block_size = cutlass.const_expr( - 128 - if bytes_per_row % 128 == 0 - else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) - ) // dtype_byte + smem_k_block_size = ( + cutlass.const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) swizzle_bits = ( 4 if smem_k_block_size == 128 @@ -22,7 +25,9 @@ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.Compo return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), 0, - cute.make_ordered_layout((8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)), + cute.make_ordered_layout( + (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) + ), ) From 48ecd149c030dd250e1334bf59d5fe1591af9432 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 17 Oct 2025 19:06:07 -0700 Subject: [PATCH 143/258] [ROCm] prepare CK sources for pytorch hipify v2 APIs (#1944) See https://github.com/pytorch/pytorch/pull/151845. pytorch has removed caffe2, but hipify still contained work-arounds for caffe2 vs torch compatibility. As a result of hipify v2 changes, some torch APIs are changing. --- csrc/flash_attn_ck/mha_bwd.cpp | 6 +++++- csrc/flash_attn_ck/mha_fwd.cpp | 4 ++++ csrc/flash_attn_ck/mha_varlen_fwd.cpp | 4 ++++ setup.py | 22 ++++++++++++++++++++-- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1f016a4a4e6..bb879453680 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -220,7 +220,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -399,4 +403,4 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 68e28355189..4d7d5bd655e 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -272,7 +272,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 3e4422efecd..07cfa9a8f90 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -469,7 +469,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } if (max_seqlen_k > 0) { +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; if (paged_KV) diff --git a/setup.py b/setup.py index 9a406839e7f..f0b476255ba 100644 --- a/setup.py +++ b/setup.py @@ -173,6 +173,18 @@ def check_if_rocm_home_none(global_option: str) -> None: ) +def detect_hipify_v2(): + try: + from torch.utils.hipify import __version__ + from packaging.version import Version + if Version(__version__) >= Version("2.0.0"): + return True + except Exception as e: + print("failed to detect pytorch hipify version, defaulting to version 1.0.0 behavior") + print(e) + return False + + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] @@ -408,6 +420,12 @@ def validate_and_update_archs(archs): f"build/fmha_*wd*.cpp" ) + # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro, + # we must replace the incorrect APIs. + maybe_hipify_v2_flag = [] + if detect_hipify_v2(): + maybe_hipify_v2_flag = ["-DHIPIFY_V2"] + rename_cpp_to_cu(sources) renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", @@ -450,8 +468,8 @@ def validate_and_update_archs(archs): cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": cc_flag + generator_flag, + "cxx": ["-O3", "-std=c++17"] + generator_flag + maybe_hipify_v2_flag, + "nvcc": cc_flag + generator_flag + maybe_hipify_v2_flag, } include_dirs = [ From cc843a2b9e685daf20a0394fd626921b4d329b95 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 18 Oct 2025 16:04:41 -0400 Subject: [PATCH 144/258] [Cute] Add flake8 config file --- flash_attn/cute/.flake8 | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 flash_attn/cute/.flake8 diff --git a/flash_attn/cute/.flake8 b/flash_attn/cute/.flake8 new file mode 100644 index 00000000000..bae5b85c002 --- /dev/null +++ b/flash_attn/cute/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +# W503: line break before binary operator +ignore = E731, E741, F841, W503 From c712d43ace03de4ca4cf60a16b4528373e33b358 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 18 Oct 2025 17:18:40 -0400 Subject: [PATCH 145/258] [Cute,Fwd,Sm90] Load Q & K using the same mbarrier --- flash_attn/cute/flash_bwd_sm90.py | 4 +- flash_attn/cute/flash_fwd.py | 70 ++++++++++++++++++------- flash_attn/cute/hopper_helpers.py | 2 - flash_attn/cute/pipeline.py | 86 +++++++++++++++++++++++++------ 4 files changed, 123 insertions(+), 39 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 3d2ae593160..2ef8df777d8 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -498,7 +498,7 @@ def kernel( pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) - pipeline_Q = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_Q = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, @@ -506,7 +506,7 @@ def kernel( tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], init_wait=False, ) - pipeline_dO = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_dO = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_dO.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 75232662d0d..e19656664d3 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1184,9 +1184,14 @@ def __call__( gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() - self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) - self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) - self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( @@ -1355,27 +1360,28 @@ def kernel( # if tidx < 2: # # barrierO num threads should be self.num_mma_threads # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - cute.arch.mbarrier_init(mbar_ptr_Q, 1 if const_expr(self.use_tma_Q) else self.num_Q_load_threads) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) - pipeline_k = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_k = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_K.data_ptr(), num_stages=self.num_stages, producer_group=pipeline_kv_producer_group, consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_k_bytes, + tx_count=self.tma_copy_bytes["K"], init_wait=False, ) - pipeline_v = pipeline.PipelineTmaAsyncNoCluster.create( + pipeline_v = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_V.data_ptr(), num_stages=self.num_stages, producer_group=pipeline_kv_producer_group, consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_v_bytes, + tx_count=self.tma_copy_bytes["V"], ) # /////////////////////////////////////////////////////////////////////////////// @@ -1519,23 +1525,46 @@ def load( load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - # load_Q - if const_expr(self.use_tma_Q): - # TODO: wait for Q to be empty - q_producer_phase ^= 1 - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) - load_Q(tma_bar_ptr=mbar_ptr_Q) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - for i in cutlass.range(n_block_max - n_block_min, unroll=2): - n_block = n_block_max - i - 1 - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1666,7 +1695,8 @@ def mma( cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) q_consumer_phase ^= 1 # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index 1016a4189fe..c98f85b568e 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -100,5 +100,3 @@ def make_smem_layout( order=order if const_expr(stage is not None) else order[:2], ) return smem_layout_staged - - diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index b1f422068c4..89baa4a97be 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -8,8 +8,41 @@ import cutlass.cute as cute from cutlass import Boolean, Int32, const_expr from cutlass.cutlass_dsl import if_generate -from cutlass.pipeline import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait +from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup from cutlass.pipeline import PipelineUserType, PipelineOp +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg + + +# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + assert ( + False + ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." class PipelineStateSimple: @@ -89,7 +122,7 @@ def make_pipeline_state(type: PipelineUserType, stages: int): @dataclass(frozen=True) -class PipelineTmaAsyncNoCluster(PipelineAsync): +class PipelineTmaAsync(PipelineTmaAsyncOg): """ If size(ClusterShape) == 1, PipelineTmaAsync has all threads signaling the barrier during consumer_release. This causes a perf regression in FA3 @@ -103,12 +136,15 @@ class PipelineTmaAsyncNoCluster(PipelineAsync): @staticmethod def create( - barrier_storage: cute.Pointer, - num_stages: Int32, + *, + num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - init_wait: cutlass.Constexpr[bool] = True, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + init_wait: cutlass.Constexpr[bool] = True ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. @@ -116,33 +152,59 @@ def create( :type barrier_storage: cute.Pointer :param num_stages: Number of buffer stages for this pipeline :type num_stages: Int32 - :param producer_group: CooperativeGroup for the producer agent + :param producer_group: `CooperativeGroup` for the producer agent :type producer_group: CooperativeGroup - :param consumer_group: CooperativeGroup for the consumer agent + :param consumer_group: `CooperativeGroup` for the consumer agent :type consumer_group: CooperativeGroup :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + producer_type = PipelineOp.TmaLoad consumer_type = PipelineOp.AsyncThread + producer = (producer_type, producer_group) consumer = (consumer_type, consumer_group) + sync_object_full = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8), num_stages, producer, tx_count ) sync_object_empty = PipelineAsync._make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer ) - dst_rank = None + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + if const_expr(cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1): + dst_rank = None + is_signalling_thread = tidx % 128 == 0 + else: + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + producer_mask = None + if const_expr(init_wait): pipeline_init_wait() - return PipelineTmaAsyncNoCluster( + + return PipelineTmaAsync( sync_object_full, sync_object_empty, num_stages, producer_mask, dst_rank, + is_signalling_thread, ) def producer_acquire( @@ -164,12 +226,6 @@ def producer_acquire( tx_count = self.sync_object_full.tx_count + extra_tx_count self.sync_object_full.arrive_and_expect_tx(state.index, tx_count) - def producer_commit(self, state: PipelineState): - """ - TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. - """ - pass - def consumer_release(self, state: PipelineState): """ TMA consumer release conditionally signals the empty buffer to the producer. From 752c2639dc81352815b3117387f401413845eda6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 18 Oct 2025 17:31:36 -0400 Subject: [PATCH 146/258] [Cute,Bwd,Sm90] Use the same producer states if Q_stage == dO_stage --- flash_attn/cute/flash_bwd_sm90.py | 103 ++++++++++++++---------------- flash_attn/cute/flash_fwd.py | 15 ++--- 2 files changed, 53 insertions(+), 65 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 2ef8df777d8..ff80d454c30 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -8,6 +8,7 @@ import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.cute.arch import ProxyKind, SharedSpace from cutlass import Float32, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum @@ -659,14 +660,12 @@ def load( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 if warp_idx_in_wg == 0: - producer_state_Q = pipeline.make_pipeline_state( + producer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) - producer_state_dO = producer_state_Q - if const_expr(self.dO_stage != self.Q_stage): - producer_state_dO = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dO_stage - ) + producer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -716,16 +715,20 @@ def load( # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) pipeline_dO.producer_acquire( - producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"] + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) - load_dO(m_block, producer_state=producer_state_dO) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + load_dO(m_block, producer_state=producer_state_dO_cur) with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO) + load_dPsum(m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() - if const_expr(self.Q_stage != self.dO_stage): - producer_state_dO.advance() + producer_state_dO.advance() # Subsequent iterations: load Q & LSE, then dO & dPsum for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): pipeline_Q.producer_acquire(producer_state_Q) @@ -733,15 +736,17 @@ def load( # cp.async.bulk is using ptx, so we need to elect one thread to do it with cute.arch.elect_one(): load_LSE(m_block, producer_state=producer_state_Q) - if const_expr(self.Q_stage == self.dO_stage): - producer_state_dO = producer_state_Q - pipeline_dO.producer_acquire(producer_state_dO) - load_dO(m_block, producer_state=producer_state_dO) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO) + load_dPsum(m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() - if const_expr(self.dO_stage != self.Q_stage): - producer_state_dO.advance() + producer_state_dO.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -909,14 +914,12 @@ def mma( # acc_dK=acc_dK, ) - consumer_state_Q = pipeline.make_pipeline_state( + consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) - consumer_state_dO = consumer_state_Q - if const_expr(self.dO_stage != self.Q_stage): - consumer_state_dO = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage - ) + consumer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -972,8 +975,8 @@ def mma( def mma_one_m_block( self, m_block: Int32, - smem_pipe_read_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - smem_pipe_read_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, warp_group_idx: Int32, mma_qk_fn: Callable, mma_dov_fn: Callable, @@ -995,18 +998,21 @@ def mma_one_m_block( # acc_dK, dKV_accumulate: Boolean = True, ): - smem_idx_Q = smem_pipe_read_Q.index - smem_idx_dO = smem_pipe_read_dO.index + consumer_state_dO_cur = ( + consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q + ) + smem_idx_Q = consumer_state_Q.index + smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0 # (1) [GEMM 1] S = Q @ K^T - pipeline_Q.consumer_wait(smem_pipe_read_Q, pipeline_Q.consumer_try_wait(smem_pipe_read_Q)) + pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) # S2R for LSE tLSErLSE = cute.make_fragment_like(tLSEsLSE[None, 0]) cute.autovec_copy(tLSEsLSE[None, smem_idx_Q], tLSErLSE) # (2) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait( - smem_pipe_read_dO, pipeline_dO.consumer_try_wait(smem_pipe_read_dO) + consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) # (3) [Pointwise 1] P = exp(S - LSE) @@ -1063,9 +1069,7 @@ def mma_one_m_block( # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) @@ -1083,16 +1087,14 @@ def mma_one_m_block( mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier( barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads ) # (6) [GEMM 4] dQ = dS @ K acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) - pipeline_dO.consumer_release(smem_pipe_read_dO) # release dO as dV mma is done + pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q if const_expr(not self.mma_dkv_is_rs): @@ -1108,10 +1110,8 @@ def mma_one_m_block( number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, ) tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) - cute.copy(smem_thr_copy_dQaccum, tdQrdQaccum_flat, tdQsdQaccum) - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQFull), number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, @@ -1119,15 +1119,12 @@ def mma_one_m_block( warpgroup.wait_group(0) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) - pipeline_Q.consumer_release(smem_pipe_read_Q) + pipeline_Q.consumer_release(consumer_state_Q) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) - smem_pipe_read_Q.advance() - if const_expr(self.Q_stage != self.dO_stage): - smem_pipe_read_dO.advance() - else: - smem_pipe_read_dO = smem_pipe_read_Q - return smem_pipe_read_Q, smem_pipe_read_dO + consumer_state_Q.advance() + consumer_state_dO.advance() + return consumer_state_Q, consumer_state_dO @cute.jit def epilogue_dKV( @@ -1182,9 +1179,7 @@ def epilogue_dKV( taccdVsdV = smem_thr_copy_dV.partition_D(sdV) cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) # ensure smem writes are visible to TMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) @@ -1195,9 +1190,7 @@ def epilogue_dKV( taccdKsdK = smem_thr_copy_dK.partition_D(sdK) # reuse sK SMEM cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) # ensure smem writes are visible to TMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index e19656664d3..92382ae8b42 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -16,6 +16,7 @@ import cutlass.cute as cute from cutlass import Float32, Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.arch import ProxyKind, SharedSpace import cutlass.utils as utils_basic from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_basic @@ -347,7 +348,7 @@ def epilogue( # sync to make sure all smem stores are done if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) store_O, _, _ = copy_utils.tma_get_copy_fn( @@ -1723,9 +1724,7 @@ def mma( tPrP = smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_thr_copy_P, tPrP, tPsP) # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter # acc_O.fill(0.0) @@ -1860,9 +1859,7 @@ def mma_one_n_block( softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() @@ -1924,9 +1921,7 @@ def mma_one_n_block_intrawg_overlap( softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV return smem_pipe_read From 71ec343aa986084cdc780c3fe8c2497e55acb6de Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 18 Oct 2025 18:59:35 -0400 Subject: [PATCH 147/258] [Cute,Bwd,Sm90] Split sdQaccum layout into 2 warp groups --- flash_attn/cute/copy_utils.py | 9 +- flash_attn/cute/flash_bwd_postprocess.py | 64 +++++++------- flash_attn/cute/flash_bwd_sm90.py | 105 ++++++++++++----------- flash_attn/cute/named_barrier.py | 9 +- flash_attn/cute/utils.py | 48 ++++++++--- 5 files changed, 137 insertions(+), 98 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 84b3f4e2956..25263f2bd1f 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -26,12 +26,19 @@ def cvt_copy( ) -> None: assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem if const_expr(src.element_type != dst.element_type): - src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip) src_cvt.store(src.load().to(dst.element_type)) src = src_cvt cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + @dsl_user_op def get_copy_atom( dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 22b227227b0..9be406b19bb 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -127,9 +127,18 @@ def _setup_attributes(self): cute.make_layout(async_copy_elems_accum), ) num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 - self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( - Float32, self.num_threads, num_s2r_copy_elems - ) + if const_expr(self.arch == 80): + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + else: + num_threads_per_warp_group = 128 + num_mma_warp_groups = self.num_threads // 128 + self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout + cute.make_layout(128 // Float32.width), # val_layout + ) self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( self.dtype, self.tile_hdim, self.num_threads @@ -137,7 +146,13 @@ def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQaccum / dQ # /////////////////////////////////////////////////////////////////////////////// - self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + if const_expr(self.arch == 80): + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + else: + num_mma_warp_groups = self.num_threads // 128 + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + ) # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. @@ -253,6 +268,15 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], ): + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) + sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + sdQt = utils.transpose_view(sdQ) + # Thread index, block index tidx, _, _ = cute.arch.thread_idx() @@ -299,27 +323,16 @@ def kernel( ) mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) - dQaccum_shape = (self.tile_m * self.tile_hdim,) - gdQaccum = cute.local_tile(mdQaccum_cur, dQaccum_shape, (m_block,)) + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - smem = cutlass.utils.SmemAllocator() - sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) - sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) - sdQt = utils.transpose_view(sdQ) - seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) # Step 1: load dQaccum from gmem to smem g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) - tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum) - # print(tdQgdQaccum) - # print(tdQsdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(0) @@ -328,25 +341,14 @@ def kernel( # Step 2: load dQ from smem to rmem s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) - # print(s2r_tiled_copy_dQaccum) - # print(sdQaccum) - # thr_mma = tiled_mma.get_slice(tidx) - # print(tiled_mma) tile_shape = (self.tile_m, self.tile_hdim) acc_shape = tiled_mma.partition_shape_C( tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] ) acc = cute.make_fragment(acc_shape, cutlass.Float32) assert cute.size(acc) == cute.size(tdQsdQaccum) - tdQrdQaccum = s2r_thr_copy_dQaccum.retile(acc) - # Somehow even after retiling the layouts of tdQsdQaccum and tdQrdQaccum are different. - # So we have to do a for loop to copy - # cute.copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum) - # print(acc) - # print(tdQsdQaccum) # ((1, 1), 64) - # print(tdQrdQaccum) # ((1, 4), 4, 4) - for i in cutlass.range(cute.size(tdQsdQaccum), unroll_full=True): - tdQrdQaccum[i] = tdQsdQaccum[i] + tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) # Convert tdQrdQaccum from fp32 to fp16/bf16 rdQ = cute.make_fragment_like(acc, self.dtype) rdQ.store((acc.load() * scale).to(self.dtype)) @@ -362,8 +364,6 @@ def kernel( sdQ if const_expr(not self.dQ_swapAB) else sdQt ) cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) - # print(taccdQrdQ) - # print(taccdQsdQ) # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index ff80d454c30..9c8928a5b07 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -174,10 +174,15 @@ def _setup_attributes(self): ((self.tile_m, self.tile_n), self.PdS_stage), ] ] - self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + ) # dQaccum R->S - self.r2s_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( - Float32, self.num_mma_threads, num_copy_elems=128 // Float32.width + self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + # thr_layout + cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout(128 // Float32.width), # val_layout ) def _get_tiled_mma(self): @@ -346,6 +351,9 @@ def __call__( } self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = ( + self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups + ) tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), @@ -592,10 +600,11 @@ def kernel( TileSchedulerCls, ) if warp_idx == 1: - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmpty), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, - ) + for warp_group_idx in cutlass.range(self.num_mma_warp_groups): + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) self.dQaccum_store(mdQaccum, sdQaccum, block_info, TileSchedulerCls, SeqlenInfoCls) else: cute.arch.warpgroup_reg_alloc(self.num_mma_regs) @@ -1007,9 +1016,7 @@ def mma_one_m_block( # (1) [GEMM 1] S = Q @ K^T pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) - # S2R for LSE - tLSErLSE = cute.make_fragment_like(tLSEsLSE[None, 0]) - cute.autovec_copy(tLSEsLSE[None, smem_idx_Q], tLSErLSE) + tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) # (2) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait( consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) @@ -1022,21 +1029,15 @@ def mma_one_m_block( acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): - acc_S_mn[r, None].store( - cute.math.exp2( - acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True + for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): + acc_S_mn[r, c] = cute.math.exp2( + acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True ) - ) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) - # S2R for dPsum - tLSErdPsum = cute.make_fragment_like(tLSEsdPsum[None, 0]) - cute.autovec_copy(tLSEsdPsum[None, smem_idx_dO], tLSErdPsum) + tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) # Convert P from f32 -> f16 - tdVrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tdVrP = cute.make_fragment_like(tdVrP_acc, self.dtype) - utils.cvt_f16(tdVrP_acc, tdVrP) - # tdVrP.store(tdVrP_acc.load().to(self.dtype)) + tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype) # R2S for P if const_expr(not self.mma_dkv_is_rs): # sync to ensure P has already been used in the previous iteration before overwriting @@ -1052,15 +1053,11 @@ def mma_one_m_block( acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): - acc_dP_mn[r, None].store( - acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) - ) + for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): + acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) # Convert dS from f32 -> f16 - tdKrdS_acc = cute.make_tensor(acc_dP.iterator, utils.convert_layout_acc_frgA(acc_dP.layout)) - tdKrdS = cute.make_fragment_like(tdKrdS_acc, self.dtype) - utils.cvt_f16(tdKrdS_acc, tdKrdS) - # tdKrdS.store(tdKrdS_acc.load().to(self.dtype)) + tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) # If there's double buffering on dS, we don't need to sync here. # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. @@ -1106,15 +1103,15 @@ def mma_one_m_block( # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.dQEmpty), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQFull), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) warpgroup.wait_group(0) @@ -1147,9 +1144,7 @@ def epilogue_dKV( ): rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) - rdK = cute.make_fragment_like(acc_dK, self.dtype) - # rdK.store(acc_dK.load().to(self.dtype)) - utils.cvt_f16(acc_dK, rdK) + rdK = utils.cvt_f16(acc_dK, self.dtype) cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads @@ -1209,29 +1204,39 @@ def dQaccum_store( TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], ): - cpasync_bulk_bytes = self.tile_m * self.tile_hdim * Float32.width // 8 tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / WG, WG, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) + ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.dQFull), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, - ) - with cute.arch.elect_one(): - copy_utils.cpasync_reduce_bulk_add_f32( - sdQaccum.iterator, gdQaccum[None, m_block].iterator, cpasync_bulk_bytes + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block].iterator, + self.tma_copy_bytes["dQ"], + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + with cute.arch.elect_one(): + cute.arch.cp_async_bulk_wait_group( + self.num_mma_warp_groups - 1 - warp_group_idx, read=True + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmpty), - number_of_threads=self.num_mma_threads + cute.arch.WARP_SIZE, - ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 5a7f52e7497..1000c0a47bc 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -18,8 +18,7 @@ class NamedBarrierBwd(enum.IntEnum): WarpSchedulerWG2 = enum.auto() WarpSchedulerWG3 = enum.auto() PdS = enum.auto() - #dQEmpty = 9 - #dQEmpty = 9 - - dQFull = enum.auto() - dQEmpty = enum.auto() + dQFullWG0 = enum.auto() + dQFullWG1 = enum.auto() + dQEmptyWG0 = enum.auto() + dQEmptyWG1 = enum.auto() diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 2851d59c84d..3d4b8d2d316 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -3,7 +3,7 @@ import math import hashlib import inspect -from typing import Type, Callable, Optional, Tuple +from typing import Type, Callable, Optional, Tuple, overload from functools import partial import cutlass @@ -210,6 +210,10 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: return rA_mma_view +def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) @@ -513,16 +517,40 @@ def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc ) +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + @cute.jit -def cvt_f16(src: cute.Tensor, dst: cute.Tensor): - assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" - assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" - assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" - assert src.element_type is Float32, "src must be Float32" - dst_i32 = cute.recast_tensor(dst, cutlass.Int32) - assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) - for i in cutlass.range_constexpr(cute.size(dst_i32)): - dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_fragment(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) @cute.jit From 7a3a8fe506080ca3effe18d35618962cfbbb547a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 00:34:37 -0400 Subject: [PATCH 148/258] [Cute,Bwd,Sm90] Implement masking --- .pre-commit-config.yaml | 3 - flash_attn/cute/flash_bwd_sm90.py | 14 +- flash_attn/cute/mask.py | 311 ++++++++++++++++++------------ flash_attn/cute/pipeline.py | 8 +- 4 files changed, 196 insertions(+), 140 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e60f835330..0cb9effad2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,11 +19,8 @@ repos: flash_fwd_sm100| hopper_helpers| interface| - mask| mma_sm100_desc| - named_barrier| pack_gqa| - pipeline| seqlen_info| testing| tile_scheduler| diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 9c8928a5b07..bfb67824be0 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -1023,11 +1023,10 @@ def mma_one_m_block( ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) # (3) [Pointwise 1] P = exp(S - LSE) - # if cutlass.const_expr(mask_fn is not None): - if cutlass.const_expr(mask_fn is not None and not self.SdP_swapAB): # TODO: impl mask + if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) + # if cute.arch.thread_idx()[0] == 256: cute.print_tensor(acc_S_mn) for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): acc_S_mn[r, c] = cute.math.exp2( @@ -1228,12 +1227,11 @@ def dQaccum_store( gdQaccum[None, warp_group_idx, m_block].iterator, self.tma_copy_bytes["dQ"], ) - cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_commit_group() for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): - with cute.arch.elect_one(): - cute.arch.cp_async_bulk_wait_group( - self.num_mma_warp_groups - 1 - warp_group_idx, read=True - ) + cute.arch.cp_async_bulk_wait_group( + self.num_mma_warp_groups - 1 - warp_group_idx, read=True + ) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 1da693141cf..562f7900096 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -5,153 +5,202 @@ import cutlass import cutlass.cute as cute +from cutlass import Int32, const_expr import flash_attn.cute.utils as utils +@cute.jit +def mask_r2p_sm90(X: cute.Tensor, col_limit: Int32) -> None: + # R2P trick: Instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., + # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # This is so that we can use the R2P instruction. + assert cute.rank(X) in [1, 2], "mask_r2p_sm90 only supports rank 1 or 2 tensors" + col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1])) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + mask = (1 << col_limit_right_s) - 1 + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + if const_expr(cute.rank(X) == 1): + X[c] = X[c] if in_bound else -cutlass.Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -cutlass.Float32.inf + + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] tile_n: cutlass.Constexpr[int] - seqlen_q: cutlass.Int32 - seqlen_k: cutlass.Int32 - window_size_left: Optional[cutlass.Int32] = None - window_size_right: Optional[cutlass.Int32] = None + seqlen_q: Int32 + seqlen_k: Int32 + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA + swap_AB: cutlass.Constexpr[bool] = False @cute.jit def apply_mask( self, acc_S: cute.Tensor, - m_block: cutlass.Int32, - n_block: cutlass.Int32, + m_block: Int32, + n_block: Int32, thr_mma: cute.TiledMma, mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, ) -> None: - # TODO: implement swap_AB assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" - acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) - tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB) # We use t0ScS as these indices are known at compile time. We then must subtract the # column limit by the thread column offset. - t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) - thr_col_offset = tScS_mn[0][1] + t0ScS_mn = utils.make_acc_tensor_mn_view( + thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB + ) + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + thr_col_offset = tScS_mn[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset - if cutlass.const_expr(not mask_causal and not mask_local): - if cutlass.const_expr(mask_seqlen): - if cutlass.const_expr(True): - # traverse column index. + if const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + # The compiler now choses not to use R2P + r2p = const_expr(False and not self.swap_AB) + if const_expr(not r2p): for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit + oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] - else: # R2P trick, see apply_mask_sm100 - # Instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., - # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - # This is so that we can use the R2P instruction. - col_limit_transformed = seqlenk_col_limit // 8 * 2 + min(seqlenk_col_limit % 8, 2) - ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - col_limit_right_s = max(col_limit_transformed - s * 24, 0) - mask = (1 << col_limit_right_s) - 1 - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): - acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf + else: + mask_r2p_sm90(acc_S_mn, seqlenk_col_limit) else: # Causal or local - # If PackGQA, we split the work of compute divmod among threads in the same row - threads_per_row = thr_mma.tv_layout_C.shape[0][0] - if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): - assert cute.arch.WARP_SIZE % threads_per_row == 0, ( - "threads_per_row must divide WARP_SIZE" + if const_expr(not self.swap_AB): + # If PackGQA, we split the work of compute divmod among threads in the same row + threads_per_row = thr_mma.tv_layout_C.shape[0][0] + mma_m_idx = None + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert not self.swap_AB, "swap_AB with PackGQA not supported yet" + assert cute.arch.WARP_SIZE % threads_per_row == 0, ( + "threads_per_row must divide WARP_SIZE" + ) + assert cute.size(acc_S_mn.shape[0]) <= threads_per_row + tidx = thr_mma.thr_idx + mma_m_idx = ( + m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset ) - assert cute.size(acc_S_mn.shape[0]) <= threads_per_row - tidx = thr_mma.thr_idx - mma_m_idx = ( - m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] - ) // self.qhead_per_kvhead_packgqa - causal_row_offset = ( - 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset - ) - c = 0 - col_limit_transformed = 0 - ncol: cute.Constexpr = 0 - col_limit_right_s = 0 - mask = 0 - in_bound = False - if cutlass.const_expr(mask_causal): - for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): - # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. - if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m - else: - row_idx = utils.shuffle_sync( - mma_m_idx, r % threads_per_row, width=threads_per_row + if const_expr(mask_causal): + r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100 + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + if const_expr(not r2p): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + acc_S_mn[r, c] = ( + -cutlass.Float32.inf + if t0ScS_mn[0, c][1] >= col_limit_right + else acc_S_mn[r, c] + ) + else: + mask_r2p_sm90(acc_S_mn[r, None], col_limit_right) + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if const_expr(self.window_size_left is not None) + else None + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + else: + col_limit_right = self.tile_n + col_limit_left = ( + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 ) - col_limit_right = row_idx + causal_row_offset - if cutlass.const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - if cutlass.const_expr(True): + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] - else: # R2P trick, see apply_mask_sm100 - col_limit_transformed = col_limit_right // 8 * 2 + min(col_limit_right % 8, 2) - ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - col_limit_right_s = max(col_limit_transformed - s * 24, 0) - mask = (1 << col_limit_right_s) - 1 - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf - else: # Local - local_row_offset_right = ( - causal_row_offset + self.window_size_right - if cutlass.const_expr(self.window_size_right is not None) - else None - ) - local_row_offset_left = ( - causal_row_offset - 1 - self.window_size_left - if cutlass.const_expr(self.window_size_left is not None) - else None + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -cutlass.Float32.inf + else: # swap_AB + assert self.qhead_per_kvhead_packgqa == 1 + thr_row_offset = tScS_mn[0][ROW] + causal_row_offset = ( + seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset ) - c = 0 - for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): - if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): - row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m - else: - row_idx = utils.shuffle_sync( - mma_m_idx, r % threads_per_row, width=threads_per_row + if const_expr(mask_causal): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m if col0 >= seqlenk_col_limit else col0 - causal_row_offset ) - if cutlass.const_expr(self.window_size_right is not None): - col_limit_right = row_idx + local_row_offset_right - if cutlass.const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) - else: - col_limit_right = self.tile_n - col_limit_left = ( - row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) - # traverse column index. + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = ( + -cutlass.Float32.inf + if t0ScS_mn[r, 0][ROW] < row_limit_top + else acc_S_mn[r, c] + ) + else: for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - col_idx = t0ScS_mn[0, c][1] - acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) - # only consider the column index, so the row index sets to 0. - if col_idx >= col_limit_right or col_idx < col_limit_left: - acc_S_mn[r, c] = -cutlass.Float32.inf + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m + if col0 >= seqlenk_col_limit + else col0 - causal_row_offset - self.window_size_right + ) + # TODO: do we need col_limit_sink? + row_limit_bot = col0 - causal_row_offset + self.window_size_left + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + row_idx = t0ScS_mn[r, 0][ROW] + acc_S_mn[r, c] = ( + -cutlass.Float32.inf + if row_idx < row_limit_top or row_idx > row_limit_bot + else acc_S_mn[r, c] + ) @cute.jit def apply_mask_sm100( self, acc_S: cute.Tensor, - m_block: cutlass.Int32, - n_block: cutlass.Int32, + m_block: Int32, + n_block: Int32, thr_mma: cute.TiledMma, thr_tmem_load: cute.TiledCopy, mask_seqlen: cutlass.Constexpr, @@ -163,16 +212,18 @@ def apply_mask_sm100( tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - if cutlass.const_expr(not mask_causal and not mask_local): - if cutlass.const_expr(mask_seqlen): - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) - if cutlass.const_expr(False): + if const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + ncol = const_expr(cute.size(tScS_t2r.shape)) + if const_expr(False): for i in cutlass.range(ncol, unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -cutlass.Float32.inf # For some reason the 2 lines above generate really bad SASS acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + -cutlass.Float32.inf + if tScS_t2r[i][1] >= seqlenk_col_limit + else acc_S[i] ) else: # Bit manipulation, compiles down to the R2P instruction @@ -193,24 +244,28 @@ def apply_mask_sm100( # the R2P instruction, so it's slower. # Instead we just move by 24 instead of 32. # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) - acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf + acc_S[s * 24 + i] = ( + acc_S[s * 24 + i] + if cutlass.Boolean(mask & (1 << i)) + else -cutlass.Float32.inf + ) # This is the equivalent of: # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf # if tidx == 0: cute.print_tensor(acc_S) else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m - if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + if const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa c = 0 - if cutlass.const_expr(mask_causal): + if const_expr(mask_causal): col_limit_right = row_idx + causal_row_offset - if cutlass.const_expr(mask_seqlen): + if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) - if cutlass.const_expr(False): + ncol = const_expr(cute.size(tScS_t2r.shape)) + if const_expr(False): for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] @@ -225,28 +280,34 @@ def apply_mask_sm100( # This needs to be range_constexpr, otherwise the compiler can't generate # the R2P instruction for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf + acc_S[s * 24 + i] = ( + acc_S[s * 24 + i] + if cutlass.Boolean(mask & (1 << i)) + else -cutlass.Float32.inf + ) # This is the equivalent of: # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf else: local_row_offset_right = ( causal_row_offset + self.window_size_right - if cutlass.const_expr(self.window_size_right is not None) + if const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( causal_row_offset - 1 - self.window_size_left - if cutlass.const_expr(self.window_size_left is not None) + if const_expr(self.window_size_left is not None) else None ) - if cutlass.const_expr(self.window_size_right is not None): + if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right - if cutlass.const_expr(mask_seqlen): + if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: col_limit_right = self.tile_n col_limit_left = ( - row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 ) # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 89baa4a97be..0dbc905b35b 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -40,9 +40,9 @@ def _sync(group: Agent): cute.arch.cluster_arrive_relaxed() cute.arch.cluster_wait() else: - assert ( - False - ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + assert False, ( + "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + ) class PipelineStateSimple: @@ -144,7 +144,7 @@ def create( barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, tidx: Optional[Int32] = None, - init_wait: cutlass.Constexpr[bool] = True + init_wait: cutlass.Constexpr[bool] = True, ): """ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. From 75fcbf2ac1c4821510ffbf631240bd71adc5d53c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 09:43:56 -0400 Subject: [PATCH 149/258] [Cute,Fwd,Sm100] Parse swizzle from pointer, don't need to pass in --- flash_attn/cute/blackwell_helpers.py | 24 +++++++++++------------- flash_attn/cute/flash_fwd_sm100.py | 10 ++-------- flash_attn/cute/utils.py | 25 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index ad5124c04ce..0ec5af90826 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -7,6 +7,7 @@ from cutlass._mlir.dialects import llvm import flash_attn.cute.mma_sm100_desc as sm100_desc +from flash_attn.cute.utils import parse_swizzle_from_pointer @cute.jit @@ -36,18 +37,16 @@ def gemm_ptx( tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, - sA_swizzle: Optional[cute.Swizzle], - sB_swizzle: cute.Swizzle, zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" - assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, @@ -59,6 +58,7 @@ def gemm_ptx( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, @@ -135,18 +135,16 @@ def gemm_ptx_loop( tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, - sA_swizzle: Optional[cute.Swizzle], - sB_swizzle: cute.Swizzle, zero_init: bool | cutlass.Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" - assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, @@ -158,6 +156,7 @@ def gemm_ptx_loop( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, @@ -277,8 +276,6 @@ def gemm_ptx_partial( tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, - sA_swizzle: Optional[cute.Swizzle], - sB_swizzle: cute.Swizzle, mbar_ptr: Optional[cutlass.Pointer] = None, mbar_phase: Optional[cutlass.Int32] = None, zero_init: bool | cutlass.Boolean = False, @@ -286,11 +283,11 @@ def gemm_ptx_partial( is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if cutlass.const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" - assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, @@ -302,6 +299,7 @@ def gemm_ptx_partial( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, @@ -329,8 +327,8 @@ def gemm_ptx_partial( None, [ # acc.iterator.toint().ir_value(), - cutlass.Int32(smem_desc_start_a_lo).ir_value(), - cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), cutlass.Int32(not zero_init).ir_value(), ], "{\n\t" @@ -370,8 +368,8 @@ def gemm_ptx_partial( ) else: input_args = [ - cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), - cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), cutlass.Int32(not zero_init).ir_value(), ] if cutlass.const_expr(mbar_ptr is not None): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0a93f3d044f..86994d27c66 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -757,9 +757,6 @@ def kernel( sQ, sK, sV, - sQ_layout.inner, - sK_layout.inner, - sV_layout.inner, tStSs, tOtOs, tOrPs, @@ -984,9 +981,6 @@ def mma( sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, - sQ_swizzle: cute.Swizzle, - sK_swizzle: cute.Swizzle, - sV_swizzle: cute.Swizzle, tStSs: Tuple[cute.Tensor, cute.Tensor], tOtOs: tuple[cute.Tensor], tOrPs: Tuple[cute.Tensor, cute.Tensor], @@ -1012,7 +1006,7 @@ def mma( partial( sm100_utils.gemm_ptx_partial, qk_mma_op, self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], - sA_swizzle=sQ_swizzle, sB_swizzle=sK_swizzle, zero_init=True + zero_init=True ) for stage in range(2) ] @@ -1020,7 +1014,7 @@ def mma( partial( sm100_utils.gemm_ptx_partial, pv_mma_op, self.tmem_o_offset[stage if self.q_stage == 2 else 0], tOrPs[stage], - sA=None, sA_swizzle=None, sB_swizzle=sV_swizzle + sA=None ) for stage in range(2) ] diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 3d4b8d2d316..33c71c66ad4 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -3,6 +3,7 @@ import math import hashlib import inspect +import re from typing import Type, Callable, Optional, Tuple, overload from functools import partial @@ -225,6 +226,30 @@ def transpose_view(a: cute.Tensor) -> cute.Tensor: return cute.composition(a, cute.make_ordered_layout(shape, order=order)) +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r'S<(\d+),(\d+),(\d+)>', swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return cute.make_swizzle(b, m, s) + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + @cute.jit def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: """exp2f calculation for both vector and scalar. From b5e9a71ae423c690ec6e486821e1458ba3d22faa Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 11:53:46 -0400 Subject: [PATCH 150/258] [Cute,Fwd,Sm100] Clean up --- flash_attn/cute/flash_fwd_sm100.py | 232 +++++++++++++---------------- flash_attn/cute/pipeline.py | 113 +++++++++++++- flash_attn/cute/utils.py | 6 +- 3 files changed, 221 insertions(+), 130 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 86994d27c66..7bf1480bbae 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ import enum import math -from typing import Type, Tuple, Callable, Optional +from typing import Type, Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda @@ -27,7 +27,8 @@ import cutlass.utils.blackwell_helpers as sm100_utils_basic import flash_attn.cute.utils as utils -# import flash_attn.cute.pipeline as pipeline +from flash_attn.cute import copy_utils +import flash_attn.cute.pipeline as pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -131,8 +132,6 @@ def __init__( ) ) - self.tmem_alloc_sync_bar_id = 1 - self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded @@ -398,9 +397,14 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) - self.tma_copy_k_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) - self.tma_copy_v_bytes = cute.size_in_bytes(self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, sQ_layout), + ("K", mK, sK_layout), + ("V", mV, sV_layout), + ] + } if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -645,10 +649,8 @@ def kernel( # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - # sQ_pi = storage.sQ.get_tensor(sQ_layout) # (MMA, MMA_K, MMA_D, PIPE) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - # sK_pi = storage.sK.get_tensor(sK_layout) # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) @@ -662,7 +664,7 @@ def kernel( thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM - qk_acc_shape = thr_mma_qk.partition_shape_C((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. @@ -670,7 +672,7 @@ def kernel( assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) - pv_acc_shape = thr_mma_pv.partition_shape_C((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) tStSs = tuple(cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) @@ -880,17 +882,15 @@ def load( ): q_producer_phase = Int32(1) - kv_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.kv_stage) + kv_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.kv_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - if const_expr(not seqlen.has_cu_seqlens_q): - mQ_cur = mQ[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx @@ -910,12 +910,8 @@ def load( tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sQ, 0, 3), - cute.group_modes(tSgQ, 0, 3), + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ ) tKsK, tKgK = cpasync.tma_partition( tma_atom_K, @@ -933,7 +929,7 @@ def load( ) load_Q = partial( - self.load_Q, tma_atom_Q, tQgQ, tQsQ, + self.load_Q, load_Q_fn, mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, phase=q_producer_phase, ) @@ -1005,7 +1001,10 @@ def mma( gemm_Si = [ partial( sm100_utils.gemm_ptx_partial, - qk_mma_op, self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], + qk_mma_op, + self.tmem_s_offset[stage], + tSrQs[stage], + sA=sQ[None, None, None, stage], zero_init=True ) for stage in range(2) @@ -1013,8 +1012,10 @@ def mma( gemm_Pi = [ partial( sm100_utils.gemm_ptx_partial, - pv_mma_op, self.tmem_o_offset[stage if self.q_stage == 2 else 0], tOrPs[stage], - sA=None + pv_mma_op, + self.tmem_o_offset[stage if self.q_stage == 2 else 0], + tOrPs[stage], + sA=None, ) for stage in range(2) ] @@ -1075,14 +1076,23 @@ def mma( # For the first iteration in this work tile, waiting for O0/O1_partial # means that the correction warps has finished reading tO during # the last iteration of the previous work tile has finished. - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase + ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) sV_cur = sV[None, None, None, Vi_index] if const_expr(self.uneven_kv_smem): sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) - gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase= P_full_O_rescaled_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase + ) # 4. release accumulated O0_partial / O1_partial # Don't need to signal O_full to the correction warps anymore since the # correction warps wait for the softmax warps anyway. By the time the softmax @@ -1134,14 +1144,23 @@ def mma( tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(2): # 2. acquire corrected Oi_partial and Pi - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase + ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) sV_cur = sV[None, None, None, Vi_index] if const_expr(self.uneven_kv_smem): sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) - gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase + ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp # has signaled to the correction warp, the softmax warp has just finished compute @@ -1199,13 +1218,9 @@ def softmax_loop( ) ) - cS_base = cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) - tScS = thr_mma_qk.partition_C(cS_base) - - tStS_scale_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, 1))) - tStScale = cute.make_tensor(tStSi.iterator, tStS_scale_layout) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) - tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) @@ -1223,12 +1238,10 @@ def softmax_loop( thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) - tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32, ) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) - thr_tmem_store = tiled_tmem_store.get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) mma_si_consumer_phase = Int32(0) @@ -1248,7 +1261,12 @@ def softmax_loop( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask_sm100, m_block=self.q_stage * m_block + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local + mask.apply_mask_sm100, + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local ) softmax = SoftmaxSm100.create(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) softmax.reset() @@ -1305,6 +1323,7 @@ def softmax_loop( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) # tSrScale_r2t[0] = softmax.row_sum[0] # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) @@ -1385,18 +1404,13 @@ def softmax_step( 6. Coordinating pipeline synchronization between different processing stages """ tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width - tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) - tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) - - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - - tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32))) # Wait for Si cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) - tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) + tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype) cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) if cutlass.const_expr(self.score_mod is not None): self.apply_score_mod( @@ -1417,7 +1431,7 @@ def softmax_step( row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) if const_expr(not is_first): - # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32) # tSrScale_r2t[0] = acc_scale # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() @@ -1479,21 +1493,19 @@ def correction_loop( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) - tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) - tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStS_scale_layout) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) + tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) for stage in range(2)) - tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) - tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tmem_load_v_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, ) - tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]) - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(tidx) + thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(2)] - tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape # First iter: no correction is required cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) @@ -1640,7 +1652,7 @@ def correction_rescale( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - thread_idx: Int32, + tidx: Int32, scale: Float32, ): """Rescale intermediate attention results based on softmax normalization factor. @@ -1655,9 +1667,7 @@ def correction_rescale( 2. Apply the scaling factor to all elements 3. Store the rescaled results back to tensor memory """ - cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) - tOcO = thr_mma.partition_C(cO) - + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) corr_tile_size = 16 # tuneable parameter tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), @@ -1667,17 +1677,10 @@ def correction_rescale( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype, ) - - tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) - thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) - thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) - + tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx) tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) @@ -1685,17 +1688,15 @@ def correction_rescale( frg_count = self.head_dim_v_padded // corr_tile_size tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) for i in cutlass.range_constexpr(frg_count): - tOrO_frg_i = tOrO_frg[None, i] - tTMrO_i_layout = cute.composition(tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0])) - tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) + tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype) tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) - cute.copy(tiled_tmem_load, tOtO_t2r_i, tTMrO_i) - for j in cutlass.range_constexpr(0, cute.size(tTMrO_i), 2): - tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( - (tTMrO_i[j], tTMrO_i[j + 1]), (scale, scale), + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): + tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), ) tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) - cute.copy(tiled_tmem_store, tTMrO_i, tOtO_r2t_i) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) cute.arch.fence_view_async_tmem_store() @cute.jit @@ -1703,7 +1704,7 @@ def correction_epilogue( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, - thread_idx: Int32, + tidx: Int32, scale: Float32, sO: cute.Tensor, ): @@ -1730,10 +1731,9 @@ def correction_epilogue( :type sO: cute.Tensor """ - cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) corr_tile_size = 32 * 8 // self.o_dtype.width tOsO = thr_mma.partition_C(sO) - tOcO = thr_mma.partition_C(cO) + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) @@ -1748,23 +1748,16 @@ def correction_epilogue( epi_subtile, use_2cta_instrs=False, ) - - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) - - thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice(tidx) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) smem_copy_atom = sm100_utils_basic.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load ) - tiled_smem_store = cute.make_tiled_copy( - smem_copy_atom, - layout_tv=tiled_tmem_load.layout_dst_tv_tiled, - tiler_mn=tiled_tmem_load.tiler_mn, - ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) - for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] @@ -1774,11 +1767,9 @@ def correction_epilogue( tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), ) - tSMrO = cute.make_fragment(tOrO_frg.shape, self.o_dtype) - o_vec = tOrO_frg.load() - tSMrO.store(o_vec.to(self.o_dtype)) - cute.copy(tiled_smem_store, tSMrO, tOsO_r2s_i) - + tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) + cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) # fence view async shared cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, @@ -1801,26 +1792,20 @@ def epilogue_s2g( while work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - if const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) if const_expr(self.use_tma_O): - tOsO, tOgO = cpasync.tma_partition( - tma_atom_O, - 0, - cute.make_layout(1), - cute.group_modes(sO, 0, 2), - cute.group_modes(gO, 0, 2), + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO ) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) # 2. copy O0 / O1 to gmem - cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) + store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) cute.arch.cp_async_bulk_commit_group() for stage in cutlass.range_constexpr(self.q_stage): # Ensure O0 / O1 buffer is ready to be released @@ -1867,9 +1852,7 @@ def epilogue_s2g( def load_Q( self, - tma_atom: cute.CopyAtom, - tQgQ: cute.Tensor, - tQsQ: cute.Tensor, + load_Q_fn: Callable, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, block: Int32, @@ -1878,10 +1861,8 @@ def load_Q( ): cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_q_bytes) - cute.copy( - tma_atom, tQgQ[None, block], tQsQ[None, stage], tma_bar_ptr=mbar_full_ptr + stage - ) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"]) + load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage) @cute.jit def load_KV( @@ -1893,11 +1874,10 @@ def load_KV( mbar_empty_ptr: cute.Pointer, block: Int32, producer_state: cutlass.pipeline.PipelineState, - K_or_V: str, + K_or_V: Literal["K", "V"], page_idx: Optional[Int32] = None, ): assert K_or_V in ("K", "V") - tma_copy_bytes = self.tma_copy_k_bytes if const_expr(K_or_V == "K") else self.tma_copy_v_bytes stage, phase = producer_state.index, producer_state.phase cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) if const_expr(K_or_V == "K" and self.uneven_kv_smem): @@ -1906,7 +1886,7 @@ def load_KV( if stage == 0: cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V]) tXsX_cur = tXsX[None, stage] if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it @@ -1935,7 +1915,7 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): num_stages=self.kv_stage, producer_group=load_kv_producer_group, consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_k_bytes, + tx_count=self.tma_copy_bytes["K"], ) # @cute.jit diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 0dbc905b35b..541b0b5bed7 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -11,6 +11,7 @@ from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup from cutlass.pipeline import PipelineUserType, PipelineOp from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg # We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed @@ -231,7 +232,115 @@ def consumer_release(self, state: PipelineState): TMA consumer release conditionally signals the empty buffer to the producer. """ # Only 1 thread per warp group signals the empty buffer. + if self.consumer_mask is None: # No cluster, 1 thread per warp group to signal + if_generate( + cute.arch.thread_idx()[0] % 128 == 0, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + else: + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + init_wait: cutlass.Constexpr[bool] = True, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + if const_expr(init_wait): + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ if_generate( - cute.arch.thread_idx()[0] % 128 == 0, - lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count), + ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 33c71c66ad4..4db768e328c 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -222,8 +222,10 @@ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: def transpose_view(a: cute.Tensor) -> cute.Tensor: """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) - order = (1, 0, *range(2, cute.rank(a))) - return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + # order = (1, 0, *range(2, cute.rank(a))) + # return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: From b4fac7d71bdbccf03dda1c5eddccdffb955ca2fe Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 12:07:16 -0400 Subject: [PATCH 151/258] [Cute,Fwd,Sm100] Clean up mask --- flash_attn/cute/mask.py | 107 +++++++++++++--------------------------- 1 file changed, 35 insertions(+), 72 deletions(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 562f7900096..83046dec6a4 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -5,30 +5,39 @@ import cutlass import cutlass.cute as cute -from cutlass import Int32, const_expr +from cutlass import Float32, Int32, const_expr import flash_attn.cute.utils as utils @cute.jit -def mask_r2p_sm90(X: cute.Tensor, col_limit: Int32) -> None: - # R2P trick: Instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., +def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using. + # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - # This is so that we can use the R2P instruction. - assert cute.rank(X) in [1, 2], "mask_r2p_sm90 only supports rank 1 or 2 tensors" - col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) - ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1])) + if const_expr(arch == 90): + col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) + else: + col_limit_transformed = col_limit + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + # Don't need to clamp to 32 since the shr.u32 instruction does that already col_limit_right_s = max(col_limit_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 mask = (1 << col_limit_right_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction for i in cutlass.range_constexpr(min(24, ncol - s * 24)): in_bound = cutlass.Boolean(mask & (1 << i)) c = s * 24 + i - if const_expr(cute.rank(X) == 1): - X[c] = X[c] if in_bound else -cutlass.Float32.inf + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + # This is the equivalent of: + # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf else: for r in cutlass.range_constexpr(cute.size(X.shape[0])): - X[r, c] = X[r, c] if in_bound else -cutlass.Float32.inf + X[r, c] = X[r, c] if in_bound else -Float32.inf @dataclass(frozen=True) @@ -75,9 +84,9 @@ def apply_mask( for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): - acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] + acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: - mask_r2p_sm90(acc_S_mn, seqlenk_col_limit) + mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row @@ -113,12 +122,12 @@ def apply_mask( # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): acc_S_mn[r, c] = ( - -cutlass.Float32.inf + -Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] ) else: - mask_r2p_sm90(acc_S_mn[r, None], col_limit_right) + mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True) else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right @@ -154,7 +163,7 @@ def apply_mask( col_idx = t0ScS_mn[0, c][1] # only consider the column index, so the row index sets to 0. if col_idx >= col_limit_right or col_idx < col_limit_left: - acc_S_mn[r, c] = -cutlass.Float32.inf + acc_S_mn[r, c] = -Float32.inf else: # swap_AB assert self.qhead_per_kvhead_packgqa == 1 thr_row_offset = tScS_mn[0][ROW] @@ -171,7 +180,7 @@ def apply_mask( ) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = ( - -cutlass.Float32.inf + -Float32.inf if t0ScS_mn[r, 0][ROW] < row_limit_top else acc_S_mn[r, c] ) @@ -190,7 +199,7 @@ def apply_mask( for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): row_idx = t0ScS_mn[r, 0][ROW] acc_S_mn[r, c] = ( - -cutlass.Float32.inf + -Float32.inf if row_idx < row_limit_top or row_idx > row_limit_bot else acc_S_mn[r, c] ) @@ -212,52 +221,23 @@ def apply_mask_sm100( tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n + r2p = True if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): ncol = const_expr(cute.size(tScS_t2r.shape)) - if const_expr(False): + if const_expr(not r2p): for i in cutlass.range(ncol, unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: - # acc_S[i] = -cutlass.Float32.inf + # acc_S[i] = -Float32.inf # For some reason the 2 lines above generate really bad SASS - acc_S[i] = ( - -cutlass.Float32.inf - if tScS_t2r[i][1] >= seqlenk_col_limit - else acc_S[i] - ) + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: - # Bit manipulation, compiles down to the R2P instruction - # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using - # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 - # (see below). - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - # Don't need to clamp to 32 since the shr.u32 instruction does that already - col_limit_right_s = max(seqlenk_col_limit - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << col_limit_right_s) - 1 - # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_s = %d", mask, col_limit_right_s, col_limit_right_s) - # This needs to be range_constexpr, otherwise the compiler can't generate - # the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - # mask >> i does not produce correct result for 0b11..11 >> 31 - # However, if we use utils.shr_u32, the compiler doesn't generate - # the R2P instruction, so it's slower. - # Instead we just move by 24 instead of 32. - # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) - acc_S[s * 24 + i] = ( - acc_S[s * 24 + i] - if cutlass.Boolean(mask & (1 << i)) - else -cutlass.Float32.inf - ) - # This is the equivalent of: - # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf - # if tidx == 0: cute.print_tensor(acc_S) + mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa - c = 0 if const_expr(mask_causal): col_limit_right = row_idx + causal_row_offset if const_expr(mask_seqlen): @@ -265,28 +245,11 @@ def apply_mask_sm100( # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) ncol = const_expr(cute.size(tScS_t2r.shape)) - if const_expr(False): + if const_expr(not r2p): for i in cutlass.range(ncol, unroll_full=True): - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] - ) + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] else: - # Bit manipulation, compiles down to the R2P instruction - # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - col_limit_right_s = max(col_limit_right - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << col_limit_right_s) - 1 - # This needs to be range_constexpr, otherwise the compiler can't generate - # the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - acc_S[s * 24 + i] = ( - acc_S[s * 24 + i] - if cutlass.Boolean(mask & (1 << i)) - else -cutlass.Float32.inf - ) - # This is the equivalent of: - # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + mask_r2p(acc_S, col_limit_right, arch=100, rank1=True) else: local_row_offset_right = ( causal_row_offset + self.window_size_right @@ -313,7 +276,7 @@ def apply_mask_sm100( for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): col_idx = tScS_t2r[i][1] acc_S[i] = ( - -cutlass.Float32.inf + -Float32.inf if col_idx >= col_limit_right or col_idx < col_limit_left else acc_S[i] ) From 9c14873cd4b06a4f9788e822fb36b5ee826c69ef Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 12:12:10 -0400 Subject: [PATCH 152/258] [Cute] Reformat blackwell_helpers.py, block_info.py --- .pre-commit-config.yaml | 2 - flash_attn/cute/blackwell_helpers.py | 230 ++++++++++++++++++--------- flash_attn/cute/block_info.py | 18 +-- flash_attn/cute/mask.py | 6 +- 4 files changed, 166 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0cb9effad2e..291258fe1de 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,6 @@ repos: exclude: &cute_exclude | (?x)^flash_attn/cute/( __init__| - blackwell_helpers| - block_info| copy_utils| cute_dsl_utils| fast_math| diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 0ec5af90826..4f61a40cdc3 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -3,7 +3,6 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import tcgen05 -from cutlass.cutlass_dsl import T from cutlass._mlir.dialects import llvm import flash_attn.cute.mma_sm100_desc as sm100_desc @@ -47,11 +46,15 @@ def gemm_ptx( idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), - sA_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) @@ -59,24 +62,36 @@ def gemm_ptx( smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), - sB_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo + ) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo + ) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): if cutlass.const_expr(not is_ts): - smem_desc_a_lo = smem_desc_start_a_lo + ((cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4) - smem_desc_b_lo = smem_desc_start_b_lo + ((cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4) + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) # with cute.arch.elect_one(): # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) @@ -127,6 +142,7 @@ def gemm_ptx( asm_dialect=llvm.AsmDialect.AD_ATT, ) + @cute.jit def gemm_ptx_loop( op: cute.nvgpu.tcgen05.mma.MmaOp, @@ -145,11 +161,15 @@ def gemm_ptx_loop( idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), - sA_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) @@ -157,31 +177,49 @@ def gemm_ptx_loop( smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), - sB_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) if cutlass.const_expr(not is_ts): - offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 - for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] else: - offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 - for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] - offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))] - offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 - for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))] - offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))] + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" if cutlass.const_expr(not is_ts): llvm.inline_asm( @@ -288,11 +326,15 @@ def gemm_ptx_partial( idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), - sA_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) @@ -300,26 +342,38 @@ def gemm_ptx_partial( smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), - sB_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) - tCrA_layout = tCrA.layout if cutlass.const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + tCrA_layout = ( + tCrA.layout + if cutlass.const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" if cutlass.const_expr(not is_ts): assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" @@ -368,7 +422,9 @@ def gemm_ptx_partial( ) else: input_args = [ - cutlass.Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + cutlass.Int32( + cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint()) + ).ir_value(), cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), cutlass.Int32(not zero_init).ir_value(), ] @@ -421,17 +477,26 @@ def gemm_ptx_partial( # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(1, cute.size(tCrA.shape[2]) if cutlass.const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 4 * 3) + for k in range( + 1, + cute.size(tCrA.shape[2]) + if cutlass.const_expr(mbar_ptr is None) + else cute.size(tCrA.shape[2]) // 4 * 3, + ) ) + mbar_wait_str - + ("".join( - ( - f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" - f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) ) - for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) - ) if cutlass.const_expr(mbar_ptr is not None) else "") + if cutlass.const_expr(mbar_ptr is not None) + else "" + ) + "}\n", # "r,r,r", "r,r,r" if cutlass.const_expr(mbar_ptr is None) else "r,r,r,r,r", @@ -440,6 +505,7 @@ def gemm_ptx_partial( asm_dialect=llvm.AsmDialect.AD_ATT, ) + @cute.jit def gemm_ptx_partial1( op: cute.nvgpu.tcgen05.mma.MmaOp, @@ -464,36 +530,50 @@ def gemm_ptx_partial1( assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) if cutlass.const_expr(not is_ts): - smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), - sA_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None - smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( - cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), - sB_swizzle, - sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN - )) + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) mask = [cutlass.Int32(0)] * 4 if cutlass.const_expr(not is_ts): - offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 - for k in range(cute.size(tCrA.shape[2]))] + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] else: - offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 - for k in range(cute.size(tCrA.shape[2]))] + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] - offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 - for k in range(cute.size(tCrB.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] if cutlass.const_expr(not is_ts): @@ -519,7 +599,7 @@ def gemm_ptx_partial1( mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), - mask[3].ir_value() + mask[3].ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" @@ -570,7 +650,7 @@ def gemm_ptx_partial1( mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), - mask[3].ir_value() + mask[3].ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 9f50321a28c..6382700bf16 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -20,13 +20,9 @@ class BlockInfo: qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit - def get_n_block_min_max( - self, seqlen_info: SeqlenInfoQK, m_block: Int32 - ) -> Tuple[Int32, Int32]: + def get_n_block_min_max(self, seqlen_info: SeqlenInfoQK, m_block: Int32) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) - if const_expr( - self.is_causal or (self.is_local and self.window_size_right is not None) - ): + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): m_idx_max = (m_block + 1) * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) @@ -44,13 +40,15 @@ def get_n_block_min_max( return n_block_min, n_block_max @cute.jit - def get_m_block_min_max( - self, seqlen_info: SeqlenInfoQK, n_block: Int32 - ) -> Tuple[Int32, Int32]: + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 if const_expr(self.is_causal): - m_block_min = max(m_block_min, (n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k) // self.tile_m) + m_block_min = max( + m_block_min, + (n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k) + // self.tile_m, + ) return m_block_min, m_block_max @cute.jit diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 83046dec6a4..b7e3d7c66ea 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -212,9 +212,9 @@ def apply_mask_sm100( n_block: Int32, thr_mma: cute.TiledMma, thr_tmem_load: cute.TiledCopy, - mask_seqlen: cutlass.Constexpr, - mask_causal: cutlass.Constexpr, - mask_local: cutlass.Constexpr, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) From aae355ea3d56a6815a2711f49165f5f275f84c77 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 12:15:32 -0400 Subject: [PATCH 153/258] [Cute] Format mma_sm100_desc.py, seqlen_info.py --- .pre-commit-config.yaml | 2 -- flash_attn/cute/mma_sm100_desc.py | 6 ++++-- flash_attn/cute/seqlen_info.py | 11 ++++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 291258fe1de..0bdc9b1b35b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,9 +17,7 @@ repos: flash_fwd_sm100| hopper_helpers| interface| - mma_sm100_desc| pack_gqa| - seqlen_info| testing| tile_scheduler| utils diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py index 62f1bc742e1..16336c34686 100644 --- a/flash_attn/cute/mma_sm100_desc.py +++ b/flash_attn/cute/mma_sm100_desc.py @@ -138,9 +138,10 @@ def make_instr_desc( if N < 8 or N > 256 or (N & 7): raise ValueError("N must be a multiple of 8 in the range 8…256") - m_dim = M >> 4 # 5-bit field - n_dim = N >> 3 # 6-bit field + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + # fmt: off # --- pack the bit-fields ----------------------------------------------------- desc = 0 desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) @@ -156,6 +157,7 @@ def make_instr_desc( desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on return desc & 0xFFFF_FFFF # ensure 32-bit result diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 792d84e2d64..792da01bd90 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -10,6 +10,7 @@ to compute various things like n_block_min, n_block_max, etc. """ + class SeqlenInfo: def __init__( self, @@ -60,19 +61,19 @@ def __init__( self.has_cu_seqlens_k: int = mCuSeqlensK is not None def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: - """Seqlen must be the first dimension of mQ - """ + """Seqlen must be the first dimension of mQ""" if const_expr(not self.has_cu_seqlens_q): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) return mQ[idx] else: - offset = self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q) + offset = ( + self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q) + ) idx = (offset,) + (0,) * (cute.rank(mQ) - 1) return cute.domain_offset(idx, mQ) def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: - """Seqlen must be the first dimension of mK - """ + """Seqlen must be the first dimension of mK""" if const_expr(not self.has_cu_seqlens_k): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) return mK[idx] From 83eb8d6c082a6bd9c6c986a890eddae7ad2a257e Mon Sep 17 00:00:00 2001 From: Ted Zadouri Date: Sun, 19 Oct 2025 13:03:36 -0400 Subject: [PATCH 154/258] sm100 bwd add kernel and update postprocess mask and barriers (#1945) --- flash_attn/cute/flash_bwd_postprocess.py | 233 +++ flash_attn/cute/flash_bwd_sm100.py | 2330 ++++++++++++++++++++++ flash_attn/cute/mask.py | 46 + flash_attn/cute/named_barrier.py | 6 + 4 files changed, 2615 insertions(+) create mode 100644 flash_attn/cute/flash_bwd_sm100.py diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 9be406b19bb..a2d9e93b547 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -9,6 +9,7 @@ import cutlass import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic +import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass.cute.nvgpu import cpasync, warp, warpgroup from cutlass import Float32, const_expr from cutlass.utils import LayoutEnum @@ -18,6 +19,7 @@ from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute import hopper_helpers as sm90_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK +import cutlass.cute.nvgpu.tcgen05 as tcgen05 from flash_attn.cute.tile_scheduler import ( ParamsBase, SingleTileScheduler, @@ -386,3 +388,234 @@ def kernel( tdQgdQ[None, rest_m, None], pred=tdQpdQ[None, rest_m, None], ) + +class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess): + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 128, + num_threads: int = 256, + AtomLayoutMdQ: int = 1, + dQ_swapAB: bool = False, + ): + super().__init__( + dtype=dtype, + head_dim=head_dim, + arch=90, # tmp dummy placement for now + tile_m=m_block_size, + num_threads=num_threads, + AtomLayoutMdQ=AtomLayoutMdQ, + dQ_swapAB=dQ_swapAB, + ) + + def _setup_attributes(self): + self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128 + + self.sdQaccum_layout = cute.make_layout(shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32)) + self.epi_tile_q = (self.tile_m, self.tile_hdim) + self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( + self.dtype, + LayoutEnum.ROW_MAJOR, + self.epi_tile_q, + 1, + ) + + @cute.jit + def __call__( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + stream: cuda.CUstream, + ): + # (b, h, s*d) -> (s*d, h, b) + mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0])) + # (b, s, h, d) -> (s, d, h, b) + mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0])) + + self._setup_attributes() + + grid_dim = [ + cute.ceil_div(mdQ.shape[0], self.tile_m), + cute.size(mdQ.shape[2]), + cute.size(mdQ.shape[3]), + ] + + cta_group = tcgen05.CtaGroup.ONE + self.mma_tiler_dsk = (self.tile_m, self.tile_hdim) + + dS_major_mode = tcgen05.OperandMajorMode.MN + kt_major_mode_dsq = tcgen05.OperandMajorMode.MN + + tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( + cutlass.BFloat16 , + dS_major_mode, + kt_major_mode_dsq, + cutlass.Float32, + cta_group, + self.mma_tiler_dsk, + ) + + dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + mdQ, + cute.select(self.sdQ_layout, mode=[0, 1]), + dQ_cta_v_layout, + ) + + buffer_align_bytes = 1024 + @cute.struct + class SharedStorage: + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)], + 128, + ] + + sdQ: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)], + buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + mdQaccum, + tma_tensor_dQ, + tma_atom_dQ, + self.sdQaccum_layout, + self.sdQ_layout, + tiled_mma_dsk, + scale, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + @cute.kernel + def kernel( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + tiled_mma_dsk: cute.TiledMma, + scale: cutlass.Float32, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + m_block, head_idx, batch_idx = cute.arch.block_idx() + + # SMEM + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + swz128 = cute.make_swizzle(3, 4, 3) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) + + sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner) + + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + mdQ_cur = mdQ[None, None, head_idx, batch_idx] + + thr_mma_dsk = tiled_mma_dsk.get_slice(tidx) + dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator , tdQtdQ.layout) + + tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32) + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) + tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim, ) , (m_block, )) + + num_reduce_warps = 4 + num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps + + + atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128) + tiler_mn, layout_tv = cute.make_layout_tv(thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1), val_layout=cute.make_layout(shape=4, stride=1)) + G2S_tiled_copy_dQaccum = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) + + smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx) + + # S->R + tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) + tiled_smem_store_s2r = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) + + s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx) + tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape) + + # R->S + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld + ) + tiled_smem_store_r2s = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_ld.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld.tiler_mn, + ) + tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) + tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) + + + num_stages = cute.size(tdQrdQ_t2r, mode=[1]) + for stage in cutlass.range_constexpr(num_stages): + + # G->S + gdQaccum_stage = cute.local_tile(gdQaccum, (self.tile_m * 32, ), (stage, ),) + + gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0)) + gdQaccum_stage_g2s = cute.make_tensor(cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s) + + tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s) + tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum) + + cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0]) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) + + # S -> R + tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None] + tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0] + tdQrdQ_r2s_cpy = cute.make_tensor(tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape)) + + cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) + + # R->S + tdQrdQ_r2s_cpy = cute.make_tensor(cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape) + dQ_vec = tdQrdQ_r2s_cpy.load() * scale + tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype)) + + + cute.copy(tiled_smem_store_r2s, tdQrdQ_r2s[None, None, None, None, 0], tdQsdQ_r2s[None, None, None, None, 0]) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) + + + # S-> G + gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) + tdQsdQ, tdQgdQ = cpasync.tma_partition( + tma_atom_dQ, + 0, + cute.make_layout(1), + cute.group_modes(sdQ, 0, 2), + cute.group_modes(gdQ, 0, 2) + ) + + cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block]) + + diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py new file mode 100644 index 00000000000..69ea1f04847 --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -0,0 +1,2330 @@ +from ctypes import alignment +import enum +import math +from typing import Type, Tuple, Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +from cutlass._mlir.ir import _si1Attr +from cutlass.base_dsl.jit_executor import t +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.cute.nvgpu.tcgen05 as tcgen05 + +import cutlass.utils.blackwell_helpers as sm100_utils_basic +import flash_attn.cute.utils as utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfo, SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo + +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, ParamsBase +from cutlass.pipeline import PipelineAsync + +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import dsl_user_op + +from cutlass._mlir.dialects import nvvm + +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 + + +@dsl_user_op +def tma_reduce_add_bulk_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: cutlass.Int32, + *, loc=None, ip=None + ): + cute.make_mma_atom + smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_u32, store_bytes.ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +class FlashAttentionBackwardSm100: + arch = 100 + + def __init__( + self, + head_dim: int, + head_dim_v: Optional[int] = None, + is_causal: bool = False, + is_local: bool = False, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + m_block_size: int = 128, + n_block_size: int = 128, + is_persistent: bool = False, + deterministic: bool = False, + ): + + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + assert self.head_dim_padded == self.head_dim_v_padded, "head_dim_padded and head_dim_v_padded must be the same for now" + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + + self.m_block_size = m_block_size + self.n_block_size = n_block_size + # number of tma reduce adds per dQacc mma + self.dQaccum_reduce_stage = self.head_dim_padded // 32 + + # CTA tiler + self.cta_tiler = (m_block_size, n_block_size, self.head_dim_padded) + + # S = K @ Q.T + self.mma_tiler_kq = (n_block_size, m_block_size, self.head_dim_padded) + + # dP = V @ dO.T + self.mma_tiler_vdo = (n_block_size, m_block_size, self.head_dim_v_padded) + + # dV = P.T @ dO + self.mma_tiler_pdo = (n_block_size, self.head_dim_v_padded, m_block_size) + + # dK = dS.T @ Q (N, M) (M, D) + self.mma_tiler_dsq = (n_block_size, self.head_dim_v_padded, m_block_size) + + # dQ = dS @ K + self.mma_tiler_dsk = (m_block_size, self.head_dim_v_padded, n_block_size) + + + self.kq_acc_dtype = self.vdo_acc_dtype = self.pdo_acc_dtype = self.dsq_acc_dtype = self.dsk_acc_dtype = Float32 + + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = False + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = False + self.use_tma_store = True + self.deterministic = deterministic + + self.reduce_warp_ids = (0, 1, 2, 3) + self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epi_warp_id = 14 + self.empty_warp_id = 15 + + # 16 warps -> 512 threads + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.reduce_warp_ids, + *self.compute_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epi_warp_id, + self.empty_warp_id, + ) + ) + + # TMEM setup + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.tmem_s_offset = 0 + self.tmem_p_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_s_offset + self.n_block_size + self.tmem_dP_offset = self.tmem_dV_offset + self.head_dim_v_padded + self.tmem_dQaccum_offset = self.tmem_dP_offset # overlap with dP + self.tmem_dK_offset = self.tmem_dP_offset + self.m_block_size + + self.num_regs_reduce = 144 + self.num_regs_compute = 128 + self.num_regs_load = 96 + self.num_regs_mma = 112 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + self.num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) + + def _setup_attributes(self): + + self.q_stage = 2 + self.k_stage = 1 + self.v_stage = 1 + self.do_stage = 1 + self.ds_stage = 1 + self.lse_stage = 1 + self.acc_stage = 1 + self.s_stage = 1 + self.dP_stage = 1 + self.dV_stage = 1 + self.dK_stage = 1 + self.dS_stage = 1 + self.dQaccum_mma_stage = 1 + self.sdQaccum_stage = 2 + self.psum_stage = 1 + self.p_tmem_stage = 1 + self.sdKdVaccum_stage = 2 + + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + ): + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.do_dtype = mdO.element_type + self.lse_dtype = mLSE.element_type + self.psum_dtype = mPsum.element_type + self.dqaccum_dtype = mdQaccum.element_type + self.dk_dtype = mdK.element_type + self.dv_dtype = mdV.element_type + self.ds_dtype = self.q_dtype + + if const_expr(self.qhead_per_kvhead > 1): + assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" + assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" + + QKVdO_layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdO, mdK, mdV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QKVdO_layout_transpose)) + for t in (mQ, mK, mV, mdO, mdK, mdV) + ] + + LSE_Psum_dQaccum_layout_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + mLSE, mPsum, mdQaccum = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_Psum_dQaccum_layout_transpose)) + for t in (mLSE, mPsum, mdQaccum) + ] + + dO_transpose = [1, 0, 2, 3] + mdO = cute.make_tensor(mdO.iterator, cute.select(mdO.layout, mode=dO_transpose)) + + semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = cute.make_tensor(mdQ_semaphore.iterator, cute.select(mdQ_semaphore.layout, mode=semaphore_transpose)) + else: + mdQ_semaphore = None + + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=semaphore_transpose)) + for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.do_major_mode = cutlass.utils.LayoutEnum.from_tensor(mdO).mma_major_mode() + + self._setup_attributes() + cta_group = tcgen05.CtaGroup.ONE + + # S = K @ Q.T + tiled_mma_kq = sm100_utils_basic.make_trivial_tiled_mma( + self.k_dtype, + self.k_major_mode, + self.q_major_mode, + self.kq_acc_dtype, + cta_group, + self.mma_tiler_kq[:2], + ) + + # dV += P @ dO --> (K, MN) major + p_source = tcgen05.OperandSource.TMEM + self.p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_pdo = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + self.p_major_mode, + self.do_major_mode, + self.pdo_acc_dtype, + cta_group, + self.mma_tiler_pdo[:2], + p_source, + ) + + # dP = V @ dO.T + self.dot_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_vdo = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + self.v_major_mode, + self.dot_major_mode, + self.vdo_acc_dtype, + cta_group, + self.mma_tiler_vdo[:2], + ) + + # dK += dS.T @ Q + self.dSt_major_mode = tcgen05.OperandMajorMode.K + self.q_major_mode_dsq = tcgen05.OperandMajorMode.MN + tiled_mma_dsq = sm100_utils_basic.make_trivial_tiled_mma( + self.ds_dtype, + self.dSt_major_mode, + self.q_major_mode_dsq, + self.dsq_acc_dtype, + cta_group, + self.mma_tiler_dsq[:2], + ) + + # dQ = dS @ K + self.dS_major_mode = tcgen05.OperandMajorMode.MN + self.kt_major_mode_dsq = tcgen05.OperandMajorMode.MN + tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( + self.ds_dtype, + self.dS_major_mode, + self.kt_major_mode_dsq, + self.dsk_acc_dtype, + cta_group, + self.mma_tiler_dsk[:2], + ) + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_kq.thr_id.shape,), + ) + + # S = K @ Q.T + sK_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_kq, self.mma_tiler_kq, self.k_dtype, self.k_stage, + ) + sQ_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_kq, self.mma_tiler_kq, self.q_dtype, self.q_stage, + ) + + # dV += P @ dO + sdO_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pdo, self.mma_tiler_pdo, self.do_dtype, self.do_stage, + ) + + # dP = V @ dO.T + sV_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_vdo, self.mma_tiler_vdo, self.v_dtype, self.v_stage, + ) + + sdOt_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_vdo, self.mma_tiler_vdo, self.do_dtype, self.do_stage, + ) + + # dK += dS.T @ Q + sdSt_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_dsq, self.mma_tiler_dsq, self.ds_dtype, self.ds_stage, + ) + + sQt_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_dsq, self.mma_tiler_dsq, self.q_dtype, self.q_stage, + ) + + # dQaccum = dS @ K + sdS_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_dsk, self.mma_tiler_dsk, self.q_dtype, self.ds_stage, + ) + sKt_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_dsk, self.mma_tiler_dsk, self.k_dtype, self.k_stage, + ) + + sdQaccum_layout = cute.make_layout(shape=(self.m_block_size * 32, self.sdQaccum_stage ),) + sLSE_layout = cute.make_layout(shape=(self.m_block_size, self.lse_stage), stride=(1, cute.round_up(self.m_block_size, 64))) + sPsum_layout = cute.make_layout(shape=(self.m_block_size, self.psum_stage), stride=(1, cute.round_up(self.m_block_size, 64))) + + self.mdK_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdK) + self.mdV_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdV) + self.dK_major_mode = self.mdK_layout_enum.mma_major_mode() + self.dV_major_mode = self.mdV_layout_enum.mma_major_mode() + if const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdK is wrong") + if const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdV is wrong") + self.sdKdV_epi_tile = (self.n_block_size, 128 // (self.dk_dtype.width // 8)) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + sdKdV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dk_dtype, self.mdK_layout_enum, self.sdKdV_epi_tile, self.sdKdVaccum_stage, + ) + + self.tma_copy_dKdV_bytes = cute.size_in_bytes(self.dk_dtype, cute.select(sdKdV_layout, mode=[0,1])) + + if const_expr(self.use_tma_store): + if const_expr(self.dk_dtype.width == 32): + tma_copy_op_dKdV = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + tma_copy_op_dKdV = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKdV, + mdK, + cute.select(sdKdV_layout, mode=[0, 1]), + self.sdKdV_epi_tile, + 1 # no mcast + ) + tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKdV, + mdV, + cute.select(sdKdV_layout, mode=[0, 1]), + self.sdKdV_epi_tile, + 1 # no mcast + ) + else: + assert self.qhead_per_kvhead == 1, "Must use TMA reduce add for GQA" + mdV_tma_tensor = mdV + mdK_tma_tensor = mdK + tma_atom_dV = None + tma_atom_dK = None + + thr_layout_r2s_dKdV = cute.make_ordered_layout((self.n_block_size, 1), order=(1,0)) # 128 threads + val_layout_r2s_dKdV = cute.make_ordered_layout((1, 128 // self.dk_dtype.width), order=(1,0)) # 4 or 8 vals for 16 byte store + r2s_copy_atom_r2s_dKdV = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=128,) + tiled_copy_r2s_dKdV = cute.make_tiled_copy_tv(r2s_copy_atom_r2s_dKdV, thr_layout_r2s_dKdV, val_layout_r2s_dKdV) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + + # S = K @ Q.T + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + tiled_mma_kq, + self.cluster_layout_vmnk.shape, + ) + + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + tiled_mma_kq, + self.cluster_layout_vmnk.shape, + ) + + # dV += P @ dO + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mdO, + cute.select(sdO_layout, mode=[0, 1, 2]), + self.mma_tiler_pdo, + tiled_mma_pdo, + self.cluster_layout_vmnk.shape, + ) + tma_atom_LSE, tma_tensor_LSE = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_load_op, + mLSE, + cute.make_layout((self.m_block_size)), + (self.m_block_size, ), + ) + tma_atom_Psum, tma_tensor_Psum = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_load_op, + mPsum, + cute.make_layout((self.m_block_size)), + (self.m_block_size, ), + ) + + # dP = V @ dO.T + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_vdo, + tiled_mma_vdo, + self.cluster_layout_vmnk.shape, + ) + + self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + self.tma_copy_k_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + self.tma_copy_v_bytes = cute.size_in_bytes(self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + self.tma_copy_do_bytes = cute.size_in_bytes(self.do_dtype, cute.select(sdO_layout, mode=[0, 1, 2])) + self.tma_copy_lse_bytes = self.m_block_size * 4 + self.tma_copy_psum_bytes = self.m_block_size * 4 + + TileScheduler = SingleTileScheduler + # TODO -- optimizer scheduler for causal + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), # num_heads = num_query_heads + cute.size(mK.shape[3]), + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]), + tile_shape_mn=self.cta_tiler[:2], + mCuSeqlensQ=None, + mSeqUsedQ=None, + qhead_per_kvhead_packgqa=1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=False, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + # cute.printf("grid_dim = {}", grid_dim) + + @cute.struct + class SharedStorage: + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] + k_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + v_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.v_stage] + lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] + do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] + lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + psum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] + psum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] + s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] + p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.ds_stage] + dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] + dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] + dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] + dQaccum_reduce_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] + + # TMEM + tmem_holding_buf: Int32 + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + + # Smem tensors + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[self.do_dtype, cute.cosize(sdO_layout)], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(sdSt_layout)], + 128, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.lse_dtype, cute.cosize(sLSE_layout)], + 128, + ] + sPsum: cute.struct.Align[ + cute.struct.MemRange[self.psum_dtype, cute.cosize(sPsum_layout)], + 128, + ] + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(sdQaccum_layout)], + self.buffer_align_bytes, + ] + self.shared_storage = SharedStorage + + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + tma_tensor_LSE, + tma_tensor_Psum, + tma_tensor_dO, + mdV, + mdK, + mdQaccum, + mdV_tma_tensor, + mdK_tma_tensor, + mdQ_semaphore, + mdK_semaphore, + mdV_semaphore, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_LSE, + tma_atom_Psum, + tma_atom_dO, + tma_atom_dV, + tma_atom_dK, + sQ_layout, + sQt_layout, + sK_layout, + sV_layout, + sLSE_layout, + sPsum_layout, + sdO_layout, + sdOt_layout, + sdSt_layout, + sdS_layout, + sKt_layout, + sdQaccum_layout, + sdKdV_layout, + tiled_mma_kq, + tiled_mma_pdo, + tiled_mma_vdo, + tiled_mma_dsq, + tiled_mma_dsk, + tiled_copy_r2s_dKdV, + softmax_scale, + softmax_scale_log2, + tile_sched_params, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mPsum: cute.Tensor, + mdO: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + mdQaccum: cute.Tensor, + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + mdQ_semaphore: Optional[cute.Tensor], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_LSE: cute.CopyAtom, + tma_atom_Psum: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + sQ_layout: cute.ComposedLayout, + sQt_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sPsum_layout: cute.Layout, + sdO_layout: cute.ComposedLayout, + sdOt_layout: cute.ComposedLayout, + sdSt_layout: cute.ComposedLayout, + sdS_layout: cute.ComposedLayout, + sKt_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + sdKdV_layout: cute.ComposedLayout, + tiled_mma_kq: cute.TiledMma, + tiled_mma_pdo: cute.TiledMma, + tiled_mma_vdo: cute.TiledMma, + tiled_mma_dsq: cute.TiledMma, + tiled_mma_dsk: cute.TiledMma, + tiled_copy_r2s_dKdV: cute.TiledCopy, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + tile_sched_params: ParamsBase, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == self.load_warp_id: + with cute.arch.elect_one(): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_LSE) + cpasync.prefetch_descriptor(tma_atom_Psum) + cpasync.prefetch_descriptor(tma_atom_dO) + if const_expr(tma_atom_dV is not None): + cpasync.prefetch_descriptor(tma_atom_dV) + if const_expr(tma_atom_dK is not None): + cpasync.prefetch_descriptor(tma_atom_dK) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + k_full_mbar_ptr = storage.k_full_mbar_ptr.data_ptr() + v_full_mbar_ptr = storage.v_full_mbar_ptr.data_ptr() + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() + lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() + psum_full_mbar_ptr = storage.psum_full_mbar_ptr.data_ptr() + psum_empty_mbar_ptr = storage.psum_empty_mbar_ptr.data_ptr() + dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() + + if warp_idx == self.load_warp_id: + cute.arch.mbarrier_init(k_full_mbar_ptr, len([self.load_warp_id])) + cute.arch.mbarrier_init(v_full_mbar_ptr, len([self.load_warp_id])) + cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + cute.arch.mbarrier_init(lse_full_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(psum_full_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(psum_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(dQaccum_reduce_mbar_ptr, 1) + + pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id])) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) + + pipeline_q = cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=self.q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_q_bytes, + ) + + pipeline_do = cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=storage.do_mbar_ptr.data_ptr(), + num_stages=self.do_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_do_bytes, + ) + + # UMMA producers and AsyncThread consumers + pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) + pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + + pipeline_s = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.s_stage, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.s_mbar_ptr.data_ptr(), + ) + pipeline_dV = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.dV_stage, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dV_mbar_ptr.data_ptr(), + ) + pipeline_dK = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.dK_stage, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dK_mbar_ptr.data_ptr(), + ) + pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.reduce_warp_ids), alignment=128) # Compute + pipeline_dQaccum = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.dQaccum_mma_stage, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, + barrier_storage=storage.dQaccum_mbar_ptr.data_ptr(), + ) + pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.dP_stage, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dP_mbar_ptr.data_ptr(), + ) + + # AsyncThread producers and UMMA consumers + pipeline_pdS_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) # Compute + pipeline_pdS_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) # MMA + + pipeline_p = cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.s_stage, + producer_group=pipeline_pdS_producer_group, + consumer_group=pipeline_pdS_consumer_group, + barrier_storage=storage.p_mbar_ptr.data_ptr(), + ) + + pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.dS_stage, + producer_group=pipeline_pdS_producer_group, + consumer_group=pipeline_pdS_consumer_group, + barrier_storage=storage.dS_mbar_ptr.data_ptr(), + ) + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQt = cute.make_tensor(cute.recast_ptr(sQ.iterator, swizzle_=sQt_layout.inner), sQt_layout.outer) + sQ_pi = storage.sQ.get_tensor(sQ_layout) + + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, swizzle_=sKt_layout.inner), sKt_layout.outer) + + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + + sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) + sdSt_pi = storage.sdS.get_tensor(sdSt_layout) + + sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, swizzle_=sdS_layout.inner), sdS_layout.outer) + + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, swizzle_=sdOt_layout.inner), sdOt_layout.outer) + + sLSE_load = storage.sLSE.get_tensor(sLSE_layout) + sLSE_mma = storage.sLSE.get_tensor(cute.make_layout( + shape=(self.m_block_size, self.n_block_size, self.lse_stage), + stride=(0, 1, 0) + )) + + + sPsum_load = storage.sPsum.get_tensor(sPsum_layout) + sPsum_mma = storage.sPsum.get_tensor(cute.make_layout( + shape=(self.m_block_size, self.n_block_size, self.psum_stage), + stride=(0, 1, 0) + )) + + sdV = storage.sdO.get_tensor(sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype) + sdK = storage.sQ.get_tensor(sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype) + + assert cute.cosize(sdV) * self.dv_dtype.width <= cute.cosize(sdO) * self.do_dtype.width, "Not enough space for sdV" + assert cute.cosize(sdK) * self.dk_dtype.width <= cute.cosize(sQ) * self.q_dtype.width, "Not enough space for sdK" + + swz128 = cute.make_swizzle(3, 4, 3) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) + + # TMEM + # S + thr_mma_kq = tiled_mma_kq.get_slice(0) + Sacc_shape = thr_mma_kq.partition_shape_C(self.mma_tiler_kq[:2]) #(M, N) + tStS = thr_mma_kq.make_fragment_C(Sacc_shape) + tStS = cute.make_tensor(tStS.iterator, tStS.layout) + + # dV + thr_mma_pdo = tiled_mma_pdo.get_slice(0) + dvacc_shape = thr_mma_pdo.partition_shape_C(self.mma_tiler_pdo[:2]) + tdVtdV = thr_mma_pdo.make_fragment_C(dvacc_shape) + tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset , tdVtdV.layout) + + # dK + thr_mma_dsq = tiled_mma_dsq.get_slice(0) + dkacc_shape = thr_mma_dsq.partition_shape_C(self.mma_tiler_dsq[:2]) + tdKtdK = thr_mma_dsq.make_fragment_C(dkacc_shape) + tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset , tdKtdK.layout) + + # dQ + thr_mma_dsk = tiled_mma_dsk.get_slice(0) + dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset , tdQtdQ.layout) + + # dP + thr_mma_vdo = tiled_mma_vdo.get_slice(0) + dPacc_shape = thr_mma_vdo.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_vdo.make_fragment_C(dPacc_shape) + tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset , tdPtdP.layout) + + block_info = BlockInfo( + self.m_block_size, + self.n_block_size, + self.is_causal, self.is_local, + None, None, + qhead_per_kvhead_packgqa=1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=None, mCuSeqlensK=None, + mSeqUsedQ=None, mSeqUsedK=None, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # TODO: support local + AttentionMaskCls = partial( + AttentionMask, self.m_block_size, self.n_block_size, + ) + + cute.arch.sync_threads() + + # EMPTY + # (15) + if warp_idx == self.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # EPI + # (14) + if warp_idx == self.epi_warp_id: + # currently no-op, could use for tma store/reduce + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # LOAD + # (13) + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + self.load( + thr_mma_kq, + thr_mma_pdo, + thr_mma_vdo, + mQ, + mK, + mV, + mLSE, + mPsum, + mdO, + sQ, + sK, + sV, + sLSE_load, + sPsum_load, + sdO, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_LSE, + tma_atom_Psum, + tma_atom_dO, + pipeline_q, + lse_full_mbar_ptr, + lse_empty_mbar_ptr, + psum_full_mbar_ptr, + psum_empty_mbar_ptr, + pipeline_do, + k_full_mbar_ptr, + v_full_mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # MMA + # (12) + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_mma) + + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_kq, + tiled_mma_pdo, + tiled_mma_vdo, + tiled_mma_dsq, + tiled_mma_dsk, + thr_mma_kq, + thr_mma_pdo, + thr_mma_vdo, + thr_mma_dsq, + thr_mma_dsk, + sQ, + sQt, + sK, + sV, + sdO, + sdOt, + sdSt, + sdS, + sKt, + sK_layout.inner, + sQ_layout.inner, + tStS, + tdVtdV, + tdKtdK, + tdPtdP, + tdQtdQ, + pipeline_q, + pipeline_do, + pipeline_s, + pipeline_p, + pipeline_dS, + pipeline_dV, + pipeline_dK, + pipeline_dP, + pipeline_dQaccum, + k_full_mbar_ptr, + v_full_mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + cute.arch.relinquish_tmem_alloc_permit() + tmem_ptr = cute.arch.retrieve_tmem_ptr(Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf) + + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False) + + # Compute + # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps + if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_compute) # 8 warps + self.compute_loop( + thr_mma_kq, + thr_mma_pdo, + thr_mma_vdo, + thr_mma_dsq, + tStS, + sLSE_mma, + sPsum_mma, + tdVtdV, + tdKtdK, + mdV, + mdK, + sdSt, + sdS, + tdPtdP, + lse_full_mbar_ptr, + lse_empty_mbar_ptr, + psum_full_mbar_ptr, + psum_empty_mbar_ptr, + pipeline_s, + pipeline_p, + pipeline_dS, + pipeline_dV, + pipeline_dK, + pipeline_dP, + softmax_scale, + softmax_scale_log2, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + sdV, + sdK, + mdV_tma_tensor, + mdK_tma_tensor, + tma_atom_dV, + tma_atom_dK, + tiled_copy_r2s_dKdV, + mdK_semaphore, + mdV_semaphore, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # Reduce + # (0, 1, 2, 3) - dQ + if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_reduce) + + self.dQacc_reduce( + mdQaccum, + sdQaccum, + thr_mma_dsk, + tdQtdQ, + pipeline_dQaccum, + dQaccum_reduce_mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + mdQ_semaphore, + ) + + return + + + @cute.jit + def load( + self, + thr_mma_kq: cute.core.ThrMma, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_vdo: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mPsum: cute.Tensor, + mdO: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, + sPsum: cute.Tensor, + sdO: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_LSE: cute.CopyAtom, + tma_atom_Psum: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + pipeline_q: PipelineAsync, + lse_full_mbar_ptr: cute.Pointer, + lse_empty_mbar_ptr: cute.Pointer, + psum_full_mbar_ptr: cute.Pointer, + psum_empty_mbar_ptr: cute.Pointer, + pipeline_do: PipelineAsync, + k_full_mbar_ptr: cute.Pointer, + v_full_mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + + q_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.q_stage) + do_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.do_stage) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + head_idx_kv = head_idx // self.qhead_per_kvhead + mQ_cur = mQ[None, None, head_idx, batch_idx] + mK_cur = mK[None, None, head_idx_kv, batch_idx] + mV_cur = mV[None, None, head_idx_kv, batch_idx] + mdO_cur = mdO[None, None, head_idx, batch_idx] + mLSE_cur = mLSE[None, head_idx, batch_idx] + mPsum_cur = mPsum[None, head_idx, batch_idx] + + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) + tSgK = thr_mma_kq.partition_A(gK) + + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) + tdPgV = thr_mma_vdo.partition_A(gV) + + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) + tSgQ = thr_mma_kq.partition_B(gQ) + + gLSE = cute.local_tile(mLSE_cur, (self.n_block_size, ), (None, )) + gPsum = cute.local_tile(mPsum_cur, (self.n_block_size, ), (None, )) + + gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) + tdVgdO = thr_mma_pdo.partition_B(gdO) + + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tdPgV, 0, 3), + ) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + tdOsdO, tdOgdO = cpasync.tma_partition( + tma_atom_dO, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdVgdO, 0, 3), + ) + tLSEsLSE, tLSEgLSE = cpasync.tma_partition( + tma_atom_LSE, + 0, + cute.make_layout(1), + sLSE, + gLSE, + ) + tPsumsPsum, tPsumgPsum = cpasync.tma_partition( + tma_atom_Psum, + 0, + cute.make_layout(1), + sPsum, + gPsum, + ) + # K + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(k_full_mbar_ptr, self.tma_copy_k_bytes) + cute.copy(tma_atom_K, tKgK, tKsK[None, 0], tma_bar_ptr=k_full_mbar_ptr) + + ###### Prologue + # Q0 + pipeline_q.producer_acquire(q_producer_state) + cute.copy( + tma_atom_Q, + tQgQ[None, m_block_max - 1], + tQsQ[None, q_producer_state.index], + tma_bar_ptr=pipeline_q.producer_get_barrier(q_producer_state) + ) + pipeline_q.producer_commit(q_producer_state) + q_producer_state.advance() + + # LSE + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(lse_full_mbar_ptr, self.tma_copy_lse_bytes) + + cute.copy( + tma_atom_LSE, + tLSEgLSE[None, m_block_max - 1], + tLSEsLSE[None, 0], + tma_bar_ptr=lse_full_mbar_ptr, + ) + + # V + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(v_full_mbar_ptr, self.tma_copy_v_bytes) + cute.copy(tma_atom_V, tVgV, tVsV[None, 0], tma_bar_ptr=v_full_mbar_ptr) + + # dO + pipeline_do.producer_acquire(do_producer_state) + cute.copy( + tma_atom_dO, + tdOgdO[None, m_block_max - 1], + tdOsdO[None, do_producer_state.index], + tma_bar_ptr=pipeline_do.producer_get_barrier(do_producer_state) + ) + pipeline_do.producer_commit(do_producer_state) + do_producer_state.advance() + + # Psum + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(psum_full_mbar_ptr, self.tma_copy_psum_bytes) + + cute.copy( + tma_atom_Psum, + tPsumgPsum[None, m_block_max - 1], + tPsumsPsum[None, 0], + tma_bar_ptr=psum_full_mbar_ptr, + ) + lse_empty_consumer_phase = cute.Int32(0) + psum_empty_consumer_phase = cute.Int32(0) + + for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): + m_block = m_block_max - 2 - i + + # Q + self.load_M_tile(tma_atom_Q, tQgQ, tQsQ, pipeline_q, m_block, producer_state=q_producer_state) + pipeline_q.producer_commit(q_producer_state) + q_producer_state.advance() + + # LSE + cute.arch.mbarrier_wait(lse_empty_mbar_ptr, lse_empty_consumer_phase) + lse_empty_consumer_phase ^= 1 + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(lse_full_mbar_ptr, self.tma_copy_lse_bytes) + + cute.copy( + tma_atom_LSE, + tLSEgLSE[None, m_block], + tLSEsLSE[None, 0], + tma_bar_ptr=lse_full_mbar_ptr, + ) + + # dO + self.load_M_tile(tma_atom_dO, tdOgdO, tdOsdO, pipeline_do, m_block, producer_state=do_producer_state) + pipeline_do.producer_commit(do_producer_state) + do_producer_state.advance() + + # Psum + cute.arch.mbarrier_wait(psum_empty_mbar_ptr, psum_empty_consumer_phase) + psum_empty_consumer_phase ^= 1 + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(psum_full_mbar_ptr, self.tma_copy_psum_bytes) + + cute.copy( + tma_atom_Psum, + tPsumgPsum[None, m_block], + tPsumsPsum[None, 0], + tma_bar_ptr=psum_full_mbar_ptr, + ) + + pipeline_q.producer_tail(q_producer_state) + pipeline_do.producer_tail(do_producer_state) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def mma( + self, + tiled_mma_kq: cute.core.TiledMma, + tiled_mma_pdo: cute.core.TiledMma, + tiled_mma_vdo: cute.core.TiledMma, + tiled_mma_dsq: cute.core.TiledMma, + tiled_mma_dsk: cute.core.TiledMma, + thr_mma_kq: cute.core.ThrMma, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_vdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + thr_mma_dsk: cute.core.ThrMma, + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sdOt: cute.Tensor, + sdSt: cute.Tensor, + sdS: cute.Tensor, + sKt: cute.Tensor, + sK_swizzle: cute.Swizzle, + sQ_swizzle: cute.Swizzle, + tStS: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + tdPtdP: cute.Tensor, + tdQacctdQacc: cute.Tensor, + pipeline_q: PipelineAsync, + pipeline_do: PipelineAsync, + pipeline_s: PipelineAsync, + pipeline_p: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dV: PipelineAsync, + pipeline_dK: PipelineAsync, + pipeline_dP: PipelineAsync, + pipeline_dQaccum: PipelineAsync, + full_key_mbar_ptr: cute.Pointer, + full_value_mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + key_consumer_phase = cutlass.Int32(0) + + q_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.q_stage) + q_dk_consumer_state = q_consumer_state + do_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.do_stage) + + s_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.s_stage) + dP_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dP_stage) + p_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.s_stage) + dS_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage) + dV_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dV_stage) + dK_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dK_stage) + dQaccum_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k + + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + cute.arch.mbarrier_wait(full_key_mbar_ptr, phase=key_consumer_phase) + cute.arch.mbarrier_wait(full_value_mbar_ptr, phase=key_consumer_phase) + + key_consumer_phase ^= 1 + + # S = K @ Q.T sK and sQ + tSrK = thr_mma_kq.make_fragment_A(sK) + tSrQ = thr_mma_kq.make_fragment_B(sQ) + + # dP = V @ dOt + tdPrV = thr_mma_vdo.make_fragment_A(sV) + tdPrdOt = thr_mma_vdo.make_fragment_B(sdOt) + + # dK = dS.T @ Q + tdKrdS = thr_mma_dsq.make_fragment_A(sdSt) + tdKrQ = thr_mma_dsq.make_fragment_B(sQt) + + accumulate_dK = False + + # dV = P @ dO.T + tdVrdO = thr_mma_pdo.make_fragment_B(sdO) + p_tmem_layout = sm100_utils_basic.make_smem_layout_a(tiled_mma_pdo, self.mma_tiler_pdo, self.q_dtype, self.acc_stage,) + + tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) + tdVrP = thr_mma_pdo.make_fragment_A(tP)[None, None, None, 0] + tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) + + # dQ = dS @ K + tdQaccrdS = thr_mma_dsk.make_fragment_A(sdS) + tdQaccrK = thr_mma_dsk.make_fragment_B(sKt) + + + #----------------------------------------------------------- + ###### Prologue + #----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dO.T + # 3. dV = P @ dO + + # 1) S = Q0 @ K.T + pipeline_q.consumer_wait(q_consumer_state) + pipeline_s.producer_acquire(s_producer_state) + + num_k_phases = cute.size(tSrK, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): + tiled_mma_kq.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_kq, + tStS, + tSrK[(None, None, kphase_idx, 0)], + tSrQ[(None, None, kphase_idx, q_consumer_state.index)], + tStS, + ) + + q_consumer_state.advance() + pipeline_s.producer_commit(s_producer_state) + s_producer_state.advance() + + # 2) dP = V @ dO.T + pipeline_do.consumer_wait(do_consumer_state) + pipeline_dP.producer_acquire(dP_producer_state) + + pipeline_dQaccum.producer_acquire(dQaccum_producer_state) + + for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): + tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_vdo, + tdPtdP, + tdPrV[(None, None, kphase_idx, 0)], + tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], + tdPtdP, + ) + pipeline_dP.producer_commit(dP_producer_state); dP_producer_state.advance() + + # 3) dV = P.T @ dO + pipeline_p.consumer_wait(p_consumer_state) + + num_kphases = cute.size(tdVrP, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases): + tiled_mma_pdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_pdo, + tdVtdV, + tdVrP[(None, None, kphase_idx)], + tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], + tdVtdV, + ) + pipeline_p.consumer_release(p_consumer_state); p_consumer_state.advance() + pipeline_do.consumer_release(do_consumer_state); do_consumer_state.advance() + #----------------------------------------------------------- + ###### MAIN LOOP + #----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dO.T + # 5. dV = P.T @ dO + + for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): + # 1) S = K @ Q_i + pipeline_q.consumer_wait(q_consumer_state) + pipeline_s.producer_acquire(s_producer_state) + #''' + for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): + tiled_mma_kq.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_kq, + tStS, + tSrK[(None, None, kphase_idx, 0)], + tSrQ[(None, None, kphase_idx, q_consumer_state.index)], + tStS, + ) + + pipeline_s.producer_commit(s_producer_state) + s_producer_state.advance() + q_consumer_state.advance() + + # 2) dQ = dS @ K + pipeline_dS.consumer_wait(dS_consumer_state) + pipeline_dP.producer_acquire(dP_producer_state) + + num_kphases = cute.size(tdQaccrdS, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases): + tiled_mma_dsk.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_dsk, + tdQacctdQacc, + tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdQaccrK[(None, None, kphase_idx, 0)], + tdQacctdQacc, + ) + pipeline_dQaccum.producer_commit(dQaccum_producer_state) ; dQaccum_producer_state.advance() + + # 3) dK = dS.T @ Q + num_kphases = cute.size(tdKrdS, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): + tiled_mma_dsq.set(tcgen05.Field.ACCUMULATE, accumulate_dK) + cute.gemm( + tiled_mma_dsq, + tdKtdK, + tdKrdS[(None, None, kphase_idx, 0)], + tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], + tdKtdK, + ) + accumulate_dK = True + + pipeline_q.consumer_release(q_dk_consumer_state) ; q_dk_consumer_state.advance() + pipeline_dS.consumer_release(dS_consumer_state); dS_consumer_state.advance() + + #4) dP = V @ dO.T + pipeline_do.consumer_wait(do_consumer_state) + + pipeline_dQaccum.producer_acquire(dQaccum_producer_state) + + for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): + tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_vdo, + tdPtdP, + tdPrV[(None, None, kphase_idx, 0)], + tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], + tdPtdP, + ) + pipeline_dP.producer_commit(dP_producer_state); dP_producer_state.advance() + + # 5) dV += P @ dO + pipeline_p.consumer_wait(p_consumer_state) + + num_kphases = cute.size(tdVrP, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases): + tiled_mma_pdo.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + tiled_mma_pdo, + tdVtdV, + tdVrP[(None, None, kphase_idx)], + tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], + tdVtdV, + ) + + pipeline_p.consumer_release(p_consumer_state); p_consumer_state.advance() + pipeline_do.consumer_release(do_consumer_state); do_consumer_state.advance() + + pipeline_dV.producer_acquire(dV_producer_state); pipeline_dV.producer_commit(dV_producer_state); dV_producer_state.advance() + + pipeline_s.producer_tail(s_producer_state) + pipeline_dP.producer_tail(dP_producer_state) + pipeline_dV.producer_tail(dV_producer_state) + + #----------------------------------------------------------- + ###### Remaining 2 + #----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(dS_consumer_state) + + num_kphases = cute.size(tdKrdS, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases): + tiled_mma_dsq.set(tcgen05.Field.ACCUMULATE, accumulate_dK) + cute.gemm( + tiled_mma_dsq, + tdKtdK, + tdKrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], + tdKtdK, + ) + accumulate_dK = True + + pipeline_dK.producer_acquire(dK_producer_state); + pipeline_dK.producer_commit(dK_producer_state); dK_producer_state.advance() + + # 2) dQaccum = dS @ K + num_kphases = cute.size(tdQaccrdS, mode=[2]) + for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): + tiled_mma_dsk.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_dsk, + tdQacctdQacc, + tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdQaccrK[(None, None, kphase_idx, 0)], + tdQacctdQacc, + ) + pipeline_dQaccum.producer_commit(dQaccum_producer_state) ; dQaccum_producer_state.advance() + pipeline_q.consumer_release(q_dk_consumer_state); q_dk_consumer_state.advance() + pipeline_dS.consumer_release(dS_consumer_state); dS_consumer_state.advance() + + pipeline_dK.producer_tail(dK_producer_state) + pipeline_dQaccum.producer_tail(dQaccum_producer_state) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def split_wg(self, thr_tensor: cute.Tensor, wg_idx: cutlass.Int32, num_wg: cutlass.Constexpr[cutlass.Int32]): + reduced_shape = cute.product_each(thr_tensor.shape) + rank = len(reduced_shape) + if const_expr(reduced_shape[1] > 1): + assert rank >= 2, "Need rank >= 2 for thr_tensor in split_wg" + t = cute.logical_divide(thr_tensor, (reduced_shape[0], reduced_shape[1] // num_wg)) + coord = (None, (None, wg_idx)) + (None, ) * (rank - 2) + else: + assert rank >= 3, "Need rank >= 3 for thr_tensor in split_wg" + if const_expr(rank == 3): + t = cute.logical_divide( + thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg)) + coord = (None, None, (None, wg_idx), ) + (None, ) * (rank - 3) + else: + t = cute.logical_divide(thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2], reduced_shape[3] // num_wg)) + coord = (None, None, None, (None, wg_idx), ) + (None, ) * (rank - 4) + return t[coord] + + + @cute.jit + def compute_loop( + self, + thr_mma_kq: cute.core.ThrMma, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_vdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + tStS: cute.Tensor, + sLSE_2D: cute.Tensor, + sPsum_2D: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + sdSt: cute.Tensor, + sdSt_pi: cute.Tensor, + tdPtdP: cute.Tensor, + lse_full_mbar_ptr: cute.Pointer, + lse_empty_mbar_ptr: cute.Pointer, + psum_full_mbar_ptr: cute.Pointer, + psum_empty_mbar_ptr: cute.Pointer, + pipeline_s: PipelineAsync, + pipeline_p: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dV: PipelineAsync, + pipeline_dK: PipelineAsync, + pipeline_dP: PipelineAsync, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + sdV: Optional[cute.Tensor], + sdK: Optional[cute.Tensor], + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + tiled_copy_r2s_dKdV: Optional[cute.TiledCopy], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + ): + # tix: [128...384] 8 warps + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 + + tidx = cute.arch.thread_idx()[0] % 128 # 0...128 + wg_idx = (cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 + num_wg = (cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128) # 2 + + # wg_idx: + # 0: [256...384] + # 1: [128...256] + + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32) + + s_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.s_stage) + p_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.s_stage) + dS_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.ds_stage) + + dP_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dP_stage) + + lse_consumer_phase = psum_consumer_phase = cute.Int32(0) + + sub_packed_f32x2 = partial(cute.arch.calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2, rnd=nvvm.RoundingModeKind.RN ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + # TODO: condition mask_seqlen + mask_fn = partial( + mask.apply_mask_sm100_transposed, + n_block=n_block, mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local + ) + + # Mainloop + for i in cutlass.range(m_block_max - m_block_min, unroll=1): + m_block = m_block_max - 1 - i + + pipeline_s.consumer_wait(s_consumer_state) + pipeline_p.producer_acquire(p_producer_state) + + if warp_idx == self.compute_warp_ids[0]: + cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) + lse_consumer_phase ^= 1 + + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) + tStP = cute.make_tensor( + tStS.iterator, + cute.composition(tStS.layout, cute.make_layout((self.m_block_size, tileP_f32_like))), + ) + + tiled_tmem_st = tcgen05.make_tmem_copy(tmem_store_atom, tStP) + thr_tmem_st = tiled_tmem_st.get_slice(tidx) + + #### TMEM + tStS_t2r_p = thr_tmem_ld.partition_S(tStS) + tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) + + #### RMEM + tScS = thr_mma_kq.partition_C(cute.make_identity_tensor((self.mma_tiler_kq[0], self.mma_tiler_kq[1]))) + tScS_tensor = cute.make_tensor(tScS.iterator, tScS.layout) + tScS_t2r_p = thr_tmem_ld.partition_D(tScS_tensor) + tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) + + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 + + #### TMEM->RMEM (Load S from TMEM) + cute.copy(tiled_tmem_ld, tStS_t2r, tSrS_t2r) + cute.arch.fence_view_async_tmem_load() + + #### Sync for load fence and LSE + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + + #### APPLY MASK + if const_expr(self.is_causal or self.is_local): + mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block, ) + + #--------------------------------------------- + #### P = exp(S - LSE) + #--------------------------------------------- + + #### RMEM (coordinates for P) + cP_f32 = cute.make_tensor( + tScS.iterator, + cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tileP_f32_like))) + ) + + tScP_r2t_p = thr_tmem_st.partition_S(cP_f32) + tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) + + tStP_r2t_p = thr_tmem_st.partition_D(tStP) + tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) + + #### Compute P = exp(S * scale - LSE) + tLSE = thr_tmem_ld.partition_D(sLSE_2D) + # split to wg0 & wg1 + tLSErLSE_p = cute.make_tensor(cute.recast_ptr(tLSE.iterator), cute.make_layout((tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1))) + tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] + + + WIDTH = cute.arch.WARP_SIZE + CLAMP = WIDTH - 1 + MAC = (0 << 8) | CLAMP + FULL = cute.arch.FULL_MASK + + lidx = cute.arch.lane_idx() + + + tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 + tSrP_r2t = cute.make_tensor(cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r[None, 0, None, None].layout) + + for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): + + own0 = tLSErLSE[(lidx, 0), i, 0, 0] + own1 = tLSErLSE[(lidx+1, 0), i, 0, 0] + #own1 = cute.arch.shuffle_sync(own0, offset=((lidx + 1) & CLAMP), + # mask=FULL, mask_and_clamp=MAC) + + for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): + lse_j = cute.arch.shuffle_sync(own0, offset=j, mask=FULL, mask_and_clamp=MAC) + lse_j1 = cute.arch.shuffle_sync(own1, offset=j, mask=FULL, mask_and_clamp=MAC) + + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0] = cute.arch.fma_packed_f32x2(( + (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0])), + (softmax_scale_log2, softmax_scale_log2), + (-lse_j, -lse_j1)) + + tSrS_t2r[j, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j, i, 0, 0]) + tSrS_t2r[j+1, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j+1, i, 0, 0]) + + tSrP_r2t[j, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.q_dtype) + tSrP_r2t[j+1, 0, 0] = tSrS_t2r[j+1, i, 0, 0].to(self.q_dtype) + + cute.copy(thr_tmem_st, tSrP_r2t_f32[None, None], tStP_r2t[None, None, i]) + + cute.arch.fence_view_async_tmem_store() + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + + pipeline_p.producer_commit(p_producer_state) + p_producer_state.advance() + + pipeline_s.consumer_release(s_consumer_state) + s_consumer_state.advance() + + if warp_idx == self.compute_warp_ids[0]: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(lse_empty_mbar_ptr) + + #--------------------------------------------- + # dS.T = P.T * (dP.T - D) + #--------------------------------------------- + if warp_idx == self.compute_warp_ids[0]: + cute.arch.mbarrier_wait(psum_full_mbar_ptr, psum_consumer_phase) + psum_consumer_phase ^= 1 + + pipeline_dP.consumer_wait(dP_consumer_state) + pipeline_dS.producer_acquire(dS_producer_state) + + #### TMEM->RMEM (Load dP from TMEM) + tiled_tmem_ld_dP = tcgen05.make_tmem_copy(tmem_load_atom, tdPtdP) + thr_tmem_ld_dP = tiled_tmem_ld_dP.get_slice(tidx) + + tdPtdP_t2r_p = thr_tmem_ld_dP.partition_S(tdPtdP) # + tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) + + #### TMEM->RMEM (Load dP from TMEM) + cdP = cute.make_identity_tensor((self.mma_tiler_vdo[0], self.mma_tiler_vdo[1])) + tdPcdP = thr_mma_vdo.partition_C(cdP) + tdPcdP_tensor = cute.make_tensor(tdPcdP.iterator, tdPcdP.layout) + + tdPcdP_t2r_p = thr_tmem_ld_dP.partition_D(tdPcdP_tensor) + tdPcdP_t2r = self.split_wg(tdPcdP_t2r_p, wg_idx, num_wg) + tdPrdP_t2r = cute.make_fragment(tdPcdP_t2r[(None, 0, None, None)].shape, Float32) # ((32,1),1,1) + + #### Sync for load fence and Psum + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + + ##### dS.T = P.T * (dP.T - Psum) + sdSt_mn = cute.make_tensor(sdSt_pi.iterator, cute.composition(sdSt_pi.layout, cute.make_layout((self.m_block_size, self.n_block_size)))) + tdKsdS = cute.composition(sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape)) + + tSrS_t2r_bf16 = cute.make_tensor(cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape) + + tPsum = thr_tmem_ld.partition_D(sPsum_2D) + tPsumrPsum_p = cute.make_tensor(cute.recast_ptr(tPsum.iterator), cute.make_layout((tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1))) + tPsumrPsum = tPsumrPsum_p[None, (None, wg_idx), None, None] # self.split_wg(tLSErLSE_p, wg_idx, num_wg) + + for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): + cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) + cute.arch.fence_view_async_tmem_load() + + own0 = tPsumrPsum[(lidx, 0), i, 0, 0] + own1 = tPsumrPsum[(lidx+1, 0), i, 0, 0] + + for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): + + psum_j = cute.arch.shuffle_sync(own0, offset=j, mask=FULL, mask_and_clamp=MAC) + psum_j1 = cute.arch.shuffle_sync(own1, offset=j, mask=FULL, mask_and_clamp=MAC) + + tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0] = sub_packed_f32x2( + (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0]), + (psum_j, psum_j1) + ) + + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0]), + (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0]) + ) + + tSrS_t2r_bf16[j, i, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.ds_dtype) + tSrS_t2r_bf16[j+1, i, 0, 0] = tSrS_t2r[j+1, i, 0, 0].to(self.ds_dtype) + + cute.autovec_copy(tSrS_t2r_bf16[None, i, 0, 0], tdKsdS[None, i, 0, 0]) + + pipeline_dP.consumer_release(dP_consumer_state) + dP_consumer_state.advance() + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + + pipeline_dS.producer_commit(dS_producer_state) + dS_producer_state.advance() + + if warp_idx == self.compute_warp_ids[0]: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(psum_empty_mbar_ptr) + + if const_expr(not self.use_tma_store): + self.epilogue_dKV( + tidx, + warp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_pdo, + thr_mma_dsq, + tdVtdV, + tdKtdK, + mdV, + mdK, + pipeline_dV, + pipeline_dK, + softmax_scale, + ) + else: + thr_copy_r2s_dKdV = tiled_copy_r2s_dKdV.get_slice(tidx) + #### STORE dV + self.epilogue_dK_or_dV_tma( + tidx, + batch_idx, + head_idx, + n_block, + thr_mma_pdo, + tdVtdV, + mdV_tma_tensor, + sdV, + tma_atom_dV, + thr_copy_r2s_dKdV, + pipeline_dV, + softmax_scale, + False, # apply scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdV_semaphore, + ) + #### STORE dK + self.epilogue_dK_or_dV_tma( + tidx, + batch_idx, + head_idx, + n_block, + thr_mma_dsq, + tdKtdK, + mdK_tma_tensor, + sdK, + tma_atom_dK, + thr_copy_r2s_dKdV, + pipeline_dK, + softmax_scale, + True, # apply scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdK_semaphore, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def dQacc_reduce( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + thr_mma_dsk: cute.core.ThrMma, + tdQtdQ: cute.Tensor, + pipeline_dQ: PipelineAsync, + dQaccum_reduce_mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + mdQ_semaphore: Optional[cute.Tensor], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * 4) + + dQ_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + + # TMEM -> RMEM + tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + tdQtdQ_t2r = thr_tmem_ld.partition_S(tdQtdQ) + + cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) + tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + + num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) + + atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dqaccum_dtype, num_bits_per_copy=128) + thr_layout = cute.make_layout(shape=128, stride=1) + val_layout = cute.make_layout(shape=4, stride=1) + + tiler_mn, layout_tv = cute.make_layout_tv(thr_layout=thr_layout, val_layout=val_layout) + tiled_smem_store = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) + + + smem_thr_copy_dQaccum = tiled_smem_store.get_slice(tidx) + tdQsdQ = smem_thr_copy_dQaccum.partition_D(sdQaccum) + store_bytes = cutlass.Int32(self.m_block_size * 32 * 4) + + if const_expr(self.deterministic): + read_flag = False + else: + read_flag = True + + reduce_phase = cutlass.Int32(0) + if cute.arch.thread_idx()[0] == 0: + cute.arch.mbarrier_arrive(dQaccum_reduce_mbar_ptr) + + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + + if const_expr(self.deterministic): + mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + + for i in cutlass.range(m_block_max - m_block_min, unroll=1): + m_block = m_block_max - 1 - i + + pipeline_dQ.consumer_wait(dQ_consumer_state) + + # TMEM -> RMEM + tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, Float32) + assert self.dQaccum_reduce_stage == cute.size(tdQrdQ_t2r, mode=[1]), "dQaccum reduce stage mismatch" + + cute.copy(thr_tmem_ld, tdQtdQ_t2r, tdQrdQ_t2r) + cute.arch.fence_view_async_tmem_load() + + pipeline_dQ.consumer_release(dQ_consumer_state); dQ_consumer_state.advance() + + # semaphore acquire + if const_expr(self.deterministic): + barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + + if stage >= 2 and cute.arch.thread_idx()[0] == 0: + cute.arch.cp_async_bulk_wait_group(1, read=read_flag) + + cute.arch.mbarrier_wait(dQaccum_reduce_mbar_ptr, reduce_phase) + + tdQrdQ_r2s = tdQrdQ_t2r[None, stage, None, None] + tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] + tdQrdQ_r2s = cute.make_tensor(tdQrdQ_r2s.iterator, cute.make_layout(tdQsdQ_r2s.shape)) + + cute.copy(smem_thr_copy_dQaccum, tdQrdQ_r2s, tdQsdQ_r2s) + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + + if cute.arch.thread_idx()[0] == 0: + smem_ptr = sdQaccum[None, reduce_phase].iterator + g_stage_index_elems = m_block * (self.m_block_size * self.head_dim_v_padded) + stage * (self.m_block_size * 32) + gmem_row_ptr = cute.domain_offset((g_stage_index_elems,), mdQaccum_cur).iterator + + tma_reduce_add_bulk_f32(smem_ptr, gmem_row_ptr, store_bytes) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(1, read=read_flag) + + cute.arch.mbarrier_arrive(dQaccum_reduce_mbar_ptr) + + reduce_phase ^= 1 + + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(self.deterministic): + if cute.arch.thread_idx()[0] == 0: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) + + + if cute.arch.thread_idx()[0] == 0: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def epilogue_dKV( + self, + tidx: Int32, + warp_idx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + pipeline_dV: PipelineAsync, + pipeline_dK: PipelineAsync, + softmax_scale: Float32, + ): + + wg_idx = (cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 + num_wg = (cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128) + + dV_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dV_stage) + dK_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dK_stage) + + assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32) + + # dV + pipeline_dV.consumer_wait(dV_consumer_state) + + tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) + thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) + + tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV) + tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) + + cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) + tdVcdV = thr_mma_pdo.partition_C(cdV) + tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) + + tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) + tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) + tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) + + cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dv_dtype, num_bits_per_copy=universal_copy_bits,) + tiled_gmem_store_dV = cute.make_tiled_copy(atom_universal_copy, layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, tiler_mn=tiled_tmem_ld_dV.tiler_mn,) + + tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) + for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])): + dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() + tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) + + gdV = cute.local_tile(mdV_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + gdV_tile = gdV[None, None, n_block] + + tdVgdV = thr_mma_pdo.partition_C(gdV_tile) + tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) + tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) + + cute.copy(tiled_gmem_store_dV, tdVrdV_r2s , tdVgdV_r2g) + + pipeline_dV.consumer_release(dV_consumer_state); dV_consumer_state.advance() + + # dK + pipeline_dK.consumer_wait(dK_consumer_state) + + tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) + thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) + + tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK) + tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) + + cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) + tdKcdK = thr_mma_dsq.partition_C(cdK) + tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) + + tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) + tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) + tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) + + cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=universal_copy_bits,) + + tiled_gmem_store_dK = cute.make_tiled_copy(atom_universal_copy,layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled,tiler_mn=tiled_tmem_ld_dK.tiler_mn,) + + tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) + + + for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])): + dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale + tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) + + gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + gdK_tile = gdK[None, None, n_block] + + tdKgdK = thr_mma_dsq.partition_C(gdK_tile) + tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) + tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) + + cute.copy(tiled_gmem_store_dK, tdKrdK_r2s , tdKgdK_r2g) + + pipeline_dK.consumer_release(dK_consumer_state); dK_consumer_state.advance() + + + @cute.jit + def epilogue_dK_or_dV_tma( + self, + tidx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma: cute.core.ThrMma, + tdKVtdKV: cute.Tensor, + mdKV: cute.Tensor, + sdKV: cute.Tensor, + tma_atom_dKV: cute.CopyAtom, + thr_copy_r2s_dKdV: cute.TiledCopy, + pipeline: PipelineAsync, + softmax_scale : Float32, + do_scale : cutlass.Constexpr[cutlass.Boolean], + barrier_id : Int32, + mdKV_semaphore : Optional[cute.Tensor], + ): + # assumes mma_tiler_pdo = mma_tiler_dsq = (n_block_size, head_dim) + # head_dim = head_dim_v, dk_dtype = dv_dtype + + wg_idx = (cute.arch.thread_idx()[0] % self.num_compute_threads) // 128 + num_wg = (self.num_compute_threads // 128) + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + sdKV = sdKV[None, None, wg_idx] + + head_idx_kv = head_idx // self.qhead_per_kvhead + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] + + gdKV_p = cute.local_tile(mdKV_cur, (self.m_block_size, self.head_dim_v_padded), (n_block, 0)) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) + gdKV_epi = cute.local_tile(gdKV, self.sdKdV_epi_tile, (0, None)) + + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] + + # (TMA) and (TMA, EPI_STAGE) + tdKVsdKV, tdKVgdKV = cpasync.tma_partition( + tma_atom_dKV, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sdKV, 0, 2), + cute.group_modes(gdKV_epi, 0, 2), + ) + + assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" + assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" + + num_epi_stages = cute.size(tdKVgdKV.shape[1]) + assert num_epi_stages == 1 or num_epi_stages == 2, "Wrong number of epi stages" + + tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + + if const_expr(self.deterministic): + read_flag = False + else: + read_flag = True + + # TODO: maybe support more than 1 stage + consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, 1) + pipeline.consumer_wait(consumer_state) + + # semaphore acquire + if const_expr(self.deterministic): + barrier.wait_eq(mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + for s in cutlass.range_constexpr(num_epi_stages): + + # TMEM -> RMEM -- setup + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdKVtdKV) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + tdKVtdKV_t2r_p = thr_tmem_ld.partition_S(tdKVtdKV) + tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] + + cdKV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tdKVcdKV = thr_mma.partition_C(cdKV) + tdKVcdKV_t2r_p = thr_tmem_ld.partition_D(tdKVcdKV) + tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVcdKV_t2r = tdKVcdKV_t2r[None, s] + + tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) + + assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, "RMEM<->TMEM fragment size mismatch" + + # TMEM -> RMEM -- copy and fence + cute.copy(thr_tmem_ld, tdKVtdKV_t2r, tdKVrdKV_t2r) + cute.arch.fence_view_async_tmem_load() + + # RMEM -- scale and convert + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) + if const_expr(do_scale): + scale = softmax_scale + else: + scale = Float32(1) + + dKV_vec = tdKVrdKV_t2r.load() * scale + tdKVrdKV.store(dKV_vec.to(self.dv_dtype)) + + # RMEM -> SMEM -- setup + tdKVcdKV_r2s_p = thr_copy_r2s_dKdV.partition_S(cdKV) + tdKVcdKV_r2s = self.split_wg(tdKVcdKV_r2s_p, wg_idx, num_wg) + tdKVcdKV_r2s = cute.logical_divide( + tdKVcdKV_r2s, + (tdKVcdKV_r2s.shape[0], tdKVcdKV_r2s.shape[1], tdKVcdKV_r2s.shape[2] // num_epi_stages) + )[((None, 0), (None, 0), (None, s))] + + tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVcdKV_r2s.shape) + + tdKVsdKV_r2s = thr_copy_r2s_dKdV.partition_D(sdKV) + + assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), "RMEM<->SMEM fragment size mismatch" + + # RMEM -> SMEM -- copy, fence and barrier + cute.copy(thr_copy_r2s_dKdV, tdKVrdKV_r2s, tdKVsdKV_r2s) + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + # SMEM -> GMEM + if leader_warp: + cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, s]) + if s < num_epi_stages - 1: + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier_arrive(barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE) + + # Barrier since all warps need to wait for SMEM to be freed + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(self.deterministic): + if leader_warp: + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1) + + pipeline.consumer_release(consumer_state) + consumer_state.advance() + + + @cute.jit + def load_M_tile( + self, + tma_atom: cute.CopyAtom, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, + pipeline: PipelineAsync, + block: cutlass.Int32, + producer_state: cutlass.pipeline.PipelineState, + ): + pipeline.producer_acquire(producer_state) + cute.copy( + tma_atom, + tQgQ[None, block], + tQsQ[None, producer_state.index], + tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + ) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index b7e3d7c66ea..25c69a69bc0 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -280,3 +280,49 @@ def apply_mask_sm100( if col_idx >= col_limit_right or col_idx < col_limit_left else acc_S[i] ) + + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r : cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + wg_idx: cutlass.Int32, + num_wg: cutlass.Constexpr[cutlass.Int32], + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, + ) -> None: + ''' + Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. + ''' + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + + tidx = cute.arch.thread_idx()[0] % 128 + + seqlenk_row_limit = self.seqlen_k - n_block * self.tile_n + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if tScS_t2r[0][0] >= seqlenk_row_limit: + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + else: # Causal or local + causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m + row_idx = tScS_t2r[0][0] + n_block * self.tile_n + + if cutlass.const_expr(mask_causal): + col_limit_left = row_idx + causal_row_offset + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + # if tidx == 32 and wg_idx == 1: + # cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1]) + if cutlass.const_expr(mask_seqlen): + if tScS_t2r[0][0] >= seqlenk_row_limit: + col_limit_left = self.tile_m + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] <= col_limit_left else acc_S[i] + ) + # TODO: local \ No newline at end of file diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 1000c0a47bc..48229ccd25d 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -22,3 +22,9 @@ class NamedBarrierBwd(enum.IntEnum): dQFullWG1 = enum.auto() dQEmptyWG0 = enum.auto() dQEmptyWG1 = enum.auto() + +class NamedBarrierBwdSm100(enum.IntEnum): + EpilogueWG1 = enum.auto() + EpilogueWG2 = enum.auto() + Compute = enum.auto() + dQaccReduce = enum.auto() \ No newline at end of file From 5fa6e8d5d6f3d3bd614c1e1132342c52b821981e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 14:36:41 -0400 Subject: [PATCH 155/258] [Cute,Bwd,Sm100] Format flash_bwd_sm100.py and flash_bwd_postprocess --- flash_attn/cute/flash_bwd_postprocess.py | 134 +- flash_attn/cute/flash_bwd_sm100.py | 1585 +++++++++++++--------- 2 files changed, 1039 insertions(+), 680 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index a2d9e93b547..8088997fd26 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -389,6 +389,7 @@ def kernel( pred=tdQpdQ[None, rest_m, None], ) + class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess): def __init__( self, @@ -402,7 +403,7 @@ def __init__( super().__init__( dtype=dtype, head_dim=head_dim, - arch=90, # tmp dummy placement for now + arch=90, # tmp dummy placement for now tile_m=m_block_size, num_threads=num_threads, AtomLayoutMdQ=AtomLayoutMdQ, @@ -412,7 +413,9 @@ def __init__( def _setup_attributes(self): self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128 - self.sdQaccum_layout = cute.make_layout(shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32)) + self.sdQaccum_layout = cute.make_layout( + shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32) + ) self.epi_tile_q = (self.tile_m, self.tile_hdim) self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( self.dtype, @@ -425,9 +428,9 @@ def _setup_attributes(self): def __call__( self, mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - scale: cutlass.Float32, - stream: cuda.CUstream, + mdQ: cute.Tensor, + scale: cutlass.Float32, + stream: cuda.CUstream, ): # (b, h, s*d) -> (s*d, h, b) mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0])) @@ -445,11 +448,11 @@ def __call__( cta_group = tcgen05.CtaGroup.ONE self.mma_tiler_dsk = (self.tile_m, self.tile_hdim) - dS_major_mode = tcgen05.OperandMajorMode.MN + dS_major_mode = tcgen05.OperandMajorMode.MN kt_major_mode_dsq = tcgen05.OperandMajorMode.MN tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( - cutlass.BFloat16 , + cutlass.BFloat16, dS_major_mode, kt_major_mode_dsq, cutlass.Float32, @@ -467,16 +470,17 @@ def __call__( ) buffer_align_bytes = 1024 + @cute.struct class SharedStorage: - sdQaccum: cute.struct.Align[ - cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)], - 128, + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)], + 128, ] - sdQ: cute.struct.Align[ - cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)], - buffer_align_bytes, + sdQ: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)], + buffer_align_bytes, ] self.shared_storage = SharedStorage @@ -495,16 +499,17 @@ class SharedStorage: smem=self.shared_storage.size_in_bytes(), stream=stream, ) + @cute.kernel def kernel( self, - mdQaccum: cute.Tensor, - mdQ: cute.Tensor, - tma_atom_dQ: cute.CopyAtom, - sdQaccum_layout: cute.Layout, - sdQ_layout: cute.ComposedLayout, - tiled_mma_dsk: cute.TiledMma, - scale: cutlass.Float32, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + tiled_mma_dsk: cute.TiledMma, + scale: cutlass.Float32, ): tidx = cute.arch.thread_idx()[0] warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -513,43 +518,53 @@ def kernel( # SMEM smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - swz128 = cute.make_swizzle(3, 4, 3) + swz128 = cute.make_swizzle(3, 4, 3) sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner) mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - mdQ_cur = mdQ[None, None, head_idx, batch_idx] + mdQ_cur = mdQ[None, None, head_idx, batch_idx] thr_mma_dsk = tiled_mma_dsk.get_slice(tidx) dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) - tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) - tdQtdQ = cute.make_tensor(tdQtdQ.iterator , tdQtdQ.layout) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout) - tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32) + tmem_ld_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32 + ) tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) - cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) - tdQcdQ = thr_mma_dsk.partition_C(cdQ) + cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) - tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) - gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim, ) , (m_block, )) + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) num_reduce_warps = 4 num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps - - atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128) - tiler_mn, layout_tv = cute.make_layout_tv(thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1), val_layout=cute.make_layout(shape=4, stride=1)) - G2S_tiled_copy_dQaccum = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128 + ) + tiler_mn, layout_tv = cute.make_layout_tv( + thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1), + val_layout=cute.make_layout(shape=4, stride=1), + ) + G2S_tiled_copy_dQaccum = cute.make_tiled_copy( + atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn + ) smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx) # S->R tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) - tiled_smem_store_s2r = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) + tiled_smem_store_s2r = cute.make_tiled_copy( + atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn + ) s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx) tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum) @@ -567,45 +582,62 @@ def kernel( tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) - num_stages = cute.size(tdQrdQ_t2r, mode=[1]) for stage in cutlass.range_constexpr(num_stages): - # G->S - gdQaccum_stage = cute.local_tile(gdQaccum, (self.tile_m * 32, ), (stage, ),) + gdQaccum_stage = cute.local_tile( + gdQaccum, + (self.tile_m * 32,), + (stage,), + ) gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0)) - gdQaccum_stage_g2s = cute.make_tensor(cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s) + gdQaccum_stage_g2s = cute.make_tensor( + cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s + ) tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s) tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum) cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0]) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) # S -> R tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None] - tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0] - tdQrdQ_r2s_cpy = cute.make_tensor(tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape)) + tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0] + tdQrdQ_r2s_cpy = cute.make_tensor( + tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape) + ) cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) # R->S - tdQrdQ_r2s_cpy = cute.make_tensor(cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape) - dQ_vec = tdQrdQ_r2s_cpy.load() * scale + tdQrdQ_r2s_cpy = cute.make_tensor( + cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), + tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape, + ) + dQ_vec = tdQrdQ_r2s_cpy.load() * scale tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype)) - - cute.copy(tiled_smem_store_r2s, tdQrdQ_r2s[None, None, None, None, 0], tdQsdQ_r2s[None, None, None, None, 0]) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.copy( + tiled_smem_store_r2s, + tdQrdQ_r2s[None, None, None, None, 0], + tdQsdQ_r2s[None, None, None, None, 0], + ) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) - # S-> G gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) tdQsdQ, tdQgdQ = cpasync.tma_partition( @@ -613,9 +645,7 @@ def kernel( 0, cute.make_layout(1), cute.group_modes(sdQ, 0, 2), - cute.group_modes(gdQ, 0, 2) + cute.group_modes(gdQ, 0, 2), ) cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block]) - - diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 69ea1f04847..86afbf8f105 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1,27 +1,25 @@ -from ctypes import alignment -import enum import math -from typing import Type, Tuple, Callable, Optional +from typing import Callable, Optional from functools import partial import cuda.bindings.driver as cuda import cutlass -from cutlass._mlir.ir import _si1Attr -from cutlass.base_dsl.jit_executor import t import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic -import flash_attn.cute.utils as utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfo, SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + ParamsBase, +) from cutlass.pipeline import PipelineAsync from cutlass._mlir.dialects import llvm @@ -35,11 +33,8 @@ @dsl_user_op def tma_reduce_add_bulk_f32( - smem_ptr: cute.Pointer, - gmem_ptr: cute.Pointer, - store_bytes: cutlass.Int32, - *, loc=None, ip=None - ): + smem_ptr: cute.Pointer, gmem_ptr: cute.Pointer, store_bytes: cutlass.Int32, *, loc=None, ip=None +): cute.make_mma_atom smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( @@ -68,7 +63,6 @@ def __init__( is_persistent: bool = False, deterministic: bool = False, ): - # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) @@ -76,7 +70,9 @@ def __init__( self.same_hdim_kv = head_dim == head_dim_v assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) - assert self.head_dim_padded == self.head_dim_v_padded, "head_dim_padded and head_dim_v_padded must be the same for now" + assert self.head_dim_padded == self.head_dim_v_padded, ( + "head_dim_padded and head_dim_v_padded must be the same for now" + ) self.check_hdim_oob = head_dim != self.head_dim_padded self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded @@ -86,10 +82,10 @@ def __init__( self.dQaccum_reduce_stage = self.head_dim_padded // 32 # CTA tiler - self.cta_tiler = (m_block_size, n_block_size, self.head_dim_padded) + self.cta_tiler = (m_block_size, n_block_size, self.head_dim_padded) # S = K @ Q.T - self.mma_tiler_kq = (n_block_size, m_block_size, self.head_dim_padded) + self.mma_tiler_kq = (n_block_size, m_block_size, self.head_dim_padded) # dP = V @ dO.T self.mma_tiler_vdo = (n_block_size, m_block_size, self.head_dim_v_padded) @@ -103,8 +99,9 @@ def __init__( # dQ = dS @ K self.mma_tiler_dsk = (m_block_size, self.head_dim_v_padded, n_block_size) - - self.kq_acc_dtype = self.vdo_acc_dtype = self.pdo_acc_dtype = self.dsq_acc_dtype = self.dsk_acc_dtype = Float32 + self.kq_acc_dtype = self.vdo_acc_dtype = self.pdo_acc_dtype = self.dsq_acc_dtype = ( + self.dsk_acc_dtype + ) = Float32 self.cluster_shape_mn = (1, 1) self.is_persistent = is_persistent @@ -138,12 +135,12 @@ def __init__( SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS - self.tmem_s_offset = 0 - self.tmem_p_offset = 0 # overlap with S - self.tmem_dV_offset = self.tmem_s_offset + self.n_block_size - self.tmem_dP_offset = self.tmem_dV_offset + self.head_dim_v_padded - self.tmem_dQaccum_offset = self.tmem_dP_offset # overlap with dP - self.tmem_dK_offset = self.tmem_dP_offset + self.m_block_size + self.tmem_s_offset = 0 + self.tmem_p_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_s_offset + self.n_block_size + self.tmem_dP_offset = self.tmem_dV_offset + self.head_dim_v_padded + self.tmem_dQaccum_offset = self.tmem_dP_offset # overlap with dP + self.tmem_dK_offset = self.tmem_dP_offset + self.m_block_size self.num_regs_reduce = 144 self.num_regs_compute = 128 @@ -156,49 +153,47 @@ def __init__( self.num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) def _setup_attributes(self): - - self.q_stage = 2 - self.k_stage = 1 - self.v_stage = 1 - self.do_stage = 1 - self.ds_stage = 1 - self.lse_stage = 1 - self.acc_stage = 1 - self.s_stage = 1 - self.dP_stage = 1 - self.dV_stage = 1 - self.dK_stage = 1 - self.dS_stage = 1 + self.q_stage = 2 + self.k_stage = 1 + self.v_stage = 1 + self.do_stage = 1 + self.ds_stage = 1 + self.lse_stage = 1 + self.acc_stage = 1 + self.s_stage = 1 + self.dP_stage = 1 + self.dV_stage = 1 + self.dK_stage = 1 + self.dS_stage = 1 self.dQaccum_mma_stage = 1 - self.sdQaccum_stage = 2 - self.psum_stage = 1 - self.p_tmem_stage = 1 + self.sdQaccum_stage = 2 + self.psum_stage = 1 + self.p_tmem_stage = 1 self.sdKdVaccum_stage = 2 - @cute.jit def __call__( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mdO: cute.Tensor, - mLSE: cute.Tensor, - mPsum: cute.Tensor, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mPsum: cute.Tensor, mdQaccum: cute.Tensor, - mdK: cute.Tensor, - mdV: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, softmax_scale: Float32, stream: cuda.CUstream, mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, ): - self.q_dtype = mQ.element_type - self.k_dtype = mK.element_type - self.v_dtype = mV.element_type + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type self.do_dtype = mdO.element_type - self.lse_dtype = mLSE.element_type + self.lse_dtype = mLSE.element_type self.psum_dtype = mPsum.element_type self.dqaccum_dtype = mdQaccum.element_type self.dk_dtype = mdK.element_type @@ -209,25 +204,29 @@ def __call__( assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" - QKVdO_layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + QKVdO_layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) mQ, mK, mV, mdO, mdK, mdV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QKVdO_layout_transpose)) for t in (mQ, mK, mV, mdO, mdK, mdV) ] - LSE_Psum_dQaccum_layout_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + LSE_Psum_dQaccum_layout_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) mLSE, mPsum, mdQaccum = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_Psum_dQaccum_layout_transpose)) + cute.make_tensor( + t.iterator, cute.select(t.layout, mode=LSE_Psum_dQaccum_layout_transpose) + ) for t in (mLSE, mPsum, mdQaccum) ] - dO_transpose = [1, 0, 2, 3] + dO_transpose = [1, 0, 2, 3] mdO = cute.make_tensor(mdO.iterator, cute.select(mdO.layout, mode=dO_transpose)) - semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) + semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) if const_expr(self.deterministic): assert mdQ_semaphore is not None - mdQ_semaphore = cute.make_tensor(mdQ_semaphore.iterator, cute.select(mdQ_semaphore.layout, mode=semaphore_transpose)) + mdQ_semaphore = cute.make_tensor( + mdQ_semaphore.iterator, cute.select(mdQ_semaphore.layout, mode=semaphore_transpose) + ) else: mdQ_semaphore = None @@ -242,10 +241,10 @@ def __call__( mdK_semaphore = None mdV_semaphore = None - self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() - self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() - self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() - self.do_major_mode = cutlass.utils.LayoutEnum.from_tensor(mdO).mma_major_mode() + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.do_major_mode = cutlass.utils.LayoutEnum.from_tensor(mdO).mma_major_mode() self._setup_attributes() cta_group = tcgen05.CtaGroup.ONE @@ -262,7 +261,7 @@ def __call__( # dV += P @ dO --> (K, MN) major p_source = tcgen05.OperandSource.TMEM - self.p_major_mode = tcgen05.OperandMajorMode.K + self.p_major_mode = tcgen05.OperandMajorMode.K tiled_mma_pdo = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, self.p_major_mode, @@ -285,8 +284,8 @@ def __call__( ) # dK += dS.T @ Q - self.dSt_major_mode = tcgen05.OperandMajorMode.K - self.q_major_mode_dsq = tcgen05.OperandMajorMode.MN + self.dSt_major_mode = tcgen05.OperandMajorMode.K + self.q_major_mode_dsq = tcgen05.OperandMajorMode.MN tiled_mma_dsq = sm100_utils_basic.make_trivial_tiled_mma( self.ds_dtype, self.dSt_major_mode, @@ -297,7 +296,7 @@ def __call__( ) # dQ = dS @ K - self.dS_major_mode = tcgen05.OperandMajorMode.MN + self.dS_major_mode = tcgen05.OperandMajorMode.MN self.kt_major_mode_dsq = tcgen05.OperandMajorMode.MN tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( self.ds_dtype, @@ -315,46 +314,81 @@ def __call__( # S = K @ Q.T sK_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_kq, self.mma_tiler_kq, self.k_dtype, self.k_stage, + tiled_mma_kq, + self.mma_tiler_kq, + self.k_dtype, + self.k_stage, ) sQ_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_kq, self.mma_tiler_kq, self.q_dtype, self.q_stage, + tiled_mma_kq, + self.mma_tiler_kq, + self.q_dtype, + self.q_stage, ) # dV += P @ dO sdO_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pdo, self.mma_tiler_pdo, self.do_dtype, self.do_stage, + tiled_mma_pdo, + self.mma_tiler_pdo, + self.do_dtype, + self.do_stage, ) # dP = V @ dO.T sV_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_vdo, self.mma_tiler_vdo, self.v_dtype, self.v_stage, + tiled_mma_vdo, + self.mma_tiler_vdo, + self.v_dtype, + self.v_stage, ) sdOt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_vdo, self.mma_tiler_vdo, self.do_dtype, self.do_stage, + tiled_mma_vdo, + self.mma_tiler_vdo, + self.do_dtype, + self.do_stage, ) # dK += dS.T @ Q sdSt_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dsq, self.mma_tiler_dsq, self.ds_dtype, self.ds_stage, + tiled_mma_dsq, + self.mma_tiler_dsq, + self.ds_dtype, + self.ds_stage, ) sQt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_dsq, self.mma_tiler_dsq, self.q_dtype, self.q_stage, + tiled_mma_dsq, + self.mma_tiler_dsq, + self.q_dtype, + self.q_stage, ) # dQaccum = dS @ K sdS_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dsk, self.mma_tiler_dsk, self.q_dtype, self.ds_stage, + tiled_mma_dsk, + self.mma_tiler_dsk, + self.q_dtype, + self.ds_stage, ) sKt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_dsk, self.mma_tiler_dsk, self.k_dtype, self.k_stage, + tiled_mma_dsk, + self.mma_tiler_dsk, + self.k_dtype, + self.k_stage, ) - sdQaccum_layout = cute.make_layout(shape=(self.m_block_size * 32, self.sdQaccum_stage ),) - sLSE_layout = cute.make_layout(shape=(self.m_block_size, self.lse_stage), stride=(1, cute.round_up(self.m_block_size, 64))) - sPsum_layout = cute.make_layout(shape=(self.m_block_size, self.psum_stage), stride=(1, cute.round_up(self.m_block_size, 64))) + sdQaccum_layout = cute.make_layout( + shape=(self.m_block_size * 32, self.sdQaccum_stage), + ) + sLSE_layout = cute.make_layout( + shape=(self.m_block_size, self.lse_stage), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) + sPsum_layout = cute.make_layout( + shape=(self.m_block_size, self.psum_stage), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) self.mdK_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdK) self.mdV_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdV) @@ -364,12 +398,20 @@ def __call__( raise RuntimeError("The layout of mdK is wrong") if const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - self.sdKdV_epi_tile = (self.n_block_size, 128 // (self.dk_dtype.width // 8)) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + self.sdKdV_epi_tile = ( + self.n_block_size, + 128 // (self.dk_dtype.width // 8), + ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] sdKdV_layout = sm100_utils_basic.make_smem_layout_epi( - self.dk_dtype, self.mdK_layout_enum, self.sdKdV_epi_tile, self.sdKdVaccum_stage, + self.dk_dtype, + self.mdK_layout_enum, + self.sdKdV_epi_tile, + self.sdKdVaccum_stage, ) - self.tma_copy_dKdV_bytes = cute.size_in_bytes(self.dk_dtype, cute.select(sdKdV_layout, mode=[0,1])) + self.tma_copy_dKdV_bytes = cute.size_in_bytes( + self.dk_dtype, cute.select(sdKdV_layout, mode=[0, 1]) + ) if const_expr(self.use_tma_store): if const_expr(self.dk_dtype.width == 32): @@ -382,14 +424,14 @@ def __call__( mdK, cute.select(sdKdV_layout, mode=[0, 1]), self.sdKdV_epi_tile, - 1 # no mcast + 1, # no mcast ) tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKdV, mdV, cute.select(sdKdV_layout, mode=[0, 1]), self.sdKdV_epi_tile, - 1 # no mcast + 1, # no mcast ) else: assert self.qhead_per_kvhead == 1, "Must use TMA reduce add for GQA" @@ -398,12 +440,22 @@ def __call__( tma_atom_dV = None tma_atom_dK = None - thr_layout_r2s_dKdV = cute.make_ordered_layout((self.n_block_size, 1), order=(1,0)) # 128 threads - val_layout_r2s_dKdV = cute.make_ordered_layout((1, 128 // self.dk_dtype.width), order=(1,0)) # 4 or 8 vals for 16 byte store - r2s_copy_atom_r2s_dKdV = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=128,) - tiled_copy_r2s_dKdV = cute.make_tiled_copy_tv(r2s_copy_atom_r2s_dKdV, thr_layout_r2s_dKdV, val_layout_r2s_dKdV) + thr_layout_r2s_dKdV = cute.make_ordered_layout( + (self.n_block_size, 1), order=(1, 0) + ) # 128 threads + val_layout_r2s_dKdV = cute.make_ordered_layout( + (1, 128 // self.dk_dtype.width), order=(1, 0) + ) # 4 or 8 vals for 16 byte store + r2s_copy_atom_r2s_dKdV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dKdV = cute.make_tiled_copy_tv( + r2s_copy_atom_r2s_dKdV, thr_layout_r2s_dKdV, val_layout_r2s_dKdV + ) - tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) # S = K @ Q.T tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( @@ -437,13 +489,13 @@ def __call__( tma_load_op, mLSE, cute.make_layout((self.m_block_size)), - (self.m_block_size, ), + (self.m_block_size,), ) tma_atom_Psum, tma_tensor_Psum = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_load_op, mPsum, cute.make_layout((self.m_block_size)), - (self.m_block_size, ), + (self.m_block_size,), ) # dP = V @ dO.T @@ -456,18 +508,26 @@ def __call__( self.cluster_layout_vmnk.shape, ) - self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) - self.tma_copy_k_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) - self.tma_copy_v_bytes = cute.size_in_bytes(self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) - self.tma_copy_do_bytes = cute.size_in_bytes(self.do_dtype, cute.select(sdO_layout, mode=[0, 1, 2])) - self.tma_copy_lse_bytes = self.m_block_size * 4 + self.tma_copy_q_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2]) + ) + self.tma_copy_k_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2]) + ) + self.tma_copy_v_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2]) + ) + self.tma_copy_do_bytes = cute.size_in_bytes( + self.do_dtype, cute.select(sdO_layout, mode=[0, 1, 2]) + ) + self.tma_copy_lse_bytes = self.m_block_size * 4 self.tma_copy_psum_bytes = self.m_block_size * 4 TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), - cute.size(mQ.shape[2]), # num_heads = num_query_heads + cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), cute.size(mK.shape[0]), mQ.shape[1], @@ -489,63 +549,63 @@ def __call__( @cute.struct class SharedStorage: - q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] - k_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - v_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.v_stage] - lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] - do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] - lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - psum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] - psum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] - s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] - dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] - p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] - dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.ds_stage] - dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] - dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] - dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] - dQaccum_reduce_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] + k_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + v_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.v_stage] + lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] + do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] + lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] + psum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] + psum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] + s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] + p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.ds_stage] + dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] + dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] + dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] + dQaccum_reduce_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] # TMEM tmem_holding_buf: Int32 - tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] # Smem tensors - sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], - self.buffer_align_bytes, + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, ] - sK: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], - self.buffer_align_bytes, + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, ] - sV: cute.struct.Align[ - cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], - self.buffer_align_bytes, + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, ] - sdO: cute.struct.Align[ - cute.struct.MemRange[self.do_dtype, cute.cosize(sdO_layout)], - self.buffer_align_bytes, + sdO: cute.struct.Align[ + cute.struct.MemRange[self.do_dtype, cute.cosize(sdO_layout)], + self.buffer_align_bytes, ] - sdS: cute.struct.Align[ - cute.struct.MemRange[self.ds_dtype, cute.cosize(sdSt_layout)], - 128, + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(sdSt_layout)], + 128, ] sLSE: cute.struct.Align[ - cute.struct.MemRange[self.lse_dtype, cute.cosize(sLSE_layout)], - 128, + cute.struct.MemRange[self.lse_dtype, cute.cosize(sLSE_layout)], + 128, ] sPsum: cute.struct.Align[ - cute.struct.MemRange[self.psum_dtype, cute.cosize(sPsum_layout)], - 128, + cute.struct.MemRange[self.psum_dtype, cute.cosize(sPsum_layout)], + 128, ] sdQaccum: cute.struct.Align[ - cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(sdQaccum_layout)], - self.buffer_align_bytes, + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(sdQaccum_layout)], + self.buffer_align_bytes, ] - self.shared_storage = SharedStorage + self.shared_storage = SharedStorage LOG2_E = math.log2(math.e) softmax_scale_log2 = softmax_scale * LOG2_E @@ -603,52 +663,51 @@ class SharedStorage: min_blocks_per_mp=1, ) - @cute.kernel def kernel( self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mLSE: cute.Tensor, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, mPsum: cute.Tensor, - mdO: cute.Tensor, - mdV: cute.Tensor, - mdK: cute.Tensor, + mdO: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, mdQaccum: cute.Tensor, mdV_tma_tensor: Optional[cute.Tensor], mdK_tma_tensor: Optional[cute.Tensor], mdQ_semaphore: Optional[cute.Tensor], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, - tma_atom_LSE: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_LSE: cute.CopyAtom, tma_atom_Psum: cute.CopyAtom, - tma_atom_dO: cute.CopyAtom, - tma_atom_dV: Optional[cute.CopyAtom], - tma_atom_dK: Optional[cute.CopyAtom], - sQ_layout: cute.ComposedLayout, - sQt_layout: cute.ComposedLayout, - sK_layout: cute.ComposedLayout, - sV_layout: cute.ComposedLayout, - sLSE_layout: cute.Layout, - sPsum_layout: cute.Layout, - sdO_layout: cute.ComposedLayout, - sdOt_layout: cute.ComposedLayout, - sdSt_layout: cute.ComposedLayout, - sdS_layout: cute.ComposedLayout, - sKt_layout: cute.ComposedLayout, + tma_atom_dO: cute.CopyAtom, + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + sQ_layout: cute.ComposedLayout, + sQt_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sPsum_layout: cute.Layout, + sdO_layout: cute.ComposedLayout, + sdOt_layout: cute.ComposedLayout, + sdSt_layout: cute.ComposedLayout, + sdS_layout: cute.ComposedLayout, + sKt_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - sdKdV_layout: cute.ComposedLayout, - tiled_mma_kq: cute.TiledMma, - tiled_mma_pdo: cute.TiledMma, - tiled_mma_vdo: cute.TiledMma, - tiled_mma_dsq: cute.TiledMma, - tiled_mma_dsk: cute.TiledMma, + sdKdV_layout: cute.ComposedLayout, + tiled_mma_kq: cute.TiledMma, + tiled_mma_pdo: cute.TiledMma, + tiled_mma_vdo: cute.TiledMma, + tiled_mma_dsq: cute.TiledMma, + tiled_mma_dsk: cute.TiledMma, tiled_copy_r2s_dKdV: cute.TiledCopy, - softmax_scale: cutlass.Float32, + softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, tile_sched_params: ParamsBase, ): @@ -669,30 +728,36 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_dK) # Alloc - smem = cutlass.utils.SmemAllocator() + smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - k_full_mbar_ptr = storage.k_full_mbar_ptr.data_ptr() - v_full_mbar_ptr = storage.v_full_mbar_ptr.data_ptr() + k_full_mbar_ptr = storage.k_full_mbar_ptr.data_ptr() + v_full_mbar_ptr = storage.v_full_mbar_ptr.data_ptr() tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() - lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() - lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() - psum_full_mbar_ptr = storage.psum_full_mbar_ptr.data_ptr() - psum_empty_mbar_ptr = storage.psum_empty_mbar_ptr.data_ptr() - dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() + lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() + lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() + psum_full_mbar_ptr = storage.psum_full_mbar_ptr.data_ptr() + psum_empty_mbar_ptr = storage.psum_empty_mbar_ptr.data_ptr() + dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() if warp_idx == self.load_warp_id: - cute.arch.mbarrier_init(k_full_mbar_ptr, len([self.load_warp_id])) - cute.arch.mbarrier_init(v_full_mbar_ptr, len([self.load_warp_id])) - cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) - cute.arch.mbarrier_init(lse_full_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(psum_full_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(psum_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(k_full_mbar_ptr, len([self.load_warp_id])) + cute.arch.mbarrier_init(v_full_mbar_ptr, len([self.load_warp_id])) + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + ) + cute.arch.mbarrier_init(lse_full_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(psum_full_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(psum_empty_mbar_ptr, len([self.compute_warp_ids])) cute.arch.mbarrier_init(dQaccum_reduce_mbar_ptr, 1) - pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id])) - pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) + pipeline_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) pipeline_q = cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=storage.q_mbar_ptr.data_ptr(), @@ -711,8 +776,12 @@ def kernel( ) # UMMA producers and AsyncThread consumers - pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) - pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + ) pipeline_s = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=self.s_stage, @@ -732,7 +801,11 @@ def kernel( consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dK_mbar_ptr.data_ptr(), ) - pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.reduce_warp_ids), alignment=128) # Compute + pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, + cute.arch.WARP_SIZE * len(self.reduce_warp_ids), + alignment=128, + ) # Compute pipeline_dQaccum = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=self.dQaccum_mma_stage, producer_group=pipeline_producer_group_MMA_AsyncThread, @@ -747,8 +820,12 @@ def kernel( ) # AsyncThread producers and UMMA consumers - pipeline_pdS_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids)) # Compute - pipeline_pdS_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) # MMA + pipeline_pdS_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + ) # Compute + pipeline_pdS_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) # MMA pipeline_p = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=self.s_stage, @@ -764,95 +841,118 @@ def kernel( barrier_storage=storage.dS_mbar_ptr.data_ptr(), ) - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sQt = cute.make_tensor(cute.recast_ptr(sQ.iterator, swizzle_=sQt_layout.inner), sQt_layout.outer) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQt = cute.make_tensor( + cute.recast_ptr(sQ.iterator, swizzle_=sQt_layout.inner), sQt_layout.outer + ) sQ_pi = storage.sQ.get_tensor(sQ_layout) - sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, swizzle_=sKt_layout.inner), sKt_layout.outer) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sKt = cute.make_tensor( + cute.recast_ptr(sK.iterator, swizzle_=sKt_layout.inner), sKt_layout.outer + ) - sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) + sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) sdSt_pi = storage.sdS.get_tensor(sdSt_layout) - sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, swizzle_=sdS_layout.inner), sdS_layout.outer) + sdS = cute.make_tensor( + cute.recast_ptr(sdSt.iterator, swizzle_=sdS_layout.inner), sdS_layout.outer + ) - sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) - sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, swizzle_=sdOt_layout.inner), sdOt_layout.outer) + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sdOt = cute.make_tensor( + cute.recast_ptr(sdO.iterator, swizzle_=sdOt_layout.inner), sdOt_layout.outer + ) sLSE_load = storage.sLSE.get_tensor(sLSE_layout) - sLSE_mma = storage.sLSE.get_tensor(cute.make_layout( - shape=(self.m_block_size, self.n_block_size, self.lse_stage), - stride=(0, 1, 0) - )) - + sLSE_mma = storage.sLSE.get_tensor( + cute.make_layout( + shape=(self.m_block_size, self.n_block_size, self.lse_stage), stride=(0, 1, 0) + ) + ) sPsum_load = storage.sPsum.get_tensor(sPsum_layout) - sPsum_mma = storage.sPsum.get_tensor(cute.make_layout( - shape=(self.m_block_size, self.n_block_size, self.psum_stage), - stride=(0, 1, 0) - )) + sPsum_mma = storage.sPsum.get_tensor( + cute.make_layout( + shape=(self.m_block_size, self.n_block_size, self.psum_stage), stride=(0, 1, 0) + ) + ) - sdV = storage.sdO.get_tensor(sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype) - sdK = storage.sQ.get_tensor(sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype) + sdV = storage.sdO.get_tensor( + sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype + ) + sdK = storage.sQ.get_tensor( + sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype + ) - assert cute.cosize(sdV) * self.dv_dtype.width <= cute.cosize(sdO) * self.do_dtype.width, "Not enough space for sdV" - assert cute.cosize(sdK) * self.dk_dtype.width <= cute.cosize(sQ) * self.q_dtype.width, "Not enough space for sdK" + assert cute.cosize(sdV) * self.dv_dtype.width <= cute.cosize(sdO) * self.do_dtype.width, ( + "Not enough space for sdV" + ) + assert cute.cosize(sdK) * self.dk_dtype.width <= cute.cosize(sQ) * self.q_dtype.width, ( + "Not enough space for sdK" + ) swz128 = cute.make_swizzle(3, 4, 3) sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) # TMEM # S - thr_mma_kq = tiled_mma_kq.get_slice(0) - Sacc_shape = thr_mma_kq.partition_shape_C(self.mma_tiler_kq[:2]) #(M, N) - tStS = thr_mma_kq.make_fragment_C(Sacc_shape) - tStS = cute.make_tensor(tStS.iterator, tStS.layout) + thr_mma_kq = tiled_mma_kq.get_slice(0) + Sacc_shape = thr_mma_kq.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) + tStS = thr_mma_kq.make_fragment_C(Sacc_shape) + tStS = cute.make_tensor(tStS.iterator, tStS.layout) # dV thr_mma_pdo = tiled_mma_pdo.get_slice(0) dvacc_shape = thr_mma_pdo.partition_shape_C(self.mma_tiler_pdo[:2]) - tdVtdV = thr_mma_pdo.make_fragment_C(dvacc_shape) - tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset , tdVtdV.layout) + tdVtdV = thr_mma_pdo.make_fragment_C(dvacc_shape) + tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) # dK thr_mma_dsq = tiled_mma_dsq.get_slice(0) dkacc_shape = thr_mma_dsq.partition_shape_C(self.mma_tiler_dsq[:2]) - tdKtdK = thr_mma_dsq.make_fragment_C(dkacc_shape) - tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset , tdKtdK.layout) + tdKtdK = thr_mma_dsq.make_fragment_C(dkacc_shape) + tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset, tdKtdK.layout) # dQ thr_mma_dsk = tiled_mma_dsk.get_slice(0) dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) - tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) - tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset , tdQtdQ.layout) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset, tdQtdQ.layout) # dP thr_mma_vdo = tiled_mma_vdo.get_slice(0) dPacc_shape = thr_mma_vdo.partition_shape_C(self.mma_tiler_vdo[:2]) - tdPtdP = thr_mma_vdo.make_fragment_C(dPacc_shape) - tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset , tdPtdP.layout) + tdPtdP = thr_mma_vdo.make_fragment_C(dPacc_shape) + tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) block_info = BlockInfo( self.m_block_size, self.n_block_size, - self.is_causal, self.is_local, - None, None, + self.is_causal, + self.is_local, + None, + None, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( SeqlenInfoQK, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=None, mCuSeqlensK=None, - mSeqUsedQ=None, mSeqUsedK=None, + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) # TODO: support local AttentionMaskCls = partial( - AttentionMask, self.m_block_size, self.n_block_size, + AttentionMask, + self.m_block_size, + self.n_block_size, ) cute.arch.sync_threads() @@ -960,7 +1060,9 @@ def kernel( TileSchedulerCls, ) cute.arch.relinquish_tmem_alloc_permit() - tmem_ptr = cute.arch.retrieve_tmem_ptr(Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf) + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf + ) cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) tmem_alloc_cols = Int32(self.tmem_alloc_cols) @@ -969,7 +1071,7 @@ def kernel( # Compute # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: - cute.arch.warpgroup_reg_dealloc(self.num_regs_compute) # 8 warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_compute) # 8 warps self.compute_loop( thr_mma_kq, thr_mma_pdo, @@ -1033,37 +1135,36 @@ def kernel( return - @cute.jit def load( self, - thr_mma_kq: cute.core.ThrMma, + thr_mma_kq: cute.core.ThrMma, thr_mma_pdo: cute.core.ThrMma, thr_mma_vdo: cute.core.ThrMma, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mLSE: cute.Tensor, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, mPsum: cute.Tensor, - mdO: cute.Tensor, - sQ: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - sLSE: cute.Tensor, + mdO: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, sPsum: cute.Tensor, - sdO: cute.Tensor, - tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, - tma_atom_LSE: cute.CopyAtom, + sdO: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_LSE: cute.CopyAtom, tma_atom_Psum: cute.CopyAtom, - tma_atom_dO: cute.CopyAtom, - pipeline_q: PipelineAsync, - lse_full_mbar_ptr: cute.Pointer, - lse_empty_mbar_ptr: cute.Pointer, - psum_full_mbar_ptr: cute.Pointer, + tma_atom_dO: cute.CopyAtom, + pipeline_q: PipelineAsync, + lse_full_mbar_ptr: cute.Pointer, + lse_empty_mbar_ptr: cute.Pointer, + psum_full_mbar_ptr: cute.Pointer, psum_empty_mbar_ptr: cute.Pointer, - pipeline_do: PipelineAsync, + pipeline_do: PipelineAsync, k_full_mbar_ptr: cute.Pointer, v_full_mbar_ptr: cute.Pointer, block_info: BlockInfo, @@ -1073,8 +1174,12 @@ def load( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx = cute.arch.thread_idx()[0] - q_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.q_stage) - do_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.do_stage) + q_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.q_stage + ) + do_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.do_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -1084,11 +1189,11 @@ def load( seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) head_idx_kv = head_idx // self.qhead_per_kvhead - mQ_cur = mQ[None, None, head_idx, batch_idx] - mK_cur = mK[None, None, head_idx_kv, batch_idx] - mV_cur = mV[None, None, head_idx_kv, batch_idx] - mdO_cur = mdO[None, None, head_idx, batch_idx] - mLSE_cur = mLSE[None, head_idx, batch_idx] + mQ_cur = mQ[None, None, head_idx, batch_idx] + mK_cur = mK[None, None, head_idx_kv, batch_idx] + mV_cur = mV[None, None, head_idx_kv, batch_idx] + mdO_cur = mdO[None, None, head_idx, batch_idx] + mLSE_cur = mLSE[None, head_idx, batch_idx] mPsum_cur = mPsum[None, head_idx, batch_idx] gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) @@ -1100,10 +1205,10 @@ def load( gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_kq.partition_B(gQ) - gLSE = cute.local_tile(mLSE_cur, (self.n_block_size, ), (None, )) - gPsum = cute.local_tile(mPsum_cur, (self.n_block_size, ), (None, )) + gLSE = cute.local_tile(mLSE_cur, (self.n_block_size,), (None,)) + gPsum = cute.local_tile(mPsum_cur, (self.n_block_size,), (None,)) - gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) + gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdVgdO = thr_mma_pdo.partition_B(gdO) tKsK, tKgK = cpasync.tma_partition( @@ -1157,10 +1262,10 @@ def load( # Q0 pipeline_q.producer_acquire(q_producer_state) cute.copy( - tma_atom_Q, - tQgQ[None, m_block_max - 1], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=pipeline_q.producer_get_barrier(q_producer_state) + tma_atom_Q, + tQgQ[None, m_block_max - 1], + tQsQ[None, q_producer_state.index], + tma_bar_ptr=pipeline_q.producer_get_barrier(q_producer_state), ) pipeline_q.producer_commit(q_producer_state) q_producer_state.advance() @@ -1187,14 +1292,16 @@ def load( tma_atom_dO, tdOgdO[None, m_block_max - 1], tdOsdO[None, do_producer_state.index], - tma_bar_ptr=pipeline_do.producer_get_barrier(do_producer_state) + tma_bar_ptr=pipeline_do.producer_get_barrier(do_producer_state), ) pipeline_do.producer_commit(do_producer_state) do_producer_state.advance() # Psum with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(psum_full_mbar_ptr, self.tma_copy_psum_bytes) + cute.arch.mbarrier_arrive_and_expect_tx( + psum_full_mbar_ptr, self.tma_copy_psum_bytes + ) cute.copy( tma_atom_Psum, @@ -1209,7 +1316,9 @@ def load( m_block = m_block_max - 2 - i # Q - self.load_M_tile(tma_atom_Q, tQgQ, tQsQ, pipeline_q, m_block, producer_state=q_producer_state) + self.load_M_tile( + tma_atom_Q, tQgQ, tQsQ, pipeline_q, m_block, producer_state=q_producer_state + ) pipeline_q.producer_commit(q_producer_state) q_producer_state.advance() @@ -1218,7 +1327,9 @@ def load( lse_empty_consumer_phase ^= 1 with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(lse_full_mbar_ptr, self.tma_copy_lse_bytes) + cute.arch.mbarrier_arrive_and_expect_tx( + lse_full_mbar_ptr, self.tma_copy_lse_bytes + ) cute.copy( tma_atom_LSE, @@ -1228,7 +1339,14 @@ def load( ) # dO - self.load_M_tile(tma_atom_dO, tdOgdO, tdOsdO, pipeline_do, m_block, producer_state=do_producer_state) + self.load_M_tile( + tma_atom_dO, + tdOgdO, + tdOsdO, + pipeline_do, + m_block, + producer_state=do_producer_state, + ) pipeline_do.producer_commit(do_producer_state) do_producer_state.advance() @@ -1237,7 +1355,9 @@ def load( psum_empty_consumer_phase ^= 1 with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(psum_full_mbar_ptr, self.tma_copy_psum_bytes) + cute.arch.mbarrier_arrive_and_expect_tx( + psum_full_mbar_ptr, self.tma_copy_psum_bytes + ) cute.copy( tma_atom_Psum, @@ -1253,46 +1373,45 @@ def load( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit def mma( self, - tiled_mma_kq: cute.core.TiledMma, + tiled_mma_kq: cute.core.TiledMma, tiled_mma_pdo: cute.core.TiledMma, tiled_mma_vdo: cute.core.TiledMma, tiled_mma_dsq: cute.core.TiledMma, tiled_mma_dsk: cute.core.TiledMma, - thr_mma_kq: cute.core.ThrMma, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_vdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, - thr_mma_dsk: cute.core.ThrMma, - sQ: cute.Tensor, - sQt: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - sdO: cute.Tensor, + thr_mma_kq: cute.core.ThrMma, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_vdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + thr_mma_dsk: cute.core.ThrMma, + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, sdOt: cute.Tensor, sdSt: cute.Tensor, - sdS: cute.Tensor, - sKt: cute.Tensor, + sdS: cute.Tensor, + sKt: cute.Tensor, sK_swizzle: cute.Swizzle, sQ_swizzle: cute.Swizzle, tStS: cute.Tensor, - tdVtdV: cute.Tensor, - tdKtdK: cute.Tensor, - tdPtdP: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + tdPtdP: cute.Tensor, tdQacctdQacc: cute.Tensor, - pipeline_q: PipelineAsync, + pipeline_q: PipelineAsync, pipeline_do: PipelineAsync, - pipeline_s: PipelineAsync, - pipeline_p: PipelineAsync, + pipeline_s: PipelineAsync, + pipeline_p: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dV: PipelineAsync, pipeline_dK: PipelineAsync, pipeline_dP: PipelineAsync, pipeline_dQaccum: PipelineAsync, - full_key_mbar_ptr: cute.Pointer, + full_key_mbar_ptr: cute.Pointer, full_value_mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1301,28 +1420,46 @@ def mma( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) key_consumer_phase = cutlass.Int32(0) - q_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.q_stage) - q_dk_consumer_state = q_consumer_state - do_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.do_stage) + q_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.q_stage + ) + q_dk_consumer_state = q_consumer_state + do_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.do_stage + ) - s_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.s_stage) - dP_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dP_stage) - p_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.s_stage) - dS_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage) - dV_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dV_stage) - dK_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dK_stage) - dQaccum_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage) + s_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.s_stage + ) + dP_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dP_stage + ) + p_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.s_stage + ) + dS_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage + ) + dV_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dV_stage + ) + dK_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dK_stage + ) + dQaccum_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage + ) tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() + work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k + seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - cute.arch.mbarrier_wait(full_key_mbar_ptr, phase=key_consumer_phase) - cute.arch.mbarrier_wait(full_value_mbar_ptr, phase=key_consumer_phase) + cute.arch.mbarrier_wait(full_key_mbar_ptr, phase=key_consumer_phase) + cute.arch.mbarrier_wait(full_value_mbar_ptr, phase=key_consumer_phase) key_consumer_phase ^= 1 @@ -1331,31 +1468,35 @@ def mma( tSrQ = thr_mma_kq.make_fragment_B(sQ) # dP = V @ dOt - tdPrV = thr_mma_vdo.make_fragment_A(sV) + tdPrV = thr_mma_vdo.make_fragment_A(sV) tdPrdOt = thr_mma_vdo.make_fragment_B(sdOt) # dK = dS.T @ Q tdKrdS = thr_mma_dsq.make_fragment_A(sdSt) - tdKrQ = thr_mma_dsq.make_fragment_B(sQt) + tdKrQ = thr_mma_dsq.make_fragment_B(sQt) accumulate_dK = False # dV = P @ dO.T tdVrdO = thr_mma_pdo.make_fragment_B(sdO) - p_tmem_layout = sm100_utils_basic.make_smem_layout_a(tiled_mma_pdo, self.mma_tiler_pdo, self.q_dtype, self.acc_stage,) + p_tmem_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pdo, + self.mma_tiler_pdo, + self.q_dtype, + self.acc_stage, + ) - tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) + tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) tdVrP = thr_mma_pdo.make_fragment_A(tP)[None, None, None, 0] tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) # dQ = dS @ K tdQaccrdS = thr_mma_dsk.make_fragment_A(sdS) - tdQaccrK = thr_mma_dsk.make_fragment_B(sKt) - + tdQaccrK = thr_mma_dsk.make_fragment_B(sKt) - #----------------------------------------------------------- + # ----------------------------------------------------------- ###### Prologue - #----------------------------------------------------------- + # ----------------------------------------------------------- # 1. S = Q0 @ K.T # 2. dP = V @ dO.T # 3. dV = P @ dO @@ -1386,15 +1527,16 @@ def mma( pipeline_dQaccum.producer_acquire(dQaccum_producer_state) for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_vdo, - tdPtdP, - tdPrV[(None, None, kphase_idx, 0)], - tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], - tdPtdP, - ) - pipeline_dP.producer_commit(dP_producer_state); dP_producer_state.advance() + tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_vdo, + tdPtdP, + tdPrV[(None, None, kphase_idx, 0)], + tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], + tdPtdP, + ) + pipeline_dP.producer_commit(dP_producer_state) + dP_producer_state.advance() # 3) dV = P.T @ dO pipeline_p.consumer_wait(p_consumer_state) @@ -1405,15 +1547,17 @@ def mma( cute.gemm( tiled_mma_pdo, tdVtdV, - tdVrP[(None, None, kphase_idx)], + tdVrP[(None, None, kphase_idx)], tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], tdVtdV, ) - pipeline_p.consumer_release(p_consumer_state); p_consumer_state.advance() - pipeline_do.consumer_release(do_consumer_state); do_consumer_state.advance() - #----------------------------------------------------------- + pipeline_p.consumer_release(p_consumer_state) + p_consumer_state.advance() + pipeline_do.consumer_release(do_consumer_state) + do_consumer_state.advance() + # ----------------------------------------------------------- ###### MAIN LOOP - #----------------------------------------------------------- + # ----------------------------------------------------------- # 1. S = K @ Q.T # 2. dQ = dS @ K # 3. dK = dS.T @ Q @@ -1449,11 +1593,12 @@ def mma( cute.gemm( tiled_mma_dsk, tdQacctdQacc, - tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdQaccrK[(None, None, kphase_idx, 0)], + tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdQaccrK[(None, None, kphase_idx, 0)], tdQacctdQacc, ) - pipeline_dQaccum.producer_commit(dQaccum_producer_state) ; dQaccum_producer_state.advance() + pipeline_dQaccum.producer_commit(dQaccum_producer_state) + dQaccum_producer_state.advance() # 3) dK = dS.T @ Q num_kphases = cute.size(tdKrdS, mode=[2]) @@ -1462,30 +1607,33 @@ def mma( cute.gemm( tiled_mma_dsq, tdKtdK, - tdKrdS[(None, None, kphase_idx, 0)], - tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], + tdKrdS[(None, None, kphase_idx, 0)], + tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], tdKtdK, ) accumulate_dK = True - pipeline_q.consumer_release(q_dk_consumer_state) ; q_dk_consumer_state.advance() - pipeline_dS.consumer_release(dS_consumer_state); dS_consumer_state.advance() + pipeline_q.consumer_release(q_dk_consumer_state) + q_dk_consumer_state.advance() + pipeline_dS.consumer_release(dS_consumer_state) + dS_consumer_state.advance() - #4) dP = V @ dO.T + # 4) dP = V @ dO.T pipeline_do.consumer_wait(do_consumer_state) pipeline_dQaccum.producer_acquire(dQaccum_producer_state) for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_vdo, - tdPtdP, - tdPrV[(None, None, kphase_idx, 0)], - tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], - tdPtdP, - ) - pipeline_dP.producer_commit(dP_producer_state); dP_producer_state.advance() + tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + tiled_mma_vdo, + tdPtdP, + tdPrV[(None, None, kphase_idx, 0)], + tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], + tdPtdP, + ) + pipeline_dP.producer_commit(dP_producer_state) + dP_producer_state.advance() # 5) dV += P @ dO pipeline_p.consumer_wait(p_consumer_state) @@ -1496,23 +1644,27 @@ def mma( cute.gemm( tiled_mma_pdo, tdVtdV, - tdVrP[(None, None, kphase_idx)], + tdVrP[(None, None, kphase_idx)], tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], tdVtdV, ) - pipeline_p.consumer_release(p_consumer_state); p_consumer_state.advance() - pipeline_do.consumer_release(do_consumer_state); do_consumer_state.advance() + pipeline_p.consumer_release(p_consumer_state) + p_consumer_state.advance() + pipeline_do.consumer_release(do_consumer_state) + do_consumer_state.advance() - pipeline_dV.producer_acquire(dV_producer_state); pipeline_dV.producer_commit(dV_producer_state); dV_producer_state.advance() + pipeline_dV.producer_acquire(dV_producer_state) + pipeline_dV.producer_commit(dV_producer_state) + dV_producer_state.advance() pipeline_s.producer_tail(s_producer_state) pipeline_dP.producer_tail(dP_producer_state) pipeline_dV.producer_tail(dV_producer_state) - #----------------------------------------------------------- + # ----------------------------------------------------------- ###### Remaining 2 - #----------------------------------------------------------- + # ----------------------------------------------------------- # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(dS_consumer_state) @@ -1522,14 +1674,15 @@ def mma( cute.gemm( tiled_mma_dsq, tdKtdK, - tdKrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], + tdKrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], tdKtdK, ) accumulate_dK = True - pipeline_dK.producer_acquire(dK_producer_state); - pipeline_dK.producer_commit(dK_producer_state); dK_producer_state.advance() + pipeline_dK.producer_acquire(dK_producer_state) + pipeline_dK.producer_commit(dK_producer_state) + dK_producer_state.advance() # 2) dQaccum = dS @ K num_kphases = cute.size(tdQaccrdS, mode=[2]) @@ -1538,13 +1691,16 @@ def mma( cute.gemm( tiled_mma_dsk, tdQacctdQacc, - tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdQaccrK[(None, None, kphase_idx, 0)], + tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], + tdQaccrK[(None, None, kphase_idx, 0)], tdQacctdQacc, ) - pipeline_dQaccum.producer_commit(dQaccum_producer_state) ; dQaccum_producer_state.advance() - pipeline_q.consumer_release(q_dk_consumer_state); q_dk_consumer_state.advance() - pipeline_dS.consumer_release(dS_consumer_state); dS_consumer_state.advance() + pipeline_dQaccum.producer_commit(dQaccum_producer_state) + dQaccum_producer_state.advance() + pipeline_q.consumer_release(q_dk_consumer_state) + q_dk_consumer_state.advance() + pipeline_dS.consumer_release(dS_consumer_state) + dS_consumer_state.advance() pipeline_dK.producer_tail(dK_producer_state) pipeline_dQaccum.producer_tail(dQaccum_producer_state) @@ -1552,93 +1708,133 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit - def split_wg(self, thr_tensor: cute.Tensor, wg_idx: cutlass.Int32, num_wg: cutlass.Constexpr[cutlass.Int32]): + def split_wg( + self, + thr_tensor: cute.Tensor, + wg_idx: cutlass.Int32, + num_wg: cutlass.Constexpr[cutlass.Int32], + ): reduced_shape = cute.product_each(thr_tensor.shape) rank = len(reduced_shape) if const_expr(reduced_shape[1] > 1): assert rank >= 2, "Need rank >= 2 for thr_tensor in split_wg" t = cute.logical_divide(thr_tensor, (reduced_shape[0], reduced_shape[1] // num_wg)) - coord = (None, (None, wg_idx)) + (None, ) * (rank - 2) + coord = (None, (None, wg_idx)) + (None,) * (rank - 2) else: assert rank >= 3, "Need rank >= 3 for thr_tensor in split_wg" if const_expr(rank == 3): t = cute.logical_divide( - thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg)) - coord = (None, None, (None, wg_idx), ) + (None, ) * (rank - 3) + thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) + ) + coord = ( + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 3) else: - t = cute.logical_divide(thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2], reduced_shape[3] // num_wg)) - coord = (None, None, None, (None, wg_idx), ) + (None, ) * (rank - 4) + t = cute.logical_divide( + thr_tensor, + ( + reduced_shape[0], + reduced_shape[1], + reduced_shape[2], + reduced_shape[3] // num_wg, + ), + ) + coord = ( + None, + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 4) return t[coord] - @cute.jit def compute_loop( self, - thr_mma_kq: cute.core.ThrMma, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_vdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, - tStS: cute.Tensor, - sLSE_2D: cute.Tensor, - sPsum_2D: cute.Tensor, - tdVtdV: cute.Tensor, - tdKtdK: cute.Tensor, - mdV: cute.Tensor, - mdK: cute.Tensor, - sdSt: cute.Tensor, - sdSt_pi: cute.Tensor, - tdPtdP: cute.Tensor, - lse_full_mbar_ptr: cute.Pointer, - lse_empty_mbar_ptr: cute.Pointer, - psum_full_mbar_ptr: cute.Pointer, - psum_empty_mbar_ptr: cute.Pointer, - pipeline_s: PipelineAsync, - pipeline_p: PipelineAsync, - pipeline_dS: PipelineAsync, - pipeline_dV: PipelineAsync, - pipeline_dK: PipelineAsync, - pipeline_dP: PipelineAsync, - softmax_scale: cutlass.Float32, - softmax_scale_log2: cutlass.Float32, - block_info: BlockInfo, - SeqlenInfoCls: Callable, - AttentionMaskCls: Callable, - TileSchedulerCls: Callable, - sdV: Optional[cute.Tensor], - sdK: Optional[cute.Tensor], - mdV_tma_tensor: Optional[cute.Tensor], - mdK_tma_tensor: Optional[cute.Tensor], - tma_atom_dV: Optional[cute.CopyAtom], - tma_atom_dK: Optional[cute.CopyAtom], - tiled_copy_r2s_dKdV: Optional[cute.TiledCopy], - mdK_semaphore: Optional[cute.Tensor], - mdV_semaphore: Optional[cute.Tensor], + thr_mma_kq: cute.core.ThrMma, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_vdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + tStS: cute.Tensor, + sLSE_2D: cute.Tensor, + sPsum_2D: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + sdSt: cute.Tensor, + sdSt_pi: cute.Tensor, + tdPtdP: cute.Tensor, + lse_full_mbar_ptr: cute.Pointer, + lse_empty_mbar_ptr: cute.Pointer, + psum_full_mbar_ptr: cute.Pointer, + psum_empty_mbar_ptr: cute.Pointer, + pipeline_s: PipelineAsync, + pipeline_p: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dV: PipelineAsync, + pipeline_dK: PipelineAsync, + pipeline_dP: PipelineAsync, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + sdV: Optional[cute.Tensor], + sdK: Optional[cute.Tensor], + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + tiled_copy_r2s_dKdV: Optional[cute.TiledCopy], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], ): # tix: [128...384] 8 warps - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 - tidx = cute.arch.thread_idx()[0] % 128 # 0...128 - wg_idx = (cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 - num_wg = (cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128) # 2 + tidx = cute.arch.thread_idx()[0] % 128 # 0...128 + wg_idx = ( + cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + ) // 128 + num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 # 2 # wg_idx: # 0: [256...384] # 1: [128...256] - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) - s_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.s_stage) - p_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.s_stage) - dS_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.ds_stage) + s_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.s_stage + ) + p_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.s_stage + ) + dS_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.ds_stage + ) - dP_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dP_stage) + dP_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dP_stage + ) - lse_consumer_phase = psum_consumer_phase = cute.Int32(0) + lse_consumer_phase = psum_consumer_phase = cute.Int32(0) - sub_packed_f32x2 = partial(cute.arch.calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2, rnd=nvvm.RoundingModeKind.RN ) + sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -1652,7 +1848,10 @@ def compute_loop( # TODO: condition mask_seqlen mask_fn = partial( mask.apply_mask_sm100_transposed, - n_block=n_block, mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local + n_block=n_block, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, ) # Mainloop @@ -1666,101 +1865,127 @@ def compute_loop( cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) lse_consumer_phase ^= 1 - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tStS) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) - tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) - tStP = cute.make_tensor( - tStS.iterator, - cute.composition(tStS.layout, cute.make_layout((self.m_block_size, tileP_f32_like))), - ) + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) + tStP = cute.make_tensor( + tStS.iterator, + cute.composition( + tStS.layout, cute.make_layout((self.m_block_size, tileP_f32_like)) + ), + ) tiled_tmem_st = tcgen05.make_tmem_copy(tmem_store_atom, tStP) - thr_tmem_st = tiled_tmem_st.get_slice(tidx) + thr_tmem_st = tiled_tmem_st.get_slice(tidx) #### TMEM tStS_t2r_p = thr_tmem_ld.partition_S(tStS) - tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) + tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) #### RMEM - tScS = thr_mma_kq.partition_C(cute.make_identity_tensor((self.mma_tiler_kq[0], self.mma_tiler_kq[1]))) + tScS = thr_mma_kq.partition_C( + cute.make_identity_tensor((self.mma_tiler_kq[0], self.mma_tiler_kq[1])) + ) tScS_tensor = cute.make_tensor(tScS.iterator, tScS.layout) - tScS_t2r_p = thr_tmem_ld.partition_D(tScS_tensor) - tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) + tScS_t2r_p = thr_tmem_ld.partition_D(tScS_tensor) + tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) - tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 #### TMEM->RMEM (Load S from TMEM) cute.copy(tiled_tmem_ld, tStS_t2r, tSrS_t2r) cute.arch.fence_view_async_tmem_load() #### Sync for load fence and LSE - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + number_of_threads=self.num_compute_threads, + ) #### APPLY MASK if const_expr(self.is_causal or self.is_local): - mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block, ) + mask_fn( + tSrS_t2r, + tScS_t2r, + m_block=m_block, + ) - #--------------------------------------------- + # --------------------------------------------- #### P = exp(S - LSE) - #--------------------------------------------- + # --------------------------------------------- #### RMEM (coordinates for P) - cP_f32 = cute.make_tensor( - tScS.iterator, - cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tileP_f32_like))) - ) + cP_f32 = cute.make_tensor( + tScS.iterator, + cute.composition( + tScS.layout, cute.make_layout((self.m_block_size, tileP_f32_like)) + ), + ) tScP_r2t_p = thr_tmem_st.partition_S(cP_f32) - tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) + tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) tStP_r2t_p = thr_tmem_st.partition_D(tStP) - tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) + tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) #### Compute P = exp(S * scale - LSE) tLSE = thr_tmem_ld.partition_D(sLSE_2D) # split to wg0 & wg1 - tLSErLSE_p = cute.make_tensor(cute.recast_ptr(tLSE.iterator), cute.make_layout((tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1))) - tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] - + tLSErLSE_p = cute.make_tensor( + cute.recast_ptr(tLSE.iterator), + cute.make_layout( + (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) + ), + ) + tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] - WIDTH = cute.arch.WARP_SIZE - CLAMP = WIDTH - 1 - MAC = (0 << 8) | CLAMP - FULL = cute.arch.FULL_MASK + WIDTH = cute.arch.WARP_SIZE + CLAMP = WIDTH - 1 + MAC = (0 << 8) | CLAMP + FULL = cute.arch.FULL_MASK lidx = cute.arch.lane_idx() - tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 - tSrP_r2t = cute.make_tensor(cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r[None, 0, None, None].layout) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r[None, 0, None, None].layout, + ) for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): - own0 = tLSErLSE[(lidx, 0), i, 0, 0] - own1 = tLSErLSE[(lidx+1, 0), i, 0, 0] - #own1 = cute.arch.shuffle_sync(own0, offset=((lidx + 1) & CLAMP), + own1 = tLSErLSE[(lidx + 1, 0), i, 0, 0] + # own1 = cute.arch.shuffle_sync(own0, offset=((lidx + 1) & CLAMP), # mask=FULL, mask_and_clamp=MAC) for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): - lse_j = cute.arch.shuffle_sync(own0, offset=j, mask=FULL, mask_and_clamp=MAC) - lse_j1 = cute.arch.shuffle_sync(own1, offset=j, mask=FULL, mask_and_clamp=MAC) + lse_j = cute.arch.shuffle_sync( + own0, offset=j, mask=FULL, mask_and_clamp=MAC + ) + lse_j1 = cute.arch.shuffle_sync( + own1, offset=j, mask=FULL, mask_and_clamp=MAC + ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0] = cute.arch.fma_packed_f32x2(( - (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0])), - (softmax_scale_log2, softmax_scale_log2), - (-lse_j, -lse_j1)) + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = cute.arch.fma_packed_f32x2( + ((tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0])), + (softmax_scale_log2, softmax_scale_log2), + (-lse_j, -lse_j1), + ) - tSrS_t2r[j, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j, i, 0, 0]) - tSrS_t2r[j+1, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j+1, i, 0, 0]) + tSrS_t2r[j, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j, i, 0, 0]) + tSrS_t2r[j + 1, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j + 1, i, 0, 0]) - tSrP_r2t[j, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.q_dtype) - tSrP_r2t[j+1, 0, 0] = tSrS_t2r[j+1, i, 0, 0].to(self.q_dtype) + tSrP_r2t[j, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.q_dtype) + tSrP_r2t[j + 1, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.q_dtype) cute.copy(thr_tmem_st, tSrP_r2t_f32[None, None], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + number_of_threads=self.num_compute_threads, + ) pipeline_p.producer_commit(p_producer_state) p_producer_state.advance() @@ -1772,9 +1997,9 @@ def compute_loop( with cute.arch.elect_one(): cute.arch.mbarrier_arrive(lse_empty_mbar_ptr) - #--------------------------------------------- + # --------------------------------------------- # dS.T = P.T * (dP.T - D) - #--------------------------------------------- + # --------------------------------------------- if warp_idx == self.compute_warp_ids[0]: cute.arch.mbarrier_wait(psum_full_mbar_ptr, psum_consumer_phase) psum_consumer_phase ^= 1 @@ -1784,65 +2009,93 @@ def compute_loop( #### TMEM->RMEM (Load dP from TMEM) tiled_tmem_ld_dP = tcgen05.make_tmem_copy(tmem_load_atom, tdPtdP) - thr_tmem_ld_dP = tiled_tmem_ld_dP.get_slice(tidx) + thr_tmem_ld_dP = tiled_tmem_ld_dP.get_slice(tidx) - tdPtdP_t2r_p = thr_tmem_ld_dP.partition_S(tdPtdP) # - tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) + tdPtdP_t2r_p = thr_tmem_ld_dP.partition_S(tdPtdP) # + tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) #### TMEM->RMEM (Load dP from TMEM) - cdP = cute.make_identity_tensor((self.mma_tiler_vdo[0], self.mma_tiler_vdo[1])) - tdPcdP = thr_mma_vdo.partition_C(cdP) + cdP = cute.make_identity_tensor((self.mma_tiler_vdo[0], self.mma_tiler_vdo[1])) + tdPcdP = thr_mma_vdo.partition_C(cdP) tdPcdP_tensor = cute.make_tensor(tdPcdP.iterator, tdPcdP.layout) tdPcdP_t2r_p = thr_tmem_ld_dP.partition_D(tdPcdP_tensor) - tdPcdP_t2r = self.split_wg(tdPcdP_t2r_p, wg_idx, num_wg) - tdPrdP_t2r = cute.make_fragment(tdPcdP_t2r[(None, 0, None, None)].shape, Float32) # ((32,1),1,1) + tdPcdP_t2r = self.split_wg(tdPcdP_t2r_p, wg_idx, num_wg) + tdPrdP_t2r = cute.make_fragment( + tdPcdP_t2r[(None, 0, None, None)].shape, Float32 + ) # ((32,1),1,1) #### Sync for load fence and Psum - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + number_of_threads=self.num_compute_threads, + ) ##### dS.T = P.T * (dP.T - Psum) - sdSt_mn = cute.make_tensor(sdSt_pi.iterator, cute.composition(sdSt_pi.layout, cute.make_layout((self.m_block_size, self.n_block_size)))) - tdKsdS = cute.composition(sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape)) + sdSt_mn = cute.make_tensor( + sdSt_pi.iterator, + cute.composition( + sdSt_pi.layout, cute.make_layout((self.m_block_size, self.n_block_size)) + ), + ) + tdKsdS = cute.composition( + sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape) + ) - tSrS_t2r_bf16 = cute.make_tensor(cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape) + tSrS_t2r_bf16 = cute.make_tensor( + cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape + ) tPsum = thr_tmem_ld.partition_D(sPsum_2D) - tPsumrPsum_p = cute.make_tensor(cute.recast_ptr(tPsum.iterator), cute.make_layout((tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1))) - tPsumrPsum = tPsumrPsum_p[None, (None, wg_idx), None, None] # self.split_wg(tLSErLSE_p, wg_idx, num_wg) + tPsumrPsum_p = cute.make_tensor( + cute.recast_ptr(tPsum.iterator), + cute.make_layout( + (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) + ), + ) + tPsumrPsum = tPsumrPsum_p[ + None, (None, wg_idx), None, None + ] # self.split_wg(tLSErLSE_p, wg_idx, num_wg) for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() own0 = tPsumrPsum[(lidx, 0), i, 0, 0] - own1 = tPsumrPsum[(lidx+1, 0), i, 0, 0] + own1 = tPsumrPsum[(lidx + 1, 0), i, 0, 0] for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): + psum_j = cute.arch.shuffle_sync( + own0, offset=j, mask=FULL, mask_and_clamp=MAC + ) + psum_j1 = cute.arch.shuffle_sync( + own1, offset=j, mask=FULL, mask_and_clamp=MAC + ) - psum_j = cute.arch.shuffle_sync(own0, offset=j, mask=FULL, mask_and_clamp=MAC) - psum_j1 = cute.arch.shuffle_sync(own1, offset=j, mask=FULL, mask_and_clamp=MAC) + tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0] = sub_packed_f32x2( + (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), (psum_j, psum_j1) + ) - tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0] = sub_packed_f32x2( - (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0]), - (psum_j, psum_j1) - ) + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0]), + (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), + ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0] = cute.arch.mul_packed_f32x2( - (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j+1, i, 0, 0]), - (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j+1, 0, 0]) - ) - - tSrS_t2r_bf16[j, i, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.ds_dtype) - tSrS_t2r_bf16[j+1, i, 0, 0] = tSrS_t2r[j+1, i, 0, 0].to(self.ds_dtype) + tSrS_t2r_bf16[j, i, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.ds_dtype) + tSrS_t2r_bf16[j + 1, i, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.ds_dtype) cute.autovec_copy(tSrS_t2r_bf16[None, i, 0, 0], tdKsdS[None, i, 0, 0]) pipeline_dP.consumer_release(dP_consumer_state) dP_consumer_state.advance() - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + number_of_threads=self.num_compute_threads, + ) pipeline_dS.producer_commit(dS_producer_state) dS_producer_state.advance() @@ -1884,8 +2137,8 @@ def compute_loop( thr_copy_r2s_dKdV, pipeline_dV, softmax_scale, - False, # apply scale - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + False, # apply scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdV_semaphore, ) #### STORE dK @@ -1902,8 +2155,8 @@ def compute_loop( thr_copy_r2s_dKdV, pipeline_dK, softmax_scale, - True, # apply scale - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + True, # apply scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, ) @@ -1913,46 +2166,53 @@ def compute_loop( @cute.jit def dQacc_reduce( self, - mdQaccum: cute.Tensor, - sdQaccum: cute.Tensor, - thr_mma_dsk: cute.core.ThrMma, - tdQtdQ: cute.Tensor, - pipeline_dQ: PipelineAsync, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + thr_mma_dsk: cute.core.ThrMma, + tdQtdQ: cute.Tensor, + pipeline_dQ: PipelineAsync, dQaccum_reduce_mbar_ptr: cute.Pointer, - block_info: BlockInfo, - SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, - mdQ_semaphore: Optional[cute.Tensor], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + mdQ_semaphore: Optional[cute.Tensor], ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * 4) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * 4) - dQ_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage) + dQ_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() # TMEM -> RMEM - tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + tmem_ld_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) - tdQtdQ_t2r = thr_tmem_ld.partition_S(tdQtdQ) + tdQtdQ_t2r = thr_tmem_ld.partition_S(tdQtdQ) - cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) - tdQcdQ = thr_mma_dsk.partition_C(cdQ) + cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) - tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) - atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dqaccum_dtype, num_bits_per_copy=128) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dqaccum_dtype, num_bits_per_copy=128 + ) thr_layout = cute.make_layout(shape=128, stride=1) - val_layout = cute.make_layout(shape=4, stride=1) + val_layout = cute.make_layout(shape=4, stride=1) tiler_mn, layout_tv = cute.make_layout_tv(thr_layout=thr_layout, val_layout=val_layout) - tiled_smem_store = cute.make_tiled_copy(atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn) - + tiled_smem_store = cute.make_tiled_copy( + atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn + ) smem_thr_copy_dQaccum = tiled_smem_store.get_slice(tidx) tdQsdQ = smem_thr_copy_dQaccum.partition_D(sdQaccum) @@ -1967,7 +2227,9 @@ def dQacc_reduce( if cute.arch.thread_idx()[0] == 0: cute.arch.mbarrier_arrive(dQaccum_reduce_mbar_ptr) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads + ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx @@ -1986,20 +2248,25 @@ def dQacc_reduce( # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, Float32) - assert self.dQaccum_reduce_stage == cute.size(tdQrdQ_t2r, mode=[1]), "dQaccum reduce stage mismatch" + assert self.dQaccum_reduce_stage == cute.size(tdQrdQ_t2r, mode=[1]), ( + "dQaccum reduce stage mismatch" + ) cute.copy(thr_tmem_ld, tdQtdQ_t2r, tdQrdQ_t2r) cute.arch.fence_view_async_tmem_load() - pipeline_dQ.consumer_release(dQ_consumer_state); dQ_consumer_state.advance() + pipeline_dQ.consumer_release(dQ_consumer_state) + dQ_consumer_state.advance() # semaphore acquire if const_expr(self.deterministic): barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) - - for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + number_of_threads=num_reduce_threads, + ) + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 if stage >= 2 and cute.arch.thread_idx()[0] == 0: cute.arch.cp_async_bulk_wait_group(1, read=read_flag) @@ -2007,17 +2274,28 @@ def dQacc_reduce( tdQrdQ_r2s = tdQrdQ_t2r[None, stage, None, None] tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] - tdQrdQ_r2s = cute.make_tensor(tdQrdQ_r2s.iterator, cute.make_layout(tdQsdQ_r2s.shape)) + tdQrdQ_r2s = cute.make_tensor( + tdQrdQ_r2s.iterator, cute.make_layout(tdQsdQ_r2s.shape) + ) cute.copy(smem_thr_copy_dQaccum, tdQrdQ_r2s, tdQsdQ_r2s) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + number_of_threads=num_reduce_threads, + ) if cute.arch.thread_idx()[0] == 0: smem_ptr = sdQaccum[None, reduce_phase].iterator - g_stage_index_elems = m_block * (self.m_block_size * self.head_dim_v_padded) + stage * (self.m_block_size * 32) - gmem_row_ptr = cute.domain_offset((g_stage_index_elems,), mdQaccum_cur).iterator + g_stage_index_elems = m_block * ( + self.m_block_size * self.head_dim_v_padded + ) + stage * (self.m_block_size * 32) + gmem_row_ptr = cute.domain_offset( + (g_stage_index_elems,), mdQaccum_cur + ).iterator tma_reduce_add_bulk_f32(smem_ptr, gmem_row_ptr, store_bytes) cute.arch.cp_async_bulk_commit_group() @@ -2027,18 +2305,25 @@ def dQacc_reduce( reduce_phase ^= 1 - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + number_of_threads=num_reduce_threads, + ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic): if cute.arch.thread_idx()[0] == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - cute.arch.barrier(barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + number_of_threads=num_reduce_threads, + ) barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) - if cute.arch.thread_idx()[0] == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) @@ -2046,63 +2331,77 @@ def dQacc_reduce( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - @cute.jit def epilogue_dKV( self, - tidx: Int32, - warp_idx: Int32, - batch_idx: Int32, - head_idx: Int32, - n_block: Int32, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, - tdVtdV: cute.Tensor, - tdKtdK: cute.Tensor, - mdV: cute.Tensor, - mdK: cute.Tensor, - pipeline_dV: PipelineAsync, - pipeline_dK: PipelineAsync, + tidx: Int32, + warp_idx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma_pdo: cute.core.ThrMma, + thr_mma_dsq: cute.core.ThrMma, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + pipeline_dV: PipelineAsync, + pipeline_dK: PipelineAsync, softmax_scale: Float32, ): + wg_idx = ( + cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + ) // 128 + num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 - wg_idx = (cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 - num_wg = (cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128) - - dV_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dV_stage) - dK_consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, self.dK_stage) + dV_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dV_stage + ) + dK_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dK_stage + ) assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" mdV_cur = mdV[None, None, head_idx, batch_idx] mdK_cur = mdK[None, None, head_idx, batch_idx] - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) # dV pipeline_dV.consumer_wait(dV_consumer_state) tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) - thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) + thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV) - tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) + tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) - cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) - tdVcdV = thr_mma_pdo.partition_C(cdV) + cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) + tdVcdV = thr_mma_pdo.partition_C(cdV) tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) - tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) - tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) + tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) + tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r) cute.arch.fence_view_async_tmem_load() universal_copy_bits = 128 - atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dv_dtype, num_bits_per_copy=universal_copy_bits,) - tiled_gmem_store_dV = cute.make_tiled_copy(atom_universal_copy, layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, tiler_mn=tiled_tmem_ld_dV.tiler_mn,) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dv_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tiled_gmem_store_dV = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dV.tiler_mn, + ) - tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) + tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])): dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) @@ -2110,41 +2409,49 @@ def epilogue_dKV( gdV = cute.local_tile(mdV_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) gdV_tile = gdV[None, None, n_block] - tdVgdV = thr_mma_pdo.partition_C(gdV_tile) + tdVgdV = thr_mma_pdo.partition_C(gdV_tile) tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) - tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) + tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) - cute.copy(tiled_gmem_store_dV, tdVrdV_r2s , tdVgdV_r2g) + cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) - pipeline_dV.consumer_release(dV_consumer_state); dV_consumer_state.advance() + pipeline_dV.consumer_release(dV_consumer_state) + dV_consumer_state.advance() # dK pipeline_dK.consumer_wait(dK_consumer_state) tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) - thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) + thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK) - tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) + tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) - cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) - tdKcdK = thr_mma_dsq.partition_C(cdK) - tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) + cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) + tdKcdK = thr_mma_dsq.partition_C(cdK) + tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) - tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) - tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) + tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) + tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r) cute.arch.fence_view_async_tmem_load() universal_copy_bits = 128 - atom_universal_copy = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=universal_copy_bits,) - - tiled_gmem_store_dK = cute.make_tiled_copy(atom_universal_copy,layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled,tiler_mn=tiled_tmem_ld_dK.tiler_mn,) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=universal_copy_bits, + ) - tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) + tiled_gmem_store_dK = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dK.tiler_mn, + ) + tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])): dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale @@ -2153,39 +2460,39 @@ def epilogue_dKV( gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) gdK_tile = gdK[None, None, n_block] - tdKgdK = thr_mma_dsq.partition_C(gdK_tile) + tdKgdK = thr_mma_dsq.partition_C(gdK_tile) tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) - tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) - - cute.copy(tiled_gmem_store_dK, tdKrdK_r2s , tdKgdK_r2g) + tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) - pipeline_dK.consumer_release(dK_consumer_state); dK_consumer_state.advance() + cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) + pipeline_dK.consumer_release(dK_consumer_state) + dK_consumer_state.advance() @cute.jit def epilogue_dK_or_dV_tma( self, - tidx: Int32, - batch_idx: Int32, - head_idx: Int32, - n_block: Int32, - thr_mma: cute.core.ThrMma, - tdKVtdKV: cute.Tensor, - mdKV: cute.Tensor, - sdKV: cute.Tensor, + tidx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + thr_mma: cute.core.ThrMma, + tdKVtdKV: cute.Tensor, + mdKV: cute.Tensor, + sdKV: cute.Tensor, tma_atom_dKV: cute.CopyAtom, thr_copy_r2s_dKdV: cute.TiledCopy, - pipeline: PipelineAsync, - softmax_scale : Float32, - do_scale : cutlass.Constexpr[cutlass.Boolean], - barrier_id : Int32, - mdKV_semaphore : Optional[cute.Tensor], + pipeline: PipelineAsync, + softmax_scale: Float32, + do_scale: cutlass.Constexpr[cutlass.Boolean], + barrier_id: Int32, + mdKV_semaphore: Optional[cute.Tensor], ): # assumes mma_tiler_pdo = mma_tiler_dsq = (n_block_size, head_dim) # head_dim = head_dim_v, dk_dtype = dv_dtype wg_idx = (cute.arch.thread_idx()[0] % self.num_compute_threads) // 128 - num_wg = (self.num_compute_threads // 128) + num_wg = self.num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 sdKV = sdKV[None, None, wg_idx] @@ -2193,7 +2500,9 @@ def epilogue_dK_or_dV_tma( head_idx_kv = head_idx // self.qhead_per_kvhead mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] - gdKV_p = cute.local_tile(mdKV_cur, (self.m_block_size, self.head_dim_v_padded), (n_block, 0)) + gdKV_p = cute.local_tile( + mdKV_cur, (self.m_block_size, self.head_dim_v_padded), (n_block, 0) + ) gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) gdKV_epi = cute.local_tile(gdKV, self.sdKdV_epi_tile, (0, None)) @@ -2203,7 +2512,7 @@ def epilogue_dK_or_dV_tma( # (TMA) and (TMA, EPI_STAGE) tdKVsdKV, tdKVgdKV = cpasync.tma_partition( tma_atom_dKV, - 0, # no multicast + 0, # no multicast cute.make_layout(1), cute.group_modes(sdKV, 0, 2), cute.group_modes(gdKV_epi, 0, 2), @@ -2215,7 +2524,9 @@ def epilogue_dK_or_dV_tma( num_epi_stages = cute.size(tdKVgdKV.shape[1]) assert num_epi_stages == 1 or num_epi_stages == 2, "Wrong number of epi stages" - tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + tmem_ld_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) if const_expr(self.deterministic): read_flag = False @@ -2223,42 +2534,47 @@ def epilogue_dK_or_dV_tma( read_flag = True # TODO: maybe support more than 1 stage - consumer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, 1) + consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) pipeline.consumer_wait(consumer_state) # semaphore acquire if const_expr(self.deterministic): - barrier.wait_eq(mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead) + barrier.wait_eq( + mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead + ) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) for s in cutlass.range_constexpr(num_epi_stages): - # TMEM -> RMEM -- setup tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdKVtdKV) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) tdKVtdKV_t2r_p = thr_tmem_ld.partition_S(tdKVtdKV) - tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] - cdKV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) - tdKVcdKV = thr_mma.partition_C(cdKV) + cdKV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tdKVcdKV = thr_mma.partition_C(cdKV) tdKVcdKV_t2r_p = thr_tmem_ld.partition_D(tdKVcdKV) - tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVcdKV_t2r = tdKVcdKV_t2r[None, s] - tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) + tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) - assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, "RMEM<->TMEM fragment size mismatch" + assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, ( + "RMEM<->TMEM fragment size mismatch" + ) # TMEM -> RMEM -- copy and fence cute.copy(thr_tmem_ld, tdKVtdKV_t2r, tdKVrdKV_t2r) cute.arch.fence_view_async_tmem_load() # RMEM -- scale and convert - tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) if const_expr(do_scale): scale = softmax_scale else: @@ -2272,18 +2588,26 @@ def epilogue_dK_or_dV_tma( tdKVcdKV_r2s = self.split_wg(tdKVcdKV_r2s_p, wg_idx, num_wg) tdKVcdKV_r2s = cute.logical_divide( tdKVcdKV_r2s, - (tdKVcdKV_r2s.shape[0], tdKVcdKV_r2s.shape[1], tdKVcdKV_r2s.shape[2] // num_epi_stages) + ( + tdKVcdKV_r2s.shape[0], + tdKVcdKV_r2s.shape[1], + tdKVcdKV_r2s.shape[2] // num_epi_stages, + ), )[((None, 0), (None, 0), (None, s))] tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVcdKV_r2s.shape) tdKVsdKV_r2s = thr_copy_r2s_dKdV.partition_D(sdKV) - assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), "RMEM<->SMEM fragment size mismatch" + assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), ( + "RMEM<->SMEM fragment size mismatch" + ) # RMEM -> SMEM -- copy, fence and barrier cute.copy(thr_copy_r2s_dKdV, tdKVrdKV_r2s, tdKVsdKV_r2s) - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) # SMEM -> GMEM @@ -2292,11 +2616,17 @@ def epilogue_dK_or_dV_tma( if s < num_epi_stages - 1: cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - cute.arch.barrier_arrive(barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE) + cute.arch.barrier_arrive( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) # Barrier since all warps need to wait for SMEM to be freed - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) - cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar @@ -2310,7 +2640,6 @@ def epilogue_dK_or_dV_tma( pipeline.consumer_release(consumer_state) consumer_state.advance() - @cute.jit def load_M_tile( self, @@ -2326,5 +2655,5 @@ def load_M_tile( tma_atom, tQgQ[None, block], tQsQ[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), ) From 498bfe677cc9ff2b9f4f35b1a1395a5f9715871d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 14:38:43 -0400 Subject: [PATCH 156/258] [Cute,Bwd,Sm100] Rename var {m,n}_block_size->tile_{m,n} --- flash_attn/cute/flash_bwd_postprocess.py | 4 +- flash_attn/cute/flash_bwd_sm100.py | 118 ++++++++++------------- 2 files changed, 55 insertions(+), 67 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 8088997fd26..e57f28c0d66 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -395,7 +395,7 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, - m_block_size: int = 128, + tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, dQ_swapAB: bool = False, @@ -404,7 +404,7 @@ def __init__( dtype=dtype, head_dim=head_dim, arch=90, # tmp dummy placement for now - tile_m=m_block_size, + tile_m=tile_m, num_threads=num_threads, AtomLayoutMdQ=AtomLayoutMdQ, dQ_swapAB=dQ_swapAB, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 86afbf8f105..7ebcf7638f7 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -58,46 +58,46 @@ def __init__( is_causal: bool = False, is_local: bool = False, qhead_per_kvhead: cutlass.Constexpr[int] = 1, - m_block_size: int = 128, - n_block_size: int = 128, + tile_m: int = 128, + tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, ): # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 - self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" - self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) - assert self.head_dim_padded == self.head_dim_v_padded, ( - "head_dim_padded and head_dim_v_padded must be the same for now" + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + assert self.tile_hdim == self.tile_hdimv, ( + "tile_hdim and tile_hdimv must be the same for now" ) - self.check_hdim_oob = head_dim != self.head_dim_padded - self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv - self.m_block_size = m_block_size - self.n_block_size = n_block_size + self.tile_m = tile_m + self.tile_n = tile_n # number of tma reduce adds per dQacc mma - self.dQaccum_reduce_stage = self.head_dim_padded // 32 + self.dQaccum_reduce_stage = self.tile_hdim // 32 # CTA tiler - self.cta_tiler = (m_block_size, n_block_size, self.head_dim_padded) + self.cta_tiler = (tile_m, tile_n, self.tile_hdim) # S = K @ Q.T - self.mma_tiler_kq = (n_block_size, m_block_size, self.head_dim_padded) + self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) # dP = V @ dO.T - self.mma_tiler_vdo = (n_block_size, m_block_size, self.head_dim_v_padded) + self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv) # dV = P.T @ dO - self.mma_tiler_pdo = (n_block_size, self.head_dim_v_padded, m_block_size) + self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m) # dK = dS.T @ Q (N, M) (M, D) - self.mma_tiler_dsq = (n_block_size, self.head_dim_v_padded, m_block_size) + self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m) # dQ = dS @ K - self.mma_tiler_dsk = (m_block_size, self.head_dim_v_padded, n_block_size) + self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) self.kq_acc_dtype = self.vdo_acc_dtype = self.pdo_acc_dtype = self.dsq_acc_dtype = ( self.dsk_acc_dtype @@ -137,10 +137,10 @@ def __init__( self.tmem_s_offset = 0 self.tmem_p_offset = 0 # overlap with S - self.tmem_dV_offset = self.tmem_s_offset + self.n_block_size - self.tmem_dP_offset = self.tmem_dV_offset + self.head_dim_v_padded + self.tmem_dV_offset = self.tmem_s_offset + self.tile_n + self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv self.tmem_dQaccum_offset = self.tmem_dP_offset # overlap with dP - self.tmem_dK_offset = self.tmem_dP_offset + self.m_block_size + self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.num_regs_reduce = 144 self.num_regs_compute = 128 @@ -379,15 +379,15 @@ def __call__( ) sdQaccum_layout = cute.make_layout( - shape=(self.m_block_size * 32, self.sdQaccum_stage), + shape=(self.tile_m * 32, self.sdQaccum_stage), ) sLSE_layout = cute.make_layout( - shape=(self.m_block_size, self.lse_stage), - stride=(1, cute.round_up(self.m_block_size, 64)), + shape=(self.tile_m, self.lse_stage), + stride=(1, cute.round_up(self.tile_m, 64)), ) sPsum_layout = cute.make_layout( - shape=(self.m_block_size, self.psum_stage), - stride=(1, cute.round_up(self.m_block_size, 64)), + shape=(self.tile_m, self.psum_stage), + stride=(1, cute.round_up(self.tile_m, 64)), ) self.mdK_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdK) @@ -399,7 +399,7 @@ def __call__( if const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") self.sdKdV_epi_tile = ( - self.n_block_size, + self.tile_n, 128 // (self.dk_dtype.width // 8), ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] sdKdV_layout = sm100_utils_basic.make_smem_layout_epi( @@ -441,7 +441,7 @@ def __call__( tma_atom_dK = None thr_layout_r2s_dKdV = cute.make_ordered_layout( - (self.n_block_size, 1), order=(1, 0) + (self.tile_n, 1), order=(1, 0) ) # 128 threads val_layout_r2s_dKdV = cute.make_ordered_layout( (1, 128 // self.dk_dtype.width), order=(1, 0) @@ -488,14 +488,14 @@ def __call__( tma_atom_LSE, tma_tensor_LSE = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_load_op, mLSE, - cute.make_layout((self.m_block_size)), - (self.m_block_size,), + cute.make_layout((self.tile_m)), + (self.tile_m,), ) tma_atom_Psum, tma_tensor_Psum = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_load_op, mPsum, - cute.make_layout((self.m_block_size)), - (self.m_block_size,), + cute.make_layout((self.tile_m)), + (self.tile_m,), ) # dP = V @ dO.T @@ -520,8 +520,8 @@ def __call__( self.tma_copy_do_bytes = cute.size_in_bytes( self.do_dtype, cute.select(sdO_layout, mode=[0, 1, 2]) ) - self.tma_copy_lse_bytes = self.m_block_size * 4 - self.tma_copy_psum_bytes = self.m_block_size * 4 + self.tma_copy_lse_bytes = self.tile_m * 4 + self.tma_copy_psum_bytes = self.tile_m * 4 TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal @@ -868,16 +868,12 @@ def kernel( sLSE_load = storage.sLSE.get_tensor(sLSE_layout) sLSE_mma = storage.sLSE.get_tensor( - cute.make_layout( - shape=(self.m_block_size, self.n_block_size, self.lse_stage), stride=(0, 1, 0) - ) + cute.make_layout(shape=(self.tile_m, self.tile_n, self.lse_stage), stride=(0, 1, 0)) ) sPsum_load = storage.sPsum.get_tensor(sPsum_layout) sPsum_mma = storage.sPsum.get_tensor( - cute.make_layout( - shape=(self.m_block_size, self.n_block_size, self.psum_stage), stride=(0, 1, 0) - ) + cute.make_layout(shape=(self.tile_m, self.tile_n, self.psum_stage), stride=(0, 1, 0)) ) sdV = storage.sdO.get_tensor( @@ -929,8 +925,8 @@ def kernel( tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) block_info = BlockInfo( - self.m_block_size, - self.n_block_size, + self.tile_m, + self.tile_n, self.is_causal, self.is_local, None, @@ -951,8 +947,8 @@ def kernel( # TODO: support local AttentionMaskCls = partial( AttentionMask, - self.m_block_size, - self.n_block_size, + self.tile_m, + self.tile_n, ) cute.arch.sync_threads() @@ -1205,8 +1201,8 @@ def load( gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_kq.partition_B(gQ) - gLSE = cute.local_tile(mLSE_cur, (self.n_block_size,), (None,)) - gPsum = cute.local_tile(mPsum_cur, (self.n_block_size,), (None,)) + gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) + gPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdVgdO = thr_mma_pdo.partition_B(gdO) @@ -1871,9 +1867,7 @@ def compute_loop( tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) tStP = cute.make_tensor( tStS.iterator, - cute.composition( - tStS.layout, cute.make_layout((self.m_block_size, tileP_f32_like)) - ), + cute.composition(tStS.layout, cute.make_layout((self.tile_m, tileP_f32_like))), ) tiled_tmem_st = tcgen05.make_tmem_copy(tmem_store_atom, tStP) @@ -1918,9 +1912,7 @@ def compute_loop( #### RMEM (coordinates for P) cP_f32 = cute.make_tensor( tScS.iterator, - cute.composition( - tScS.layout, cute.make_layout((self.m_block_size, tileP_f32_like)) - ), + cute.composition(tScS.layout, cute.make_layout((self.tile_m, tileP_f32_like))), ) tScP_r2t_p = thr_tmem_st.partition_S(cP_f32) @@ -2034,9 +2026,7 @@ def compute_loop( ##### dS.T = P.T * (dP.T - Psum) sdSt_mn = cute.make_tensor( sdSt_pi.iterator, - cute.composition( - sdSt_pi.layout, cute.make_layout((self.m_block_size, self.n_block_size)) - ), + cute.composition(sdSt_pi.layout, cute.make_layout((self.tile_m, self.tile_n))), ) tdKsdS = cute.composition( sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape) @@ -2216,7 +2206,7 @@ def dQacc_reduce( smem_thr_copy_dQaccum = tiled_smem_store.get_slice(tidx) tdQsdQ = smem_thr_copy_dQaccum.partition_D(sdQaccum) - store_bytes = cutlass.Int32(self.m_block_size * 32 * 4) + store_bytes = cutlass.Int32(self.tile_m * 32 * 4) if const_expr(self.deterministic): read_flag = False @@ -2290,9 +2280,9 @@ def dQacc_reduce( if cute.arch.thread_idx()[0] == 0: smem_ptr = sdQaccum[None, reduce_phase].iterator - g_stage_index_elems = m_block * ( - self.m_block_size * self.head_dim_v_padded - ) + stage * (self.m_block_size * 32) + g_stage_index_elems = m_block * (self.tile_m * self.tile_hdimv) + stage * ( + self.tile_m * 32 + ) gmem_row_ptr = cute.domain_offset( (g_stage_index_elems,), mdQaccum_cur ).iterator @@ -2406,7 +2396,7 @@ def epilogue_dKV( dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) - gdV = cute.local_tile(mdV_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_m, self.tile_hdimv), (None, 0)) gdV_tile = gdV[None, None, n_block] tdVgdV = thr_mma_pdo.partition_C(gdV_tile) @@ -2457,7 +2447,7 @@ def epilogue_dKV( dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) - gdK = cute.local_tile(mdK_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0)) gdK_tile = gdK[None, None, n_block] tdKgdK = thr_mma_dsq.partition_C(gdK_tile) @@ -2488,7 +2478,7 @@ def epilogue_dK_or_dV_tma( barrier_id: Int32, mdKV_semaphore: Optional[cute.Tensor], ): - # assumes mma_tiler_pdo = mma_tiler_dsq = (n_block_size, head_dim) + # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) # head_dim = head_dim_v, dk_dtype = dv_dtype wg_idx = (cute.arch.thread_idx()[0] % self.num_compute_threads) // 128 @@ -2500,9 +2490,7 @@ def epilogue_dK_or_dV_tma( head_idx_kv = head_idx // self.qhead_per_kvhead mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] - gdKV_p = cute.local_tile( - mdKV_cur, (self.m_block_size, self.head_dim_v_padded), (n_block, 0) - ) + gdKV_p = cute.local_tile(mdKV_cur, (self.tile_m, self.tile_hdimv), (n_block, 0)) gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) gdKV_epi = cute.local_tile(gdKV, self.sdKdV_epi_tile, (0, None)) @@ -2556,7 +2544,7 @@ def epilogue_dK_or_dV_tma( if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] - cdKV = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) tdKVcdKV_t2r_p = thr_tmem_ld.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] From 94f50b02d24cd63e2c77274265b664517dd08c98 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 15:10:21 -0400 Subject: [PATCH 157/258] [Cute,Bwd,Sm100] Clean up a bit --- flash_attn/cute/flash_bwd_postprocess.py | 9 ++++ flash_attn/cute/flash_bwd_sm100.py | 60 +++++++----------------- 2 files changed, 26 insertions(+), 43 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index e57f28c0d66..9aa7979adf6 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -432,6 +432,15 @@ def __call__( scale: cutlass.Float32, stream: cuda.CUstream, ): + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mdQaccum, mdQ = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mdQaccum, mdQ) + ] # (b, h, s*d) -> (s*d, h, b) mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0])) # (b, s, h, d) -> (s, d, h, b) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 7ebcf7638f7..f93b30d67bd 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1,3 +1,4 @@ +# Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao. import math from typing import Callable, Optional from functools import partial @@ -7,47 +8,27 @@ import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr -from cutlass.cute.nvgpu import cpasync -import cutlass.cute.nvgpu.tcgen05 as tcgen05 - +from cutlass.utils import LayoutEnum +from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass.pipeline import PipelineAsync + +from flash_attn.cute import utils +from flash_attn.cute import copy_utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo - from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, ParamsBase, ) -from cutlass.pipeline import PipelineAsync - -from cutlass._mlir.dialects import llvm -from cutlass.cutlass_dsl import dsl_user_op -from cutlass._mlir.dialects import nvvm - -from flash_attn.cute import barrier +# from flash_attn.cute import barrier +from flash_attn.cute import named_barrier as barrier # TODO: temp, to make linter pass from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 -@dsl_user_op -def tma_reduce_add_bulk_f32( - smem_ptr: cute.Pointer, gmem_ptr: cute.Pointer, store_bytes: cutlass.Int32, *, loc=None, ip=None -): - cute.make_mma_atom - smem_u32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() - llvm.inline_asm( - None, - [gmem_ptr.llvm_ptr, smem_u32, store_bytes.ir_value()], - "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", - "l,r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - class FlashAttentionBackwardSm100: arch = 100 @@ -241,10 +222,10 @@ def __call__( mdK_semaphore = None mdV_semaphore = None - self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() - self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() - self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() - self.do_major_mode = cutlass.utils.LayoutEnum.from_tensor(mdO).mma_major_mode() + self.q_major_mode = LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = LayoutEnum.from_tensor(mV).mma_major_mode() + self.do_major_mode = LayoutEnum.from_tensor(mdO).mma_major_mode() self._setup_attributes() cta_group = tcgen05.CtaGroup.ONE @@ -390,8 +371,8 @@ def __call__( stride=(1, cute.round_up(self.tile_m, 64)), ) - self.mdK_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdK) - self.mdV_layout_enum = cutlass.utils.LayoutEnum.from_tensor(mdV) + self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) + self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) self.dK_major_mode = self.mdK_layout_enum.mma_major_mode() self.dV_major_mode = self.mdV_layout_enum.mma_major_mode() if const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): @@ -1825,13 +1806,6 @@ def compute_loop( lse_consumer_phase = psum_consumer_phase = cute.Int32(0) - sub_packed_f32x2 = partial( - cute.arch.calc_packed_f32x2_op, - src_c=None, - calc_func=nvvm.sub_packed_f32x2, - rnd=nvvm.RoundingModeKind.RN, - ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -2062,7 +2036,7 @@ def compute_loop( own1, offset=j, mask=FULL, mask_and_clamp=MAC ) - tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0] = sub_packed_f32x2( + tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0] = utils.sub_packed_f32x2( (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), (psum_j, psum_j1) ) @@ -2287,7 +2261,7 @@ def dQacc_reduce( (g_stage_index_elems,), mdQaccum_cur ).iterator - tma_reduce_add_bulk_f32(smem_ptr, gmem_row_ptr, store_bytes) + copy_utils.cpasync_reduce_bulk_add_f32(smem_ptr, gmem_row_ptr, store_bytes) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(1, read=read_flag) From e925d10c8bb619bfd68e37b1610e31670187b119 Mon Sep 17 00:00:00 2001 From: Ted Zadouri Date: Sun, 19 Oct 2025 15:33:03 -0400 Subject: [PATCH 158/258] add barrier module (#1946) --- flash_attn/cute/barrier.py | 70 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 flash_attn/cute/barrier.py diff --git a/flash_attn/cute/barrier.py b/flash_attn/cute/barrier.py new file mode 100644 index 00000000000..744e3a56507 --- /dev/null +++ b/flash_attn/cute/barrier.py @@ -0,0 +1,70 @@ +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + +@dsl_user_op +def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + +@dsl_user_op +def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@dsl_user_op +def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@cute.jit +def wait_eq( + lock_ptr : cute.Pointer, + thread_idx : int | Int32, + flag_offset : int, + val : Int32 +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + +@cute.jit +def arrive_inc( + lock_ptr : cute.Pointer, + thread_idx : int | Int32, + flag_offset : int, + val : cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) \ No newline at end of file From d0d8adb06b25002ae4232470724e2aed62e1c2cb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 15:53:47 -0400 Subject: [PATCH 159/258] [Cute,Bwd,Sm100] Have a separate function to set up the mma --- flash_attn/cute/flash_bwd_sm100.py | 437 +++++++++++++---------------- flash_attn/cute/named_barrier.py | 5 +- 2 files changed, 200 insertions(+), 242 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index f93b30d67bd..2d0d36d588f 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -64,19 +64,14 @@ def __init__( # CTA tiler self.cta_tiler = (tile_m, tile_n, self.tile_hdim) - # S = K @ Q.T self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) - # dP = V @ dO.T self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv) - # dV = P.T @ dO self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m) - # dK = dS.T @ Q (N, M) (M, D) self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m) - # dQ = dS @ K self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) @@ -135,8 +130,7 @@ def __init__( def _setup_attributes(self): self.q_stage = 2 - self.k_stage = 1 - self.v_stage = 1 + self.k_stage = self.v_stage = 1 self.do_stage = 1 self.ds_stage = 1 self.lse_stage = 1 @@ -152,232 +146,200 @@ def _setup_attributes(self): self.p_tmem_stage = 1 self.sdKdVaccum_stage = 2 - @cute.jit - def __call__( - self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mdO: cute.Tensor, - mLSE: cute.Tensor, - mPsum: cute.Tensor, - mdQaccum: cute.Tensor, - mdK: cute.Tensor, - mdV: cute.Tensor, - softmax_scale: Float32, - stream: cuda.CUstream, - mdQ_semaphore: Optional[cute.Tensor] = None, - mdK_semaphore: Optional[cute.Tensor] = None, - mdV_semaphore: Optional[cute.Tensor] = None, - ): - self.q_dtype = mQ.element_type - self.k_dtype = mK.element_type - self.v_dtype = mV.element_type - self.do_dtype = mdO.element_type - self.lse_dtype = mLSE.element_type - self.psum_dtype = mPsum.element_type - self.dqaccum_dtype = mdQaccum.element_type - self.dk_dtype = mdK.element_type - self.dv_dtype = mdV.element_type - self.ds_dtype = self.q_dtype - - if const_expr(self.qhead_per_kvhead > 1): - assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" - assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" - - QKVdO_layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO, mdK, mdV = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=QKVdO_layout_transpose)) - for t in (mQ, mK, mV, mdO, mdK, mdV) - ] - - LSE_Psum_dQaccum_layout_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) - mLSE, mPsum, mdQaccum = [ - cute.make_tensor( - t.iterator, cute.select(t.layout, mode=LSE_Psum_dQaccum_layout_transpose) - ) - for t in (mLSE, mPsum, mdQaccum) - ] - - dO_transpose = [1, 0, 2, 3] - mdO = cute.make_tensor(mdO.iterator, cute.select(mdO.layout, mode=dO_transpose)) - - semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) - if const_expr(self.deterministic): - assert mdQ_semaphore is not None - mdQ_semaphore = cute.make_tensor( - mdQ_semaphore.iterator, cute.select(mdQ_semaphore.layout, mode=semaphore_transpose) - ) - else: - mdQ_semaphore = None - - if const_expr(self.deterministic and self.qhead_per_kvhead > 1): - assert mdK_semaphore is not None - assert mdV_semaphore is not None - mdK_semaphore, mdV_semaphore = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=semaphore_transpose)) - for t in (mdK_semaphore, mdV_semaphore) - ] - else: - mdK_semaphore = None - mdV_semaphore = None - - self.q_major_mode = LayoutEnum.from_tensor(mQ).mma_major_mode() - self.k_major_mode = LayoutEnum.from_tensor(mK).mma_major_mode() - self.v_major_mode = LayoutEnum.from_tensor(mV).mma_major_mode() - self.do_major_mode = LayoutEnum.from_tensor(mdO).mma_major_mode() - - self._setup_attributes() + def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE - - # S = K @ Q.T - tiled_mma_kq = sm100_utils_basic.make_trivial_tiled_mma( - self.k_dtype, - self.k_major_mode, - self.q_major_mode, + # S = K @ Q.T, dP = V @ dO.T + tiled_mma_SdP = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, self.kq_acc_dtype, cta_group, self.mma_tiler_kq[:2], ) - # dV += P @ dO --> (K, MN) major - p_source = tcgen05.OperandSource.TMEM - self.p_major_mode = tcgen05.OperandMajorMode.K - tiled_mma_pdo = sm100_utils_basic.make_trivial_tiled_mma( + tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, - self.p_major_mode, - self.do_major_mode, + tcgen05.OperandMajorMode.K, # P_major_mode + tcgen05.OperandMajorMode.MN, # dO_major_mode self.pdo_acc_dtype, cta_group, self.mma_tiler_pdo[:2], - p_source, + a_source=tcgen05.OperandSource.TMEM, ) - - # dP = V @ dO.T - self.dot_major_mode = tcgen05.OperandMajorMode.K - tiled_mma_vdo = sm100_utils_basic.make_trivial_tiled_mma( - self.do_dtype, - self.v_major_mode, - self.dot_major_mode, - self.vdo_acc_dtype, - cta_group, - self.mma_tiler_vdo[:2], - ) - # dK += dS.T @ Q - self.dSt_major_mode = tcgen05.OperandMajorMode.K - self.q_major_mode_dsq = tcgen05.OperandMajorMode.MN - tiled_mma_dsq = sm100_utils_basic.make_trivial_tiled_mma( - self.ds_dtype, - self.dSt_major_mode, - self.q_major_mode_dsq, - self.dsq_acc_dtype, + tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Q_major_mode + self.pdo_acc_dtype, cta_group, self.mma_tiler_dsq[:2], ) - # dQ = dS @ K - self.dS_major_mode = tcgen05.OperandMajorMode.MN - self.kt_major_mode_dsq = tcgen05.OperandMajorMode.MN - tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma( - self.ds_dtype, - self.dS_major_mode, - self.kt_major_mode_dsq, + tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( + self.k_dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode self.dsk_acc_dtype, cta_group, self.mma_tiler_dsk[:2], ) - self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) - self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout(self.cluster_shape_mnk), - (tiled_mma_kq.thr_id.shape,), - ) + return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + def _setup_smem_layout(self): # S = K @ Q.T - sK_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_kq, + self.sK_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_SdP, self.mma_tiler_kq, self.k_dtype, self.k_stage, ) - sQ_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_kq, + self.sQ_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_SdP, self.mma_tiler_kq, self.q_dtype, self.q_stage, ) - # dV += P @ dO - sdO_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pdo, + self.sdO_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dV, self.mma_tiler_pdo, self.do_dtype, self.do_stage, ) - # dP = V @ dO.T - sV_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_vdo, + self.sV_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_SdP, self.mma_tiler_vdo, self.v_dtype, self.v_stage, ) - - sdOt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_vdo, + self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_SdP, self.mma_tiler_vdo, self.do_dtype, self.do_stage, ) - # dK += dS.T @ Q - sdSt_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dsq, + self.sdSt_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, self.ds_stage, ) - - sQt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_dsq, + self.sQt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dK, self.mma_tiler_dsq, self.q_dtype, self.q_stage, ) - # dQaccum = dS @ K - sdS_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dsk, + self.sdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dQ, self.mma_tiler_dsk, self.q_dtype, self.ds_stage, ) - sKt_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_dsk, + self.sKt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dQ, self.mma_tiler_dsk, self.k_dtype, self.k_stage, ) - sdQaccum_layout = cute.make_layout( - shape=(self.tile_m * 32, self.sdQaccum_stage), - ) - sLSE_layout = cute.make_layout( + self.sdQaccum_layout = cute.make_layout((self.tile_m * 32, self.sdQaccum_stage)) + self.sLSE_layout = cute.make_layout( shape=(self.tile_m, self.lse_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) - sPsum_layout = cute.make_layout( + self.sPsum_layout = cute.make_layout( shape=(self.tile_m, self.psum_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + ): + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.do_dtype = mdO.element_type + self.lse_dtype = mLSE.element_type + self.psum_dtype = mPsum.element_type + self.dqaccum_dtype = mdQaccum.element_type + self.dk_dtype = mdK.element_type + self.dv_dtype = mdV.element_type + self.ds_dtype = self.q_dtype + + if const_expr(self.qhead_per_kvhead > 1): + assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" + assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" + + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdO, mdK, mdV = [ + utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO, mdK, mdV) + ] + LSE_Psum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + mLSE, mPsum, mdQaccum = [ + utils.select(t, mode=LSE_Psum_dQaccum_transpose) for t in (mLSE, mPsum, mdQaccum) + ] + dO_transpose = [1, 0, 2, 3] + mdO = utils.select(mdO, mode=dO_transpose) + + semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) + mdQ_semaphore = None + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = utils.select(mdQ_semaphore.layout, mode=semaphore_transpose) + + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + utils.select(t.layout, mode=semaphore_transpose) + for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None + + self._setup_attributes() + self.tiled_mma_SdP, self.tiled_mma_dK, self.tiled_mma_dV, self.tiled_mma_dQ = ( + self._get_tiled_mma() + ) + self._setup_smem_layout() + + cta_group = tcgen05.CtaGroup.ONE + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (self.tiled_mma_SdP.thr_id.shape,), + ) + self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) - self.dK_major_mode = self.mdK_layout_enum.mma_major_mode() - self.dV_major_mode = self.mdV_layout_enum.mma_major_mode() - if const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): + dK_major_mode = self.mdK_layout_enum.mma_major_mode() + dV_major_mode = self.mdV_layout_enum.mma_major_mode() + if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdK is wrong") - if const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): + if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") self.sdKdV_epi_tile = ( self.tile_n, @@ -442,18 +404,18 @@ def __call__( tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mK, - cute.select(sK_layout, mode=[0, 1, 2]), + cute.select(self.sK_layout, mode=[0, 1, 2]), self.mma_tiler_kq, - tiled_mma_kq, + self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mQ, - cute.select(sQ_layout, mode=[0, 1, 2]), + cute.select(self.sQ_layout, mode=[0, 1, 2]), self.mma_tiler_kq, - tiled_mma_kq, + self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) @@ -461,9 +423,9 @@ def __call__( tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mdO, - cute.select(sdO_layout, mode=[0, 1, 2]), + cute.select(self.sdO_layout, mode=[0, 1, 2]), self.mma_tiler_pdo, - tiled_mma_pdo, + self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) tma_atom_LSE, tma_tensor_LSE = cute.nvgpu.cpasync.make_tiled_tma_atom( @@ -483,23 +445,23 @@ def __call__( tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mV, - cute.select(sV_layout, mode=[0, 1, 2]), + cute.select(self.sV_layout, mode=[0, 1, 2]), self.mma_tiler_vdo, - tiled_mma_vdo, + self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) self.tma_copy_q_bytes = cute.size_in_bytes( - self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2]) + self.q_dtype, cute.select(self.sQ_layout, mode=[0, 1, 2]) ) self.tma_copy_k_bytes = cute.size_in_bytes( - self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2]) + self.k_dtype, cute.select(self.sK_layout, mode=[0, 1, 2]) ) self.tma_copy_v_bytes = cute.size_in_bytes( - self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2]) + self.v_dtype, cute.select(self.sV_layout, mode=[0, 1, 2]) ) self.tma_copy_do_bytes = cute.size_in_bytes( - self.do_dtype, cute.select(sdO_layout, mode=[0, 1, 2]) + self.do_dtype, cute.select(self.sdO_layout, mode=[0, 1, 2]) ) self.tma_copy_lse_bytes = self.tile_m * 4 self.tma_copy_psum_bytes = self.tile_m * 4 @@ -554,35 +516,35 @@ class SharedStorage: # Smem tensors sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + cute.struct.MemRange[self.q_dtype, cute.cosize(self.sQ_layout)], self.buffer_align_bytes, ] sK: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], self.buffer_align_bytes, ] sV: cute.struct.Align[ - cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], self.buffer_align_bytes, ] sdO: cute.struct.Align[ - cute.struct.MemRange[self.do_dtype, cute.cosize(sdO_layout)], + cute.struct.MemRange[self.do_dtype, cute.cosize(self.sdO_layout)], self.buffer_align_bytes, ] sdS: cute.struct.Align[ - cute.struct.MemRange[self.ds_dtype, cute.cosize(sdSt_layout)], + cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], 128, ] sLSE: cute.struct.Align[ - cute.struct.MemRange[self.lse_dtype, cute.cosize(sLSE_layout)], + cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], 128, ] sPsum: cute.struct.Align[ - cute.struct.MemRange[self.psum_dtype, cute.cosize(sPsum_layout)], + cute.struct.MemRange[self.psum_dtype, cute.cosize(self.sPsum_layout)], 128, ] sdQaccum: cute.struct.Align[ - cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(sdQaccum_layout)], + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], self.buffer_align_bytes, ] @@ -613,24 +575,23 @@ class SharedStorage: tma_atom_dO, tma_atom_dV, tma_atom_dK, - sQ_layout, - sQt_layout, - sK_layout, - sV_layout, - sLSE_layout, - sPsum_layout, - sdO_layout, - sdOt_layout, - sdSt_layout, - sdS_layout, - sKt_layout, - sdQaccum_layout, + self.sQ_layout, + self.sQt_layout, + self.sK_layout, + self.sV_layout, + self.sLSE_layout, + self.sPsum_layout, + self.sdO_layout, + self.sdOt_layout, + self.sdSt_layout, + self.sdS_layout, + self.sKt_layout, + self.sdQaccum_layout, sdKdV_layout, - tiled_mma_kq, - tiled_mma_pdo, - tiled_mma_vdo, - tiled_mma_dsq, - tiled_mma_dsk, + self.tiled_mma_SdP, + self.tiled_mma_dV, + self.tiled_mma_dK, + self.tiled_mma_dQ, tiled_copy_r2s_dKdV, softmax_scale, softmax_scale_log2, @@ -638,7 +599,7 @@ class SharedStorage: ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, + cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, @@ -682,11 +643,10 @@ def kernel( sKt_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, sdKdV_layout: cute.ComposedLayout, - tiled_mma_kq: cute.TiledMma, - tiled_mma_pdo: cute.TiledMma, - tiled_mma_vdo: cute.TiledMma, - tiled_mma_dsq: cute.TiledMma, - tiled_mma_dsk: cute.TiledMma, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, tiled_copy_r2s_dKdV: cute.TiledCopy, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, @@ -826,7 +786,6 @@ def kernel( sQt = cute.make_tensor( cute.recast_ptr(sQ.iterator, swizzle_=sQt_layout.inner), sQt_layout.outer ) - sQ_pi = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sKt = cute.make_tensor( @@ -876,31 +835,31 @@ def kernel( # TMEM # S - thr_mma_kq = tiled_mma_kq.get_slice(0) + thr_mma_kq = tiled_mma_SdP.get_slice(0) Sacc_shape = thr_mma_kq.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) tStS = thr_mma_kq.make_fragment_C(Sacc_shape) tStS = cute.make_tensor(tStS.iterator, tStS.layout) # dV - thr_mma_pdo = tiled_mma_pdo.get_slice(0) + thr_mma_pdo = tiled_mma_dV.get_slice(0) dvacc_shape = thr_mma_pdo.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_pdo.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) # dK - thr_mma_dsq = tiled_mma_dsq.get_slice(0) + thr_mma_dsq = tiled_mma_dK.get_slice(0) dkacc_shape = thr_mma_dsq.partition_shape_C(self.mma_tiler_dsq[:2]) tdKtdK = thr_mma_dsq.make_fragment_C(dkacc_shape) tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset, tdKtdK.layout) # dQ - thr_mma_dsk = tiled_mma_dsk.get_slice(0) + thr_mma_dsk = tiled_mma_dQ.get_slice(0) dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset, tdQtdQ.layout) # dP - thr_mma_vdo = tiled_mma_vdo.get_slice(0) + thr_mma_vdo = tiled_mma_SdP.get_slice(0) dPacc_shape = thr_mma_vdo.partition_shape_C(self.mma_tiler_vdo[:2]) tdPtdP = thr_mma_vdo.make_fragment_C(dPacc_shape) tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) @@ -995,11 +954,10 @@ def kernel( cute.arch.sync_warp() self.mma( - tiled_mma_kq, - tiled_mma_pdo, - tiled_mma_vdo, - tiled_mma_dsq, - tiled_mma_dsk, + tiled_mma_SdP, + tiled_mma_dV, + tiled_mma_dK, + tiled_mma_dQ, thr_mma_kq, thr_mma_pdo, thr_mma_vdo, @@ -1353,11 +1311,10 @@ def load( @cute.jit def mma( self, - tiled_mma_kq: cute.core.TiledMma, - tiled_mma_pdo: cute.core.TiledMma, - tiled_mma_vdo: cute.core.TiledMma, - tiled_mma_dsq: cute.core.TiledMma, - tiled_mma_dsk: cute.core.TiledMma, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, thr_mma_kq: cute.core.ThrMma, thr_mma_pdo: cute.core.ThrMma, thr_mma_vdo: cute.core.ThrMma, @@ -1457,7 +1414,7 @@ def mma( # dV = P @ dO.T tdVrdO = thr_mma_pdo.make_fragment_B(sdO) p_tmem_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_pdo, + tiled_mma_dV, self.mma_tiler_pdo, self.q_dtype, self.acc_stage, @@ -1484,9 +1441,9 @@ def mma( num_k_phases = cute.size(tSrK, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): - tiled_mma_kq.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_kq, + tiled_mma_SdP, tStS, tSrK[(None, None, kphase_idx, 0)], tSrQ[(None, None, kphase_idx, q_consumer_state.index)], @@ -1504,9 +1461,9 @@ def mma( pipeline_dQaccum.producer_acquire(dQaccum_producer_state) for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_vdo, + tiled_mma_SdP, tdPtdP, tdPrV[(None, None, kphase_idx, 0)], tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], @@ -1520,9 +1477,9 @@ def mma( num_kphases = cute.size(tdVrP, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_pdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_dV.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_pdo, + tiled_mma_dV, tdVtdV, tdVrP[(None, None, kphase_idx)], tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], @@ -1547,9 +1504,9 @@ def mma( pipeline_s.producer_acquire(s_producer_state) #''' for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): - tiled_mma_kq.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_kq, + tiled_mma_SdP, tStS, tSrK[(None, None, kphase_idx, 0)], tSrQ[(None, None, kphase_idx, q_consumer_state.index)], @@ -1566,9 +1523,9 @@ def mma( num_kphases = cute.size(tdQaccrdS, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dsk.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_dQ.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_dsk, + tiled_mma_dQ, tdQacctdQacc, tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], tdQaccrK[(None, None, kphase_idx, 0)], @@ -1580,9 +1537,9 @@ def mma( # 3) dK = dS.T @ Q num_kphases = cute.size(tdKrdS, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): - tiled_mma_dsq.set(tcgen05.Field.ACCUMULATE, accumulate_dK) + tiled_mma_dK.set(tcgen05.Field.ACCUMULATE, accumulate_dK) cute.gemm( - tiled_mma_dsq, + tiled_mma_dK, tdKtdK, tdKrdS[(None, None, kphase_idx, 0)], tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], @@ -1601,9 +1558,9 @@ def mma( pipeline_dQaccum.producer_acquire(dQaccum_producer_state) for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_vdo.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_vdo, + tiled_mma_SdP, tdPtdP, tdPrV[(None, None, kphase_idx, 0)], tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], @@ -1617,9 +1574,9 @@ def mma( num_kphases = cute.size(tdVrP, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_pdo.set(tcgen05.Field.ACCUMULATE, True) + tiled_mma_dV.set(tcgen05.Field.ACCUMULATE, True) cute.gemm( - tiled_mma_pdo, + tiled_mma_dV, tdVtdV, tdVrP[(None, None, kphase_idx)], tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], @@ -1647,9 +1604,9 @@ def mma( num_kphases = cute.size(tdKrdS, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dsq.set(tcgen05.Field.ACCUMULATE, accumulate_dK) + tiled_mma_dK.set(tcgen05.Field.ACCUMULATE, accumulate_dK) cute.gemm( - tiled_mma_dsq, + tiled_mma_dK, tdKtdK, tdKrdS[(None, None, kphase_idx, dS_consumer_state.index)], tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], @@ -1664,9 +1621,9 @@ def mma( # 2) dQaccum = dS @ K num_kphases = cute.size(tdQaccrdS, mode=[2]) for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): - tiled_mma_dsk.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + tiled_mma_dQ.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( - tiled_mma_dsk, + tiled_mma_dQ, tdQacctdQacc, tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], tdQaccrK[(None, None, kphase_idx, 0)], diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 48229ccd25d..777c44079a0 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -23,8 +23,9 @@ class NamedBarrierBwd(enum.IntEnum): dQEmptyWG0 = enum.auto() dQEmptyWG1 = enum.auto() + class NamedBarrierBwdSm100(enum.IntEnum): EpilogueWG1 = enum.auto() EpilogueWG2 = enum.auto() - Compute = enum.auto() - dQaccReduce = enum.auto() \ No newline at end of file + Compute = enum.auto() + dQaccReduce = enum.auto() From 796564dd75e4bf9e15ebb3fe53cd9d2bdb099e84 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 16:13:20 -0400 Subject: [PATCH 160/258] [Cute,Bwd,Sm100] Load LSE with cpasync_bulk --- flash_attn/cute/flash_bwd_sm100.py | 84 +++++++++--------------------- 1 file changed, 25 insertions(+), 59 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 2d0d36d588f..d9cfd9edeec 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -352,10 +352,6 @@ def __call__( self.sdKdVaccum_stage, ) - self.tma_copy_dKdV_bytes = cute.size_in_bytes( - self.dk_dtype, cute.select(sdKdV_layout, mode=[0, 1]) - ) - if const_expr(self.use_tma_store): if const_expr(self.dk_dtype.width == 32): tma_copy_op_dKdV = cpasync.CopyReduceBulkTensorTileS2GOp() @@ -428,12 +424,6 @@ def __call__( self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) - tma_atom_LSE, tma_tensor_LSE = cute.nvgpu.cpasync.make_tiled_tma_atom( - tma_load_op, - mLSE, - cute.make_layout((self.tile_m)), - (self.tile_m,), - ) tma_atom_Psum, tma_tensor_Psum = cute.nvgpu.cpasync.make_tiled_tma_atom( tma_load_op, mPsum, @@ -451,20 +441,17 @@ def __call__( self.cluster_layout_vmnk.shape, ) - self.tma_copy_q_bytes = cute.size_in_bytes( - self.q_dtype, cute.select(self.sQ_layout, mode=[0, 1, 2]) - ) - self.tma_copy_k_bytes = cute.size_in_bytes( - self.k_dtype, cute.select(self.sK_layout, mode=[0, 1, 2]) - ) - self.tma_copy_v_bytes = cute.size_in_bytes( - self.v_dtype, cute.select(self.sV_layout, mode=[0, 1, 2]) - ) - self.tma_copy_do_bytes = cute.size_in_bytes( - self.do_dtype, cute.select(self.sdO_layout, mode=[0, 1, 2]) - ) - self.tma_copy_lse_bytes = self.tile_m * 4 - self.tma_copy_psum_bytes = self.tile_m * 4 + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal @@ -556,7 +543,7 @@ class SharedStorage: tma_tensor_Q, tma_tensor_K, tma_tensor_V, - tma_tensor_LSE, + mLSE, tma_tensor_Psum, tma_tensor_dO, mdV, @@ -570,7 +557,6 @@ class SharedStorage: tma_atom_Q, tma_atom_K, tma_atom_V, - tma_atom_LSE, tma_atom_Psum, tma_atom_dO, tma_atom_dV, @@ -625,7 +611,6 @@ def kernel( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - tma_atom_LSE: cute.CopyAtom, tma_atom_Psum: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, tma_atom_dV: Optional[cute.CopyAtom], @@ -660,7 +645,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - cpasync.prefetch_descriptor(tma_atom_LSE) cpasync.prefetch_descriptor(tma_atom_Psum) cpasync.prefetch_descriptor(tma_atom_dO) if const_expr(tma_atom_dV is not None): @@ -705,7 +689,7 @@ def kernel( num_stages=self.q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_q_bytes, + tx_count=self.tma_copy_bytes["Q"], ) pipeline_do = cutlass.pipeline.PipelineTmaUmma.create( @@ -713,7 +697,7 @@ def kernel( num_stages=self.do_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_do_bytes, + tx_count=self.tma_copy_bytes["dO"], ) # UMMA producers and AsyncThread consumers @@ -927,7 +911,6 @@ def kernel( tma_atom_Q, tma_atom_K, tma_atom_V, - tma_atom_LSE, tma_atom_Psum, tma_atom_dO, pipeline_q, @@ -1091,7 +1074,6 @@ def load( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - tma_atom_LSE: cute.CopyAtom, tma_atom_Psum: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, pipeline_q: PipelineAsync, @@ -1174,13 +1156,7 @@ def load( cute.group_modes(sdO, 0, 3), cute.group_modes(tdVgdO, 0, 3), ) - tLSEsLSE, tLSEgLSE = cpasync.tma_partition( - tma_atom_LSE, - 0, - cute.make_layout(1), - sLSE, - gLSE, - ) + load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) tPsumsPsum, tPsumgPsum = cpasync.tma_partition( tma_atom_Psum, 0, @@ -1190,7 +1166,7 @@ def load( ) # K with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(k_full_mbar_ptr, self.tma_copy_k_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(k_full_mbar_ptr, self.tma_copy_bytes["K"]) cute.copy(tma_atom_K, tKgK, tKsK[None, 0], tma_bar_ptr=k_full_mbar_ptr) ###### Prologue @@ -1207,18 +1183,14 @@ def load( # LSE with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(lse_full_mbar_ptr, self.tma_copy_lse_bytes) - - cute.copy( - tma_atom_LSE, - tLSEgLSE[None, m_block_max - 1], - tLSEsLSE[None, 0], - tma_bar_ptr=lse_full_mbar_ptr, - ) + cute.arch.mbarrier_arrive_and_expect_tx( + lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] + ) + load_LSE(src_idx=m_block_max - 1, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) # V with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(v_full_mbar_ptr, self.tma_copy_v_bytes) + cute.arch.mbarrier_arrive_and_expect_tx(v_full_mbar_ptr, self.tma_copy_bytes["V"]) cute.copy(tma_atom_V, tVgV, tVsV[None, 0], tma_bar_ptr=v_full_mbar_ptr) # dO @@ -1235,7 +1207,7 @@ def load( # Psum with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - psum_full_mbar_ptr, self.tma_copy_psum_bytes + psum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) cute.copy( @@ -1263,15 +1235,9 @@ def load( with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - lse_full_mbar_ptr, self.tma_copy_lse_bytes + lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - - cute.copy( - tma_atom_LSE, - tLSEgLSE[None, m_block], - tLSEsLSE[None, 0], - tma_bar_ptr=lse_full_mbar_ptr, - ) + load_LSE(src_idx=m_block, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) # dO self.load_M_tile( @@ -1291,7 +1257,7 @@ def load( with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - psum_full_mbar_ptr, self.tma_copy_psum_bytes + psum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) cute.copy( From d0399b62a9bdc1150a875ce89e4065f83f977896 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 16:19:54 -0400 Subject: [PATCH 161/258] [Cute,Bwd,Sm100] Load dPsum with cpasync_bulk --- flash_attn/cute/flash_bwd_sm100.py | 135 +++++++++++------------------ 1 file changed, 52 insertions(+), 83 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index d9cfd9edeec..867a48b6c9f 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -142,7 +142,7 @@ def _setup_attributes(self): self.dS_stage = 1 self.dQaccum_mma_stage = 1 self.sdQaccum_stage = 2 - self.psum_stage = 1 + self.dpsum_stage = 1 self.p_tmem_stage = 1 self.sdKdVaccum_stage = 2 @@ -253,8 +253,8 @@ def _setup_smem_layout(self): shape=(self.tile_m, self.lse_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) - self.sPsum_layout = cute.make_layout( - shape=(self.tile_m, self.psum_stage), + self.sdPsum_layout = cute.make_layout( + shape=(self.tile_m, self.dpsum_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) @@ -266,7 +266,7 @@ def __call__( mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, - mPsum: cute.Tensor, + mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, @@ -281,7 +281,7 @@ def __call__( self.v_dtype = mV.element_type self.do_dtype = mdO.element_type self.lse_dtype = mLSE.element_type - self.psum_dtype = mPsum.element_type + self.dpsum_dtype = mdPsum.element_type self.dqaccum_dtype = mdQaccum.element_type self.dk_dtype = mdK.element_type self.dv_dtype = mdV.element_type @@ -295,9 +295,9 @@ def __call__( mQ, mK, mV, mdO, mdK, mdV = [ utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO, mdK, mdV) ] - LSE_Psum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) - mLSE, mPsum, mdQaccum = [ - utils.select(t, mode=LSE_Psum_dQaccum_transpose) for t in (mLSE, mPsum, mdQaccum) + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) + mLSE, mdPsum, mdQaccum = [ + utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] dO_transpose = [1, 0, 2, 3] mdO = utils.select(mdO, mode=dO_transpose) @@ -405,7 +405,6 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) - tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mQ, @@ -414,7 +413,6 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) - # dV += P @ dO tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, @@ -424,13 +422,6 @@ def __call__( self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) - tma_atom_Psum, tma_tensor_Psum = cute.nvgpu.cpasync.make_tiled_tma_atom( - tma_load_op, - mPsum, - cute.make_layout((self.tile_m)), - (self.tile_m,), - ) - # dP = V @ dO.T tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, @@ -486,8 +477,8 @@ class SharedStorage: do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - psum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] - psum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.psum_stage] + dpsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dpsum_stage] + dpsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dpsum_stage] s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] @@ -526,8 +517,8 @@ class SharedStorage: cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], 128, ] - sPsum: cute.struct.Align[ - cute.struct.MemRange[self.psum_dtype, cute.cosize(self.sPsum_layout)], + sdPsum: cute.struct.Align[ + cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], 128, ] sdQaccum: cute.struct.Align[ @@ -544,7 +535,7 @@ class SharedStorage: tma_tensor_K, tma_tensor_V, mLSE, - tma_tensor_Psum, + mdPsum, tma_tensor_dO, mdV, mdK, @@ -557,7 +548,7 @@ class SharedStorage: tma_atom_Q, tma_atom_K, tma_atom_V, - tma_atom_Psum, + # tma_atom_Psum, tma_atom_dO, tma_atom_dV, tma_atom_dK, @@ -566,7 +557,7 @@ class SharedStorage: self.sK_layout, self.sV_layout, self.sLSE_layout, - self.sPsum_layout, + self.sdPsum_layout, self.sdO_layout, self.sdOt_layout, self.sdSt_layout, @@ -598,7 +589,7 @@ def kernel( mK: cute.Tensor, mV: cute.Tensor, mLSE: cute.Tensor, - mPsum: cute.Tensor, + mdPsum: cute.Tensor, mdO: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, @@ -611,7 +602,6 @@ def kernel( tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - tma_atom_Psum: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, tma_atom_dV: Optional[cute.CopyAtom], tma_atom_dK: Optional[cute.CopyAtom], @@ -620,7 +610,7 @@ def kernel( sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sLSE_layout: cute.Layout, - sPsum_layout: cute.Layout, + sdPsum_layout: cute.Layout, sdO_layout: cute.ComposedLayout, sdOt_layout: cute.ComposedLayout, sdSt_layout: cute.ComposedLayout, @@ -645,7 +635,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_Q) cpasync.prefetch_descriptor(tma_atom_K) cpasync.prefetch_descriptor(tma_atom_V) - cpasync.prefetch_descriptor(tma_atom_Psum) cpasync.prefetch_descriptor(tma_atom_dO) if const_expr(tma_atom_dV is not None): cpasync.prefetch_descriptor(tma_atom_dV) @@ -661,8 +650,8 @@ def kernel( tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() - psum_full_mbar_ptr = storage.psum_full_mbar_ptr.data_ptr() - psum_empty_mbar_ptr = storage.psum_empty_mbar_ptr.data_ptr() + dpsum_full_mbar_ptr = storage.dpsum_full_mbar_ptr.data_ptr() + dpsum_empty_mbar_ptr = storage.dpsum_empty_mbar_ptr.data_ptr() dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() if warp_idx == self.load_warp_id: @@ -673,8 +662,8 @@ def kernel( ) cute.arch.mbarrier_init(lse_full_mbar_ptr, len([self.compute_warp_ids])) cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(psum_full_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(psum_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(dpsum_full_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(dpsum_empty_mbar_ptr, len([self.compute_warp_ids])) cute.arch.mbarrier_init(dQaccum_reduce_mbar_ptr, 1) pipeline_producer_group = cutlass.pipeline.CooperativeGroup( @@ -795,9 +784,9 @@ def kernel( cute.make_layout(shape=(self.tile_m, self.tile_n, self.lse_stage), stride=(0, 1, 0)) ) - sPsum_load = storage.sPsum.get_tensor(sPsum_layout) - sPsum_mma = storage.sPsum.get_tensor( - cute.make_layout(shape=(self.tile_m, self.tile_n, self.psum_stage), stride=(0, 1, 0)) + sdPsum_load = storage.sdPsum.get_tensor(sdPsum_layout) + sdPsum_mma = storage.sdPsum.get_tensor( + cute.make_layout(shape=(self.tile_m, self.tile_n, self.dpsum_stage), stride=(0, 1, 0)) ) sdV = storage.sdO.get_tensor( @@ -900,24 +889,23 @@ def kernel( mK, mV, mLSE, - mPsum, + mdPsum, mdO, sQ, sK, sV, sLSE_load, - sPsum_load, + sdPsum_load, sdO, tma_atom_Q, tma_atom_K, tma_atom_V, - tma_atom_Psum, tma_atom_dO, pipeline_q, lse_full_mbar_ptr, lse_empty_mbar_ptr, - psum_full_mbar_ptr, - psum_empty_mbar_ptr, + dpsum_full_mbar_ptr, + dpsum_empty_mbar_ptr, pipeline_do, k_full_mbar_ptr, v_full_mbar_ptr, @@ -997,7 +985,7 @@ def kernel( thr_mma_dsq, tStS, sLSE_mma, - sPsum_mma, + sdPsum_mma, tdVtdV, tdKtdK, mdV, @@ -1007,8 +995,8 @@ def kernel( tdPtdP, lse_full_mbar_ptr, lse_empty_mbar_ptr, - psum_full_mbar_ptr, - psum_empty_mbar_ptr, + dpsum_full_mbar_ptr, + dpsum_empty_mbar_ptr, pipeline_s, pipeline_p, pipeline_dS, @@ -1063,24 +1051,23 @@ def load( mK: cute.Tensor, mV: cute.Tensor, mLSE: cute.Tensor, - mPsum: cute.Tensor, + mdPsum: cute.Tensor, mdO: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, sLSE: cute.Tensor, - sPsum: cute.Tensor, + sdPsum: cute.Tensor, sdO: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, - tma_atom_Psum: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, pipeline_q: PipelineAsync, lse_full_mbar_ptr: cute.Pointer, lse_empty_mbar_ptr: cute.Pointer, - psum_full_mbar_ptr: cute.Pointer, - psum_empty_mbar_ptr: cute.Pointer, + dpsum_full_mbar_ptr: cute.Pointer, + dpsum_empty_mbar_ptr: cute.Pointer, pipeline_do: PipelineAsync, k_full_mbar_ptr: cute.Pointer, v_full_mbar_ptr: cute.Pointer, @@ -1111,7 +1098,7 @@ def load( mV_cur = mV[None, None, head_idx_kv, batch_idx] mdO_cur = mdO[None, None, head_idx, batch_idx] mLSE_cur = mLSE[None, head_idx, batch_idx] - mPsum_cur = mPsum[None, head_idx, batch_idx] + mPsum_cur = mdPsum[None, head_idx, batch_idx] gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) tSgK = thr_mma_kq.partition_A(gK) @@ -1123,7 +1110,7 @@ def load( tSgQ = thr_mma_kq.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) - gPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) + gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdVgdO = thr_mma_pdo.partition_B(gdO) @@ -1157,13 +1144,8 @@ def load( cute.group_modes(tdVgdO, 0, 3), ) load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) - tPsumsPsum, tPsumgPsum = cpasync.tma_partition( - tma_atom_Psum, - 0, - cute.make_layout(1), - sPsum, - gPsum, - ) + load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) + # K with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx(k_full_mbar_ptr, self.tma_copy_bytes["K"]) @@ -1204,20 +1186,15 @@ def load( pipeline_do.producer_commit(do_producer_state) do_producer_state.advance() - # Psum + # dPsum with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - psum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + dpsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) + load_dPsum(src_idx=m_block_max - 1, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) - cute.copy( - tma_atom_Psum, - tPsumgPsum[None, m_block_max - 1], - tPsumsPsum[None, 0], - tma_bar_ptr=psum_full_mbar_ptr, - ) lse_empty_consumer_phase = cute.Int32(0) - psum_empty_consumer_phase = cute.Int32(0) + dpsum_empty_consumer_phase = cute.Int32(0) for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): m_block = m_block_max - 2 - i @@ -1232,7 +1209,6 @@ def load( # LSE cute.arch.mbarrier_wait(lse_empty_mbar_ptr, lse_empty_consumer_phase) lse_empty_consumer_phase ^= 1 - with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] @@ -1251,21 +1227,14 @@ def load( pipeline_do.producer_commit(do_producer_state) do_producer_state.advance() - # Psum - cute.arch.mbarrier_wait(psum_empty_mbar_ptr, psum_empty_consumer_phase) - psum_empty_consumer_phase ^= 1 - + # dPsum + cute.arch.mbarrier_wait(dpsum_empty_mbar_ptr, dpsum_empty_consumer_phase) + dpsum_empty_consumer_phase ^= 1 with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - psum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + dpsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - - cute.copy( - tma_atom_Psum, - tPsumgPsum[None, m_block], - tPsumsPsum[None, 0], - tma_bar_ptr=psum_full_mbar_ptr, - ) + load_dPsum(src_idx=m_block, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) pipeline_q.producer_tail(q_producer_state) pipeline_do.producer_tail(do_producer_state) @@ -1669,8 +1638,8 @@ def compute_loop( tdPtdP: cute.Tensor, lse_full_mbar_ptr: cute.Pointer, lse_empty_mbar_ptr: cute.Pointer, - psum_full_mbar_ptr: cute.Pointer, - psum_empty_mbar_ptr: cute.Pointer, + dpsum_full_mbar_ptr: cute.Pointer, + dpsum_empty_mbar_ptr: cute.Pointer, pipeline_s: PipelineAsync, pipeline_p: PipelineAsync, pipeline_dS: PipelineAsync, @@ -1890,7 +1859,7 @@ def compute_loop( # dS.T = P.T * (dP.T - D) # --------------------------------------------- if warp_idx == self.compute_warp_ids[0]: - cute.arch.mbarrier_wait(psum_full_mbar_ptr, psum_consumer_phase) + cute.arch.mbarrier_wait(dpsum_full_mbar_ptr, psum_consumer_phase) psum_consumer_phase ^= 1 pipeline_dP.consumer_wait(dP_consumer_state) @@ -1989,7 +1958,7 @@ def compute_loop( if warp_idx == self.compute_warp_ids[0]: with cute.arch.elect_one(): - cute.arch.mbarrier_arrive(psum_empty_mbar_ptr) + cute.arch.mbarrier_arrive(dpsum_empty_mbar_ptr) if const_expr(not self.use_tma_store): self.epilogue_dKV( From 372f3e2ba78cb984f8296e7b2b2cec25e330eca6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 16:33:00 -0400 Subject: [PATCH 162/258] [Cute,Bwd,Sm100] Use copy_utils functions to load Q & dO --- flash_attn/cute/flash_bwd_sm100.py | 119 +++++++++++------------------ 1 file changed, 43 insertions(+), 76 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 867a48b6c9f..5572845a884 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -673,7 +673,7 @@ def kernel( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - pipeline_q = cutlass.pipeline.PipelineTmaUmma.create( + pipeline_Q = cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=storage.q_mbar_ptr.data_ptr(), num_stages=self.q_stage, producer_group=pipeline_producer_group, @@ -681,7 +681,7 @@ def kernel( tx_count=self.tma_copy_bytes["Q"], ) - pipeline_do = cutlass.pipeline.PipelineTmaUmma.create( + pipeline_dO = cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=storage.do_mbar_ptr.data_ptr(), num_stages=self.do_stage, producer_group=pipeline_producer_group, @@ -901,12 +901,12 @@ def kernel( tma_atom_K, tma_atom_V, tma_atom_dO, - pipeline_q, + pipeline_Q, lse_full_mbar_ptr, lse_empty_mbar_ptr, dpsum_full_mbar_ptr, dpsum_empty_mbar_ptr, - pipeline_do, + pipeline_dO, k_full_mbar_ptr, v_full_mbar_ptr, block_info, @@ -950,8 +950,8 @@ def kernel( tdKtdK, tdPtdP, tdQtdQ, - pipeline_q, - pipeline_do, + pipeline_Q, + pipeline_dO, pipeline_s, pipeline_p, pipeline_dS, @@ -1063,25 +1063,22 @@ def load( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, - pipeline_q: PipelineAsync, + pipeline_Q: PipelineAsync, lse_full_mbar_ptr: cute.Pointer, lse_empty_mbar_ptr: cute.Pointer, dpsum_full_mbar_ptr: cute.Pointer, dpsum_empty_mbar_ptr: cute.Pointer, - pipeline_do: PipelineAsync, + pipeline_dO: PipelineAsync, k_full_mbar_ptr: cute.Pointer, v_full_mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx = cute.arch.thread_idx()[0] - - q_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.q_stage ) - do_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.do_stage ) @@ -1089,7 +1086,6 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) head_idx_kv = head_idx // self.qhead_per_kvhead @@ -1129,20 +1125,12 @@ def load( cute.group_modes(sV, 0, 3), cute.group_modes(tdPgV, 0, 3), ) - tQsQ, tQgQ = cpasync.tma_partition( - tma_atom_Q, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sQ, 0, 3), - cute.group_modes(tSgQ, 0, 3), - ) - tdOsdO, tdOgdO = cpasync.tma_partition( - tma_atom_dO, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sdO, 0, 3), - cute.group_modes(tdVgdO, 0, 3), + load_Q, _, _ = copy_utils.tma_get_copy_fn(tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, 0, cute.make_layout(1), tdVgdO, sdO ) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) @@ -1153,15 +1141,10 @@ def load( ###### Prologue # Q0 - pipeline_q.producer_acquire(q_producer_state) - cute.copy( - tma_atom_Q, - tQgQ[None, m_block_max - 1], - tQsQ[None, q_producer_state.index], - tma_bar_ptr=pipeline_q.producer_get_barrier(q_producer_state), - ) - pipeline_q.producer_commit(q_producer_state) - q_producer_state.advance() + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block_max - 1, producer_state=producer_state_Q) + pipeline_Q.producer_commit(producer_state_Q) + producer_state_Q.advance() # LSE with cute.arch.elect_one(): @@ -1176,15 +1159,10 @@ def load( cute.copy(tma_atom_V, tVgV, tVsV[None, 0], tma_bar_ptr=v_full_mbar_ptr) # dO - pipeline_do.producer_acquire(do_producer_state) - cute.copy( - tma_atom_dO, - tdOgdO[None, m_block_max - 1], - tdOsdO[None, do_producer_state.index], - tma_bar_ptr=pipeline_do.producer_get_barrier(do_producer_state), - ) - pipeline_do.producer_commit(do_producer_state) - do_producer_state.advance() + pipeline_dO.producer_acquire(producer_state_dO) + load_dO(m_block_max - 1, producer_state=producer_state_dO) + pipeline_dO.producer_commit(producer_state_dO) + producer_state_dO.advance() # dPsum with cute.arch.elect_one(): @@ -1198,14 +1176,11 @@ def load( for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): m_block = m_block_max - 2 - i - # Q - self.load_M_tile( - tma_atom_Q, tQgQ, tQsQ, pipeline_q, m_block, producer_state=q_producer_state - ) - pipeline_q.producer_commit(q_producer_state) - q_producer_state.advance() - + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + pipeline_Q.producer_commit(producer_state_Q) + producer_state_Q.advance() # LSE cute.arch.mbarrier_wait(lse_empty_mbar_ptr, lse_empty_consumer_phase) lse_empty_consumer_phase ^= 1 @@ -1214,19 +1189,11 @@ def load( lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) load_LSE(src_idx=m_block, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) - # dO - self.load_M_tile( - tma_atom_dO, - tdOgdO, - tdOsdO, - pipeline_do, - m_block, - producer_state=do_producer_state, - ) - pipeline_do.producer_commit(do_producer_state) - do_producer_state.advance() - + pipeline_dO.producer_acquire(producer_state_dO) + load_dO(m_block, producer_state=producer_state_dO) + pipeline_dO.producer_commit(producer_state_dO) + producer_state_dO.advance() # dPsum cute.arch.mbarrier_wait(dpsum_empty_mbar_ptr, dpsum_empty_consumer_phase) dpsum_empty_consumer_phase ^= 1 @@ -1236,8 +1203,8 @@ def load( ) load_dPsum(src_idx=m_block, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) - pipeline_q.producer_tail(q_producer_state) - pipeline_do.producer_tail(do_producer_state) + pipeline_Q.producer_tail(producer_state_Q) + pipeline_dO.producer_tail(producer_state_dO) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1271,8 +1238,8 @@ def mma( tdKtdK: cute.Tensor, tdPtdP: cute.Tensor, tdQacctdQacc: cute.Tensor, - pipeline_q: PipelineAsync, - pipeline_do: PipelineAsync, + pipeline_Q: PipelineAsync, + pipeline_dO: PipelineAsync, pipeline_s: PipelineAsync, pipeline_p: PipelineAsync, pipeline_dS: PipelineAsync, @@ -1371,7 +1338,7 @@ def mma( # 3. dV = P @ dO # 1) S = Q0 @ K.T - pipeline_q.consumer_wait(q_consumer_state) + pipeline_Q.consumer_wait(q_consumer_state) pipeline_s.producer_acquire(s_producer_state) num_k_phases = cute.size(tSrK, mode=[2]) @@ -1390,7 +1357,7 @@ def mma( s_producer_state.advance() # 2) dP = V @ dO.T - pipeline_do.consumer_wait(do_consumer_state) + pipeline_dO.consumer_wait(do_consumer_state) pipeline_dP.producer_acquire(dP_producer_state) pipeline_dQaccum.producer_acquire(dQaccum_producer_state) @@ -1422,7 +1389,7 @@ def mma( ) pipeline_p.consumer_release(p_consumer_state) p_consumer_state.advance() - pipeline_do.consumer_release(do_consumer_state) + pipeline_dO.consumer_release(do_consumer_state) do_consumer_state.advance() # ----------------------------------------------------------- ###### MAIN LOOP @@ -1435,7 +1402,7 @@ def mma( for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): # 1) S = K @ Q_i - pipeline_q.consumer_wait(q_consumer_state) + pipeline_Q.consumer_wait(q_consumer_state) pipeline_s.producer_acquire(s_producer_state) #''' for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): @@ -1482,13 +1449,13 @@ def mma( ) accumulate_dK = True - pipeline_q.consumer_release(q_dk_consumer_state) + pipeline_Q.consumer_release(q_dk_consumer_state) q_dk_consumer_state.advance() pipeline_dS.consumer_release(dS_consumer_state) dS_consumer_state.advance() # 4) dP = V @ dO.T - pipeline_do.consumer_wait(do_consumer_state) + pipeline_dO.consumer_wait(do_consumer_state) pipeline_dQaccum.producer_acquire(dQaccum_producer_state) @@ -1520,7 +1487,7 @@ def mma( pipeline_p.consumer_release(p_consumer_state) p_consumer_state.advance() - pipeline_do.consumer_release(do_consumer_state) + pipeline_dO.consumer_release(do_consumer_state) do_consumer_state.advance() pipeline_dV.producer_acquire(dV_producer_state) @@ -1566,7 +1533,7 @@ def mma( ) pipeline_dQaccum.producer_commit(dQaccum_producer_state) dQaccum_producer_state.advance() - pipeline_q.consumer_release(q_dk_consumer_state) + pipeline_Q.consumer_release(q_dk_consumer_state) q_dk_consumer_state.advance() pipeline_dS.consumer_release(dS_consumer_state) dS_consumer_state.advance() From c0c8c2df3e0c2187486c2390595abfab58379770 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 16:48:21 -0400 Subject: [PATCH 163/258] [Cute,Bwd,Sm100] Load K & Q, V & dO in the first iteration --- flash_attn/cute/flash_bwd_sm100.py | 88 +++++++----------------------- 1 file changed, 19 insertions(+), 69 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 5572845a884..eb754048e08 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -15,6 +15,7 @@ from flash_attn.cute import utils from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo @@ -471,8 +472,6 @@ def __call__( @cute.struct class SharedStorage: q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] - k_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - v_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.v_stage] lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] @@ -645,8 +644,6 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - k_full_mbar_ptr = storage.k_full_mbar_ptr.data_ptr() - v_full_mbar_ptr = storage.v_full_mbar_ptr.data_ptr() tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() @@ -655,8 +652,6 @@ def kernel( dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() if warp_idx == self.load_warp_id: - cute.arch.mbarrier_init(k_full_mbar_ptr, len([self.load_warp_id])) - cute.arch.mbarrier_init(v_full_mbar_ptr, len([self.load_warp_id])) cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) @@ -673,20 +668,22 @@ def kernel( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - pipeline_Q = cutlass.pipeline.PipelineTmaUmma.create( + pipeline_Q = pipeline.PipelineTmaUmma.create( barrier_storage=storage.q_mbar_ptr.data_ptr(), num_stages=self.q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"], + init_wait=False, ) - pipeline_dO = cutlass.pipeline.PipelineTmaUmma.create( + pipeline_dO = pipeline.PipelineTmaUmma.create( barrier_storage=storage.do_mbar_ptr.data_ptr(), num_stages=self.do_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"], + init_wait=False, ) # UMMA producers and AsyncThread consumers @@ -907,8 +904,6 @@ def kernel( dpsum_full_mbar_ptr, dpsum_empty_mbar_ptr, pipeline_dO, - k_full_mbar_ptr, - v_full_mbar_ptr, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -959,8 +954,6 @@ def kernel( pipeline_dK, pipeline_dP, pipeline_dQaccum, - k_full_mbar_ptr, - v_full_mbar_ptr, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1069,8 +1062,6 @@ def load( dpsum_full_mbar_ptr: cute.Pointer, dpsum_empty_mbar_ptr: cute.Pointer, pipeline_dO: PipelineAsync, - k_full_mbar_ptr: cute.Pointer, - v_full_mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1111,19 +1102,16 @@ def load( gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdVgdO = thr_mma_pdo.partition_B(gdO) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK, 0, 3), + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), tSgK, sK[None, None, None, 0], single_stage=True ) - tVsV, tVgV = cpasync.tma_partition( + load_V, _, _ = copy_utils.tma_get_copy_fn( tma_atom_V, - 0, # no multicast + 0, cute.make_layout(1), - cute.group_modes(sV, 0, 3), - cute.group_modes(tdPgV, 0, 3), + tdPgV, + sV[None, None, None, 0], + single_stage=True, ) load_Q, _, _ = copy_utils.tma_get_copy_fn(tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ) load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) @@ -1134,36 +1122,25 @@ def load( load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) - # K - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(k_full_mbar_ptr, self.tma_copy_bytes["K"]) - cute.copy(tma_atom_K, tKgK, tKsK[None, 0], tma_bar_ptr=k_full_mbar_ptr) - - ###### Prologue - # Q0 - pipeline_Q.producer_acquire(producer_state_Q) + # First iteration: load K together w Q & LSE, then V together w dO & dPsum + # K & Q + pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) load_Q(m_block_max - 1, producer_state=producer_state_Q) pipeline_Q.producer_commit(producer_state_Q) producer_state_Q.advance() - # LSE with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) load_LSE(src_idx=m_block_max - 1, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) - - # V - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(v_full_mbar_ptr, self.tma_copy_bytes["V"]) - cute.copy(tma_atom_V, tVgV, tVsV[None, 0], tma_bar_ptr=v_full_mbar_ptr) - - # dO - pipeline_dO.producer_acquire(producer_state_dO) + # V & dO + pipeline_dO.producer_acquire(producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"]) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) load_dO(m_block_max - 1, producer_state=producer_state_dO) pipeline_dO.producer_commit(producer_state_dO) producer_state_dO.advance() - # dPsum with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( @@ -1247,15 +1224,10 @@ def mma( pipeline_dK: PipelineAsync, pipeline_dP: PipelineAsync, pipeline_dQaccum: PipelineAsync, - full_key_mbar_ptr: cute.Pointer, - full_value_mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - key_consumer_phase = cutlass.Int32(0) - q_consumer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.q_stage ) @@ -1294,10 +1266,6 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - cute.arch.mbarrier_wait(full_key_mbar_ptr, phase=key_consumer_phase) - cute.arch.mbarrier_wait(full_value_mbar_ptr, phase=key_consumer_phase) - - key_consumer_phase ^= 1 # S = K @ Q.T sK and sQ tSrK = thr_mma_kq.make_fragment_A(sK) @@ -2460,21 +2428,3 @@ def epilogue_dK_or_dV_tma( pipeline.consumer_release(consumer_state) consumer_state.advance() - - @cute.jit - def load_M_tile( - self, - tma_atom: cute.CopyAtom, - tQgQ: cute.Tensor, - tQsQ: cute.Tensor, - pipeline: PipelineAsync, - block: cutlass.Int32, - producer_state: cutlass.pipeline.PipelineState, - ): - pipeline.producer_acquire(producer_state) - cute.copy( - tma_atom, - tQgQ[None, block], - tQsQ[None, producer_state.index], - tma_bar_ptr=pipeline.producer_get_barrier(producer_state), - ) From 7b17cd8b693661097d5586358db63d5607e0efea Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 17:44:12 -0400 Subject: [PATCH 164/258] [Cute,Bwd,Sm100] Simplify mma by using functools.partial --- flash_attn/cute/blackwell_helpers.py | 261 ++++++++------- flash_attn/cute/flash_bwd_sm100.py | 455 ++++++++++----------------- 2 files changed, 309 insertions(+), 407 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 4f61a40cdc3..aefb6182575 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -1,7 +1,9 @@ # Copyright (c) 2025, Tri Dao. from typing import Optional, Tuple + import cutlass import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm @@ -9,13 +11,37 @@ from flash_attn.cute.utils import parse_swizzle_from_pointer +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + @cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, - zero_init: bool | cutlass.Boolean = False, + zero_init: bool | Boolean = False, ) -> cute.TiledMma: for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) @@ -36,56 +62,56 @@ def gemm_ptx( tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, - zero_init: bool | cutlass.Boolean = False, + zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout - idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) - if cutlass.const_expr(not is_ts): + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr( + smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) - smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) - smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr( + smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) - smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) - smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) - if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32( - smem_desc_base_a_lo - ) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32( - smem_desc_base_b_lo - ) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): smem_desc_a_lo = smem_desc_start_a_lo + ( (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 ) @@ -96,14 +122,14 @@ def gemm_ptx( # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) with cute.arch.elect_one(): - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), smem_desc_a_lo.ir_value(), smem_desc_b_lo.ir_value(), - cutlass.Int32(not zero_init or k != 0).ir_value(), + Int32(not zero_init or k != 0).ir_value(), ], "{\n\t" ".reg .pred p;\n\t" @@ -127,7 +153,7 @@ def gemm_ptx( acc.iterator.toint().ir_value(), tCrA[None, None, k].iterator.toint().ir_value(), smem_desc_b_lo.ir_value(), - cutlass.Int32(not zero_init or k != 0).ir_value(), + Int32(not zero_init or k != 0).ir_value(), ], "{\n\t" ".reg .pred p;\n\t" @@ -151,46 +177,46 @@ def gemm_ptx_loop( tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, - zero_init: bool | cutlass.Boolean = False, + zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout - idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) - if cutlass.const_expr(not is_ts): + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr( + smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) - smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) - smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr( + smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) - smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) - smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): offset_a = [ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) @@ -211,24 +237,24 @@ def gemm_ptx_loop( offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) ] - if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32( + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) ) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32( + smem_desc_start_b_lo = Int32( smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) ) - pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" - if cutlass.const_expr(not is_ts): + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" @@ -268,9 +294,9 @@ def gemm_ptx_loop( None, [ acc.iterator.toint().ir_value(), - cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), - cutlass.Int32(smem_desc_start_b_lo).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" @@ -315,49 +341,49 @@ def gemm_ptx_partial( sA: Optional[cute.Tensor], sB: cute.Tensor, mbar_ptr: Optional[cutlass.Pointer] = None, - mbar_phase: Optional[cutlass.Int32] = None, - zero_init: bool | cutlass.Boolean = False, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout - idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) - if cutlass.const_expr(not is_ts): + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): sA_swizzle = parse_swizzle_from_pointer(sA.iterator) - smem_desc_base_a: int = cutlass.const_expr( + smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) - smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) - smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = parse_swizzle_from_pointer(sB.iterator) - smem_desc_base_b: int = cutlass.const_expr( + smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) - smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) - smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) tCrA_layout = ( tCrA.layout - if cutlass.const_expr(not is_ts) + if const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) ) offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] @@ -365,25 +391,25 @@ def gemm_ptx_partial( offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] - if cutlass.const_expr(not is_ts): - smem_desc_start_a_lo = cutlass.Int32( + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) ) else: smem_desc_start_a_lo = None - smem_desc_start_b_lo = cutlass.Int32( + smem_desc_start_b_lo = Int32( smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) ) - pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" - if cutlass.const_expr(not is_ts): + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" @@ -422,16 +448,14 @@ def gemm_ptx_partial( ) else: input_args = [ - cutlass.Int32( - cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint()) - ).ir_value(), - cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), ] - if cutlass.const_expr(mbar_ptr is not None): + if const_expr(mbar_ptr is not None): assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" input_args.append(mbar_ptr.toint().ir_value()) - input_args.append(cutlass.Int32(mbar_phase).ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) mbar_wait_str = ( ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" @@ -446,9 +470,9 @@ def gemm_ptx_partial( None, # [ # # acc.iterator.toint().ir_value(), - # cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), - # cutlass.Int32(smem_desc_start_b_lo).ir_value(), - # cutlass.Int32(not zero_init).ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), # ], input_args, "{\n\t" @@ -480,7 +504,7 @@ def gemm_ptx_partial( for k in range( 1, cute.size(tCrA.shape[2]) - if cutlass.const_expr(mbar_ptr is None) + if const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 4 * 3, ) ) @@ -494,12 +518,11 @@ def gemm_ptx_partial( ) for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) ) - if cutlass.const_expr(mbar_ptr is not None) + if const_expr(mbar_ptr is not None) else "" ) + "}\n", - # "r,r,r", - "r,r,r" if cutlass.const_expr(mbar_ptr is None) else "r,r,r,r,r", + "r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, @@ -512,54 +535,54 @@ def gemm_ptx_partial1( acc_tmem_addr: cutlass.Constexpr[int], tCrA: cute.Tensor, tCrB: cute.Tensor, - sA_base_addr_for_desc: cutlass.Int32, + sA_base_addr_for_desc: Int32, sA_addr_offset_for_desc: cutlass.Constexpr[int], - sA_stage: cutlass.Int32, - sB_base_addr_for_desc: cutlass.Int32, + sA_stage: Int32, + sB_base_addr_for_desc: Int32, sB_addr_offset_for_desc: cutlass.Constexpr[int], - sB_stage: cutlass.Int32, + sB_stage: Int32, sA_layout: Optional[cute.Layout], sB_layout: Optional[cute.Layout], sA_swizzle: Optional[cute.Swizzle], sB_swizzle: cute.Swizzle, - zero_init: bool | cutlass.Boolean = False, + zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" - idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) - if cutlass.const_expr(not is_ts): - smem_desc_base_a: int = cutlass.const_expr( + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) - smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) - smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None - smem_desc_base_b: int = cutlass.const_expr( + smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K - if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) - smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) - smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) - mask = [cutlass.Int32(0)] * 4 + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 - if cutlass.const_expr(not is_ts): + if const_expr(not is_ts): offset_a = [ (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 for k in range(cute.size(tCrA.shape[2])) @@ -576,26 +599,26 @@ def gemm_ptx_partial1( ] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] - if cutlass.const_expr(not is_ts): - # smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) - smem_desc_start_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) else: smem_desc_start_a_lo = None - # smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) - smem_desc_start_b_lo = cutlass.const_expr(smem_desc_base_b_lo) - pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" - if cutlass.const_expr(not is_ts): + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), - # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), - cutlass.Int32(sA_base_addr_for_desc).ir_value(), - cutlass.Int32(sA_stage).ir_value(), - # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), - cutlass.Int32(sB_base_addr_for_desc).ir_value(), - cutlass.Int32(sB_stage).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), @@ -644,9 +667,9 @@ def gemm_ptx_partial1( None, [ # acc.iterator.toint().ir_value(), - cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), - cutlass.Int32(smem_desc_start_b_lo).ir_value(), - cutlass.Int32(not zero_init).ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index eb754048e08..247dc669b02 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -16,6 +16,7 @@ from flash_attn.cute import utils from flash_attn.cute import copy_utils from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo @@ -694,7 +695,7 @@ def kernel( cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) - pipeline_s = cutlass.pipeline.PipelineUmmaAsync.create( + pipeline_S = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=self.s_stage, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, @@ -717,7 +718,7 @@ def kernel( cute.arch.WARP_SIZE * len(self.reduce_warp_ids), alignment=128, ) # Compute - pipeline_dQaccum = cutlass.pipeline.PipelineUmmaAsync.create( + pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=self.dQaccum_mma_stage, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, @@ -738,7 +739,7 @@ def kernel( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) # MMA - pipeline_p = cutlass.pipeline.PipelineAsyncUmma.create( + pipeline_P = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=self.s_stage, producer_group=pipeline_pdS_producer_group, consumer_group=pipeline_pdS_consumer_group, @@ -805,33 +806,28 @@ def kernel( # TMEM # S - thr_mma_kq = tiled_mma_SdP.get_slice(0) - Sacc_shape = thr_mma_kq.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) - tStS = thr_mma_kq.make_fragment_C(Sacc_shape) + thr_mma_SdP = tiled_mma_SdP.get_slice(0) + Sacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) + tStS = thr_mma_SdP.make_fragment_C(Sacc_shape) tStS = cute.make_tensor(tStS.iterator, tStS.layout) - # dV - thr_mma_pdo = tiled_mma_dV.get_slice(0) - dvacc_shape = thr_mma_pdo.partition_shape_C(self.mma_tiler_pdo[:2]) - tdVtdV = thr_mma_pdo.make_fragment_C(dvacc_shape) + thr_mma_dV = tiled_mma_dV.get_slice(0) + dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) + tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) - # dK - thr_mma_dsq = tiled_mma_dK.get_slice(0) - dkacc_shape = thr_mma_dsq.partition_shape_C(self.mma_tiler_dsq[:2]) - tdKtdK = thr_mma_dsq.make_fragment_C(dkacc_shape) + thr_mma_dK = tiled_mma_dK.get_slice(0) + dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) + tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset, tdKtdK.layout) - # dQ - thr_mma_dsk = tiled_mma_dQ.get_slice(0) - dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2]) - tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + thr_mma_dQ = tiled_mma_dQ.get_slice(0) + dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset, tdQtdQ.layout) - # dP - thr_mma_vdo = tiled_mma_SdP.get_slice(0) - dPacc_shape = thr_mma_vdo.partition_shape_C(self.mma_tiler_vdo[:2]) - tdPtdP = thr_mma_vdo.make_fragment_C(dPacc_shape) + dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) block_info = BlockInfo( @@ -879,9 +875,8 @@ def kernel( if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_load) self.load( - thr_mma_kq, - thr_mma_pdo, - thr_mma_vdo, + thr_mma_SdP, + thr_mma_dV, mQ, mK, mV, @@ -924,11 +919,6 @@ def kernel( tiled_mma_dV, tiled_mma_dK, tiled_mma_dQ, - thr_mma_kq, - thr_mma_pdo, - thr_mma_vdo, - thr_mma_dsq, - thr_mma_dsk, sQ, sQt, sK, @@ -938,8 +928,6 @@ def kernel( sdSt, sdS, sKt, - sK_layout.inner, - sQ_layout.inner, tStS, tdVtdV, tdKtdK, @@ -947,13 +935,13 @@ def kernel( tdQtdQ, pipeline_Q, pipeline_dO, - pipeline_s, - pipeline_p, + pipeline_S, + pipeline_P, pipeline_dS, pipeline_dV, pipeline_dK, pipeline_dP, - pipeline_dQaccum, + pipeline_dQ, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -972,10 +960,9 @@ def kernel( if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_compute) # 8 warps self.compute_loop( - thr_mma_kq, - thr_mma_pdo, - thr_mma_vdo, - thr_mma_dsq, + thr_mma_SdP, + thr_mma_dV, + thr_mma_dK, tStS, sLSE_mma, sdPsum_mma, @@ -990,8 +977,8 @@ def kernel( lse_empty_mbar_ptr, dpsum_full_mbar_ptr, dpsum_empty_mbar_ptr, - pipeline_s, - pipeline_p, + pipeline_S, + pipeline_P, pipeline_dS, pipeline_dV, pipeline_dK, @@ -1022,9 +1009,9 @@ def kernel( self.dQacc_reduce( mdQaccum, sdQaccum, - thr_mma_dsk, + thr_mma_dQ, tdQtdQ, - pipeline_dQaccum, + pipeline_dQ, dQaccum_reduce_mbar_ptr, block_info, SeqlenInfoCls, @@ -1037,9 +1024,8 @@ def kernel( @cute.jit def load( self, - thr_mma_kq: cute.core.ThrMma, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_vdo: cute.core.ThrMma, + thr_mma_SdP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, @@ -1088,19 +1074,15 @@ def load( mPsum_cur = mdPsum[None, head_idx, batch_idx] gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) - tSgK = thr_mma_kq.partition_A(gK) - + tSgK = thr_mma_SdP.partition_A(gK) gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) - tdPgV = thr_mma_vdo.partition_A(gV) - + tdPgV = thr_mma_SdP.partition_A(gV) gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) - tSgQ = thr_mma_kq.partition_B(gQ) - + tSgQ = thr_mma_SdP.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) - gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) - tdVgdO = thr_mma_pdo.partition_B(gdO) + tdVgdO = thr_mma_dV.partition_B(gdO) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), tSgK, sK[None, None, None, 0], single_stage=True @@ -1194,11 +1176,6 @@ def mma( tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, - thr_mma_kq: cute.core.ThrMma, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_vdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, - thr_mma_dsk: cute.core.ThrMma, sQ: cute.Tensor, sQt: cute.Tensor, sK: cute.Tensor, @@ -1208,44 +1185,81 @@ def mma( sdSt: cute.Tensor, sdS: cute.Tensor, sKt: cute.Tensor, - sK_swizzle: cute.Swizzle, - sQ_swizzle: cute.Swizzle, tStS: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, tdPtdP: cute.Tensor, - tdQacctdQacc: cute.Tensor, + tdQtdQ: cute.Tensor, pipeline_Q: PipelineAsync, pipeline_dO: PipelineAsync, - pipeline_s: PipelineAsync, - pipeline_p: PipelineAsync, + pipeline_S: PipelineAsync, + pipeline_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dV: PipelineAsync, pipeline_dK: PipelineAsync, pipeline_dP: PipelineAsync, - pipeline_dQaccum: PipelineAsync, + pipeline_dQ: PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - q_consumer_state = cutlass.pipeline.make_pipeline_state( + thr_mma_SdP = tiled_mma_SdP.get_slice(0) + thr_mma_dV = tiled_mma_dV.get_slice(0) + thr_mma_dK = tiled_mma_dK.get_slice(0) + thr_mma_dQ = tiled_mma_dQ.get_slice(0) + # Partition smem / tmem tensors + # S = K @ Q.T + tSrK = thr_mma_SdP.make_fragment_A(sK) + tSrQ = thr_mma_SdP.make_fragment_B(sQ) + # dP = V @ dO.T + tdPrV = thr_mma_SdP.make_fragment_A(sV) + tdPrdOt = thr_mma_SdP.make_fragment_B(sdOt) + # dK = dS.T @ Q + tdKrdS = thr_mma_dK.make_fragment_A(sdSt) + tdKrQ = thr_mma_dK.make_fragment_B(sQt) + # dQ = dS @ K + tdQrdS = thr_mma_dQ.make_fragment_A(sdS) + tdQrK = thr_mma_dQ.make_fragment_B(sKt) + # dV = P @ dO.T + tdVrdO = thr_mma_dV.make_fragment_B(sdO) + p_tmem_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_dV, + self.mma_tiler_pdo, + self.q_dtype, + self.acc_stage, + ) + tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) + tdVrP = thr_mma_dV.make_fragment_A(tP)[None, None, None, 0] + tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) + + mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) + mma_dov_fn = partial( + gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True + ) + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, A_idx=None) + mma_dsk_fn = partial( + gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, B_idx=0, zero_init=True + ) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, A_idx=0) + + consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.q_stage ) - q_dk_consumer_state = q_consumer_state - do_consumer_state = cutlass.pipeline.make_pipeline_state( + q_dk_consumer_state = consumer_state_Q + consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.do_stage ) - s_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_S = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.s_stage ) - dP_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_dP = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dP_stage ) - p_consumer_state = cutlass.pipeline.make_pipeline_state( + consumer_state_P = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.s_stage ) - dS_consumer_state = cutlass.pipeline.make_pipeline_state( + consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage ) dV_producer_state = cutlass.pipeline.make_pipeline_state( @@ -1254,7 +1268,7 @@ def mma( dK_producer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dK_stage ) - dQaccum_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_dQ = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage ) @@ -1264,40 +1278,9 @@ def mma( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - # S = K @ Q.T sK and sQ - tSrK = thr_mma_kq.make_fragment_A(sK) - tSrQ = thr_mma_kq.make_fragment_B(sQ) - - # dP = V @ dOt - tdPrV = thr_mma_vdo.make_fragment_A(sV) - tdPrdOt = thr_mma_vdo.make_fragment_B(sdOt) - - # dK = dS.T @ Q - tdKrdS = thr_mma_dsq.make_fragment_A(sdSt) - tdKrQ = thr_mma_dsq.make_fragment_B(sQt) - accumulate_dK = False - - # dV = P @ dO.T - tdVrdO = thr_mma_pdo.make_fragment_B(sdO) - p_tmem_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dV, - self.mma_tiler_pdo, - self.q_dtype, - self.acc_stage, - ) - - tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) - tdVrP = thr_mma_pdo.make_fragment_A(tP)[None, None, None, 0] - tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) - - # dQ = dS @ K - tdQaccrdS = thr_mma_dsk.make_fragment_A(sdS) - tdQaccrK = thr_mma_dsk.make_fragment_B(sKt) - # ----------------------------------------------------------- ###### Prologue # ----------------------------------------------------------- @@ -1306,59 +1289,30 @@ def mma( # 3. dV = P @ dO # 1) S = Q0 @ K.T - pipeline_Q.consumer_wait(q_consumer_state) - pipeline_s.producer_acquire(s_producer_state) - - num_k_phases = cute.size(tSrK, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): - tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_SdP, - tStS, - tSrK[(None, None, kphase_idx, 0)], - tSrQ[(None, None, kphase_idx, q_consumer_state.index)], - tStS, - ) - - q_consumer_state.advance() - pipeline_s.producer_commit(s_producer_state) - s_producer_state.advance() + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_S.producer_acquire(producer_state_S) + mma_qk_fn(B_idx=consumer_state_Q.index) + # Don't release Q yet + consumer_state_Q.advance() + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() # 2) dP = V @ dO.T - pipeline_dO.consumer_wait(do_consumer_state) - pipeline_dP.producer_acquire(dP_producer_state) - - pipeline_dQaccum.producer_acquire(dQaccum_producer_state) - - for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_SdP, - tdPtdP, - tdPrV[(None, None, kphase_idx, 0)], - tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], - tdPtdP, - ) - pipeline_dP.producer_commit(dP_producer_state) - dP_producer_state.advance() + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.producer_acquire(producer_state_dP) + pipeline_dQ.producer_acquire(producer_state_dQ) + mma_dov_fn(B_idx=consumer_state_dO.index) + # Don't release dO yet + pipeline_dP.producer_commit(producer_state_dP) + producer_state_dP.advance() # 3) dV = P.T @ dO - pipeline_p.consumer_wait(p_consumer_state) - - num_kphases = cute.size(tdVrP, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dV.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_dV, - tdVtdV, - tdVrP[(None, None, kphase_idx)], - tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], - tdVtdV, - ) - pipeline_p.consumer_release(p_consumer_state) - p_consumer_state.advance() - pipeline_dO.consumer_release(do_consumer_state) - do_consumer_state.advance() + pipeline_P.consumer_wait(consumer_state_P) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() # ----------------------------------------------------------- ###### MAIN LOOP # ----------------------------------------------------------- @@ -1370,144 +1324,72 @@ def mma( for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): # 1) S = K @ Q_i - pipeline_Q.consumer_wait(q_consumer_state) - pipeline_s.producer_acquire(s_producer_state) - #''' - for kphase_idx in cutlass.range_constexpr(num_k_phases, unroll=1): - tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_SdP, - tStS, - tSrK[(None, None, kphase_idx, 0)], - tSrQ[(None, None, kphase_idx, q_consumer_state.index)], - tStS, - ) - - pipeline_s.producer_commit(s_producer_state) - s_producer_state.advance() - q_consumer_state.advance() + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_S.producer_acquire(producer_state_S) + mma_qk_fn(B_idx=consumer_state_Q.index) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + consumer_state_Q.advance() # 2) dQ = dS @ K - pipeline_dS.consumer_wait(dS_consumer_state) - pipeline_dP.producer_acquire(dP_producer_state) - - num_kphases = cute.size(tdQaccrdS, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dQ.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_dQ, - tdQacctdQacc, - tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdQaccrK[(None, None, kphase_idx, 0)], - tdQacctdQacc, - ) - pipeline_dQaccum.producer_commit(dQaccum_producer_state) - dQaccum_producer_state.advance() + pipeline_dS.consumer_wait(consumer_state_dS) + pipeline_dP.producer_acquire(producer_state_dP) + mma_dsk_fn(A_idx=consumer_state_dS.index) + pipeline_dQ.producer_commit(producer_state_dQ) + producer_state_dQ.advance() # 3) dK = dS.T @ Q - num_kphases = cute.size(tdKrdS, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): - tiled_mma_dK.set(tcgen05.Field.ACCUMULATE, accumulate_dK) - cute.gemm( - tiled_mma_dK, - tdKtdK, - tdKrdS[(None, None, kphase_idx, 0)], - tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], - tdKtdK, - ) - accumulate_dK = True - + mma_dsq_fn(B_idx=q_dk_consumer_state.index, zero_init=not accumulate_dK) + accumulate_dK = True pipeline_Q.consumer_release(q_dk_consumer_state) q_dk_consumer_state.advance() - pipeline_dS.consumer_release(dS_consumer_state) - dS_consumer_state.advance() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() # 4) dP = V @ dO.T - pipeline_dO.consumer_wait(do_consumer_state) - - pipeline_dQaccum.producer_acquire(dQaccum_producer_state) - - for kphase_idx in cutlass.range_constexpr(cute.size(tdPrV, mode=[2]), unroll=1): - tiled_mma_SdP.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_SdP, - tdPtdP, - tdPrV[(None, None, kphase_idx, 0)], - tdPrdOt[(None, None, kphase_idx, do_consumer_state.index)], - tdPtdP, - ) - pipeline_dP.producer_commit(dP_producer_state) - dP_producer_state.advance() + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dQ.producer_acquire(producer_state_dQ) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.producer_commit(producer_state_dP) + producer_state_dP.advance() # 5) dV += P @ dO - pipeline_p.consumer_wait(p_consumer_state) - - num_kphases = cute.size(tdVrP, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dV.set(tcgen05.Field.ACCUMULATE, True) - cute.gemm( - tiled_mma_dV, - tdVtdV, - tdVrP[(None, None, kphase_idx)], - tdVrdO[(None, None, kphase_idx, do_consumer_state.index)], - tdVtdV, - ) - - pipeline_p.consumer_release(p_consumer_state) - p_consumer_state.advance() - pipeline_dO.consumer_release(do_consumer_state) - do_consumer_state.advance() + pipeline_P.consumer_wait(consumer_state_P) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() pipeline_dV.producer_acquire(dV_producer_state) pipeline_dV.producer_commit(dV_producer_state) dV_producer_state.advance() - pipeline_s.producer_tail(s_producer_state) - pipeline_dP.producer_tail(dP_producer_state) + pipeline_S.producer_tail(producer_state_S) + pipeline_dP.producer_tail(producer_state_dP) pipeline_dV.producer_tail(dV_producer_state) # ----------------------------------------------------------- ###### Remaining 2 # ----------------------------------------------------------- # 1) dK += dS.T @ Q - pipeline_dS.consumer_wait(dS_consumer_state) - - num_kphases = cute.size(tdKrdS, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases): - tiled_mma_dK.set(tcgen05.Field.ACCUMULATE, accumulate_dK) - cute.gemm( - tiled_mma_dK, - tdKtdK, - tdKrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdKrQ[(None, None, kphase_idx, q_dk_consumer_state.index)], - tdKtdK, - ) - accumulate_dK = True - + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=q_dk_consumer_state.index, zero_init=not accumulate_dK) pipeline_dK.producer_acquire(dK_producer_state) pipeline_dK.producer_commit(dK_producer_state) dK_producer_state.advance() - # 2) dQaccum = dS @ K - num_kphases = cute.size(tdQaccrdS, mode=[2]) - for kphase_idx in cutlass.range_constexpr(num_kphases, unroll=1): - tiled_mma_dQ.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) - cute.gemm( - tiled_mma_dQ, - tdQacctdQacc, - tdQaccrdS[(None, None, kphase_idx, dS_consumer_state.index)], - tdQaccrK[(None, None, kphase_idx, 0)], - tdQacctdQacc, - ) - pipeline_dQaccum.producer_commit(dQaccum_producer_state) - dQaccum_producer_state.advance() + # 2) dQ = dS @ K + mma_dsk_fn(A_idx=consumer_state_dS.index) + pipeline_dQ.producer_commit(producer_state_dQ) + producer_state_dQ.advance() pipeline_Q.consumer_release(q_dk_consumer_state) q_dk_consumer_state.advance() - pipeline_dS.consumer_release(dS_consumer_state) - dS_consumer_state.advance() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() pipeline_dK.producer_tail(dK_producer_state) - pipeline_dQaccum.producer_tail(dQaccum_producer_state) + pipeline_dQ.producer_tail(producer_state_dQ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1557,10 +1439,9 @@ def split_wg( @cute.jit def compute_loop( self, - thr_mma_kq: cute.core.ThrMma, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_vdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, + thr_mma_SdP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, tStS: cute.Tensor, sLSE_2D: cute.Tensor, sPsum_2D: cute.Tensor, @@ -1575,8 +1456,8 @@ def compute_loop( lse_empty_mbar_ptr: cute.Pointer, dpsum_full_mbar_ptr: cute.Pointer, dpsum_empty_mbar_ptr: cute.Pointer, - pipeline_s: PipelineAsync, - pipeline_p: PipelineAsync, + pipeline_S: PipelineAsync, + pipeline_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dV: PipelineAsync, pipeline_dK: PipelineAsync, @@ -1655,8 +1536,8 @@ def compute_loop( for i in cutlass.range(m_block_max - m_block_min, unroll=1): m_block = m_block_max - 1 - i - pipeline_s.consumer_wait(s_consumer_state) - pipeline_p.producer_acquire(p_producer_state) + pipeline_S.consumer_wait(s_consumer_state) + pipeline_P.producer_acquire(p_producer_state) if warp_idx == self.compute_warp_ids[0]: cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) @@ -1679,9 +1560,7 @@ def compute_loop( tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) #### RMEM - tScS = thr_mma_kq.partition_C( - cute.make_identity_tensor((self.mma_tiler_kq[0], self.mma_tiler_kq[1])) - ) + tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) tScS_tensor = cute.make_tensor(tScS.iterator, tScS.layout) tScS_t2r_p = thr_tmem_ld.partition_D(tScS_tensor) tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) @@ -1780,10 +1659,10 @@ def compute_loop( number_of_threads=self.num_compute_threads, ) - pipeline_p.producer_commit(p_producer_state) + pipeline_P.producer_commit(p_producer_state) p_producer_state.advance() - pipeline_s.consumer_release(s_consumer_state) + pipeline_S.consumer_release(s_consumer_state) s_consumer_state.advance() if warp_idx == self.compute_warp_ids[0]: @@ -1809,7 +1688,7 @@ def compute_loop( #### TMEM->RMEM (Load dP from TMEM) cdP = cute.make_identity_tensor((self.mma_tiler_vdo[0], self.mma_tiler_vdo[1])) - tdPcdP = thr_mma_vdo.partition_C(cdP) + tdPcdP = thr_mma_SdP.partition_C(cdP) tdPcdP_tensor = cute.make_tensor(tdPcdP.iterator, tdPcdP.layout) tdPcdP_t2r_p = thr_tmem_ld_dP.partition_D(tdPcdP_tensor) @@ -1902,8 +1781,8 @@ def compute_loop( batch_idx, head_idx, n_block, - thr_mma_pdo, - thr_mma_dsq, + thr_mma_dV, + thr_mma_dK, tdVtdV, tdKtdK, mdV, @@ -1920,7 +1799,7 @@ def compute_loop( batch_idx, head_idx, n_block, - thr_mma_pdo, + thr_mma_dV, tdVtdV, mdV_tma_tensor, sdV, @@ -1938,7 +1817,7 @@ def compute_loop( batch_idx, head_idx, n_block, - thr_mma_dsq, + thr_mma_dK, tdKtdK, mdK_tma_tensor, sdK, @@ -1959,7 +1838,7 @@ def dQacc_reduce( self, mdQaccum: cute.Tensor, sdQaccum: cute.Tensor, - thr_mma_dsk: cute.core.ThrMma, + thr_mma_dQ: cute.core.ThrMma, tdQtdQ: cute.Tensor, pipeline_dQ: PipelineAsync, dQaccum_reduce_mbar_ptr: cute.Pointer, @@ -1988,7 +1867,7 @@ def dQacc_reduce( tdQtdQ_t2r = thr_tmem_ld.partition_S(tdQtdQ) cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) - tdQcdQ = thr_mma_dsk.partition_C(cdQ) + tdQcdQ = thr_mma_dQ.partition_C(cdQ) tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) @@ -2130,8 +2009,8 @@ def epilogue_dKV( batch_idx: Int32, head_idx: Int32, n_block: Int32, - thr_mma_pdo: cute.core.ThrMma, - thr_mma_dsq: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, mdV: cute.Tensor, @@ -2170,7 +2049,7 @@ def epilogue_dKV( tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) - tdVcdV = thr_mma_pdo.partition_C(cdV) + tdVcdV = thr_mma_dV.partition_C(cdV) tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) @@ -2200,7 +2079,7 @@ def epilogue_dKV( gdV = cute.local_tile(mdV_cur, (self.tile_m, self.tile_hdimv), (None, 0)) gdV_tile = gdV[None, None, n_block] - tdVgdV = thr_mma_pdo.partition_C(gdV_tile) + tdVgdV = thr_mma_dV.partition_C(gdV_tile) tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) @@ -2219,7 +2098,7 @@ def epilogue_dKV( tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) - tdKcdK = thr_mma_dsq.partition_C(cdK) + tdKcdK = thr_mma_dK.partition_C(cdK) tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) @@ -2251,7 +2130,7 @@ def epilogue_dKV( gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0)) gdK_tile = gdK[None, None, n_block] - tdKgdK = thr_mma_dsq.partition_C(gdK_tile) + tdKgdK = thr_mma_dK.partition_C(gdK_tile) tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) From 5c685eaa7d2bca7eeaae5068f061fabd00fb4d7d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 17:57:14 -0400 Subject: [PATCH 165/258] [Cute,Bwd,Sm100] Don't need q_dk_consumer_state --- flash_attn/cute/flash_bwd_sm100.py | 41 ++++++++++++++---------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 247dc669b02..dffdf227acb 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -133,8 +133,8 @@ def __init__( def _setup_attributes(self): self.q_stage = 2 self.k_stage = self.v_stage = 1 - self.do_stage = 1 - self.ds_stage = 1 + self.dO_stage = 1 + self.dS_stage = 1 self.lse_stage = 1 self.acc_stage = 1 self.s_stage = 1 @@ -208,7 +208,7 @@ def _setup_smem_layout(self): self.tiled_mma_dV, self.mma_tiler_pdo, self.do_dtype, - self.do_stage, + self.dO_stage, ) # dP = V @ dO.T self.sV_layout = sm100_utils_basic.make_smem_layout_a( @@ -221,14 +221,14 @@ def _setup_smem_layout(self): self.tiled_mma_SdP, self.mma_tiler_vdo, self.do_dtype, - self.do_stage, + self.dO_stage, ) # dK += dS.T @ Q self.sdSt_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, - self.ds_stage, + self.dS_stage, ) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, @@ -241,7 +241,7 @@ def _setup_smem_layout(self): self.tiled_mma_dQ, self.mma_tiler_dsk, self.q_dtype, - self.ds_stage, + self.dS_stage, ) self.sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, @@ -474,7 +474,7 @@ def __call__( class SharedStorage: q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] - do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.do_stage] + do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] dpsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dpsum_stage] @@ -482,7 +482,7 @@ class SharedStorage: s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] - dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.ds_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] @@ -680,7 +680,7 @@ def kernel( pipeline_dO = pipeline.PipelineTmaUmma.create( barrier_storage=storage.do_mbar_ptr.data_ptr(), - num_stages=self.do_stage, + num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"], @@ -1056,7 +1056,7 @@ def load( cutlass.pipeline.PipelineUserType.Producer, self.q_stage ) producer_state_dO = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.do_stage + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) tile_scheduler = TileSchedulerCls() @@ -1245,11 +1245,9 @@ def mma( consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.q_stage ) - q_dk_consumer_state = consumer_state_Q consumer_state_dO = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.do_stage + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) - producer_state_S = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.s_stage ) @@ -1293,7 +1291,6 @@ def mma( pipeline_S.producer_acquire(producer_state_S) mma_qk_fn(B_idx=consumer_state_Q.index) # Don't release Q yet - consumer_state_Q.advance() pipeline_S.producer_commit(producer_state_S) producer_state_S.advance() @@ -1324,12 +1321,13 @@ def mma( for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): # 1) S = K @ Q_i + consumer_state_Q_prev = consumer_state_Q.clone() + consumer_state_Q.advance() pipeline_Q.consumer_wait(consumer_state_Q) pipeline_S.producer_acquire(producer_state_S) mma_qk_fn(B_idx=consumer_state_Q.index) pipeline_S.producer_commit(producer_state_S) producer_state_S.advance() - consumer_state_Q.advance() # 2) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) @@ -1339,10 +1337,9 @@ def mma( producer_state_dQ.advance() # 3) dK = dS.T @ Q - mma_dsq_fn(B_idx=q_dk_consumer_state.index, zero_init=not accumulate_dK) + mma_dsq_fn(B_idx=consumer_state_Q_prev.index, zero_init=not accumulate_dK) accumulate_dK = True - pipeline_Q.consumer_release(q_dk_consumer_state) - q_dk_consumer_state.advance() + pipeline_Q.consumer_release(consumer_state_Q_prev) pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1374,7 +1371,7 @@ def mma( # ----------------------------------------------------------- # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=q_dk_consumer_state.index, zero_init=not accumulate_dK) + mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) pipeline_dK.producer_acquire(dK_producer_state) pipeline_dK.producer_commit(dK_producer_state) dK_producer_state.advance() @@ -1383,8 +1380,8 @@ def mma( mma_dsk_fn(A_idx=consumer_state_dS.index) pipeline_dQ.producer_commit(producer_state_dQ) producer_state_dQ.advance() - pipeline_Q.consumer_release(q_dk_consumer_state) - q_dk_consumer_state.advance() + pipeline_Q.consumer_release(consumer_state_Q) + consumer_state_Q.advance() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1505,7 +1502,7 @@ def compute_loop( cutlass.pipeline.PipelineUserType.Producer, self.s_stage ) dS_producer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.ds_stage + cutlass.pipeline.PipelineUserType.Producer, self.dS_stage ) dP_consumer_state = cutlass.pipeline.make_pipeline_state( From 8790c6ec23d4e8270ee3033e512314144800f86b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 21:17:41 -0400 Subject: [PATCH 166/258] [Cute,Bwd,Sm100] Simplify dQacc_reduce, don't need mbarrier --- flash_attn/cute/flash_bwd_sm100.py | 179 ++++++++++------------------- 1 file changed, 60 insertions(+), 119 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index dffdf227acb..faf4bf4a96a 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -61,8 +61,6 @@ def __init__( self.tile_m = tile_m self.tile_n = tile_n - # number of tma reduce adds per dQacc mma - self.dQaccum_reduce_stage = self.tile_hdim // 32 # CTA tiler self.cta_tiler = (tile_m, tile_n, self.tile_hdim) @@ -147,6 +145,8 @@ def _setup_attributes(self): self.dpsum_stage = 1 self.p_tmem_stage = 1 self.sdKdVaccum_stage = 2 + # number of tma reduce adds per dQacc mma + self.dQaccum_reduce_stage = self.tile_hdim // 32 def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE @@ -445,6 +445,7 @@ def __call__( } self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = self.tile_m * 32 * Float32.width // 8 TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal @@ -486,7 +487,6 @@ class SharedStorage: dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] - dQaccum_reduce_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] # TMEM tmem_holding_buf: Int32 @@ -650,7 +650,6 @@ def kernel( lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() dpsum_full_mbar_ptr = storage.dpsum_full_mbar_ptr.data_ptr() dpsum_empty_mbar_ptr = storage.dpsum_empty_mbar_ptr.data_ptr() - dQaccum_reduce_mbar_ptr = storage.dQaccum_reduce_mbar_ptr.data_ptr() if warp_idx == self.load_warp_id: cute.arch.mbarrier_init( @@ -660,7 +659,6 @@ def kernel( cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) cute.arch.mbarrier_init(dpsum_full_mbar_ptr, len([self.compute_warp_ids])) cute.arch.mbarrier_init(dpsum_empty_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(dQaccum_reduce_mbar_ptr, 1) pipeline_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) @@ -1012,7 +1010,6 @@ def kernel( thr_mma_dQ, tdQtdQ, pipeline_dQ, - dQaccum_reduce_mbar_ptr, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1541,7 +1538,7 @@ def compute_loop( lse_consumer_phase ^= 1 tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tStS) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + thr_tmem_load = tiled_tmem_ld.get_slice(tidx) tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) tStP = cute.make_tensor( @@ -1553,13 +1550,13 @@ def compute_loop( thr_tmem_st = tiled_tmem_st.get_slice(tidx) #### TMEM - tStS_t2r_p = thr_tmem_ld.partition_S(tStS) + tStS_t2r_p = thr_tmem_load.partition_S(tStS) tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) #### RMEM tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) tScS_tensor = cute.make_tensor(tScS.iterator, tScS.layout) - tScS_t2r_p = thr_tmem_ld.partition_D(tScS_tensor) + tScS_t2r_p = thr_tmem_load.partition_D(tScS_tensor) tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 @@ -1599,7 +1596,7 @@ def compute_loop( tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) #### Compute P = exp(S * scale - LSE) - tLSE = thr_tmem_ld.partition_D(sLSE_2D) + tLSE = thr_tmem_load.partition_D(sLSE_2D) # split to wg0 & wg1 tLSErLSE_p = cute.make_tensor( cute.recast_ptr(tLSE.iterator), @@ -1713,7 +1710,7 @@ def compute_loop( cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape ) - tPsum = thr_tmem_ld.partition_D(sPsum_2D) + tPsum = thr_tmem_load.partition_D(sPsum_2D) tPsumrPsum_p = cute.make_tensor( cute.recast_ptr(tPsum.iterator), cute.make_layout( @@ -1838,163 +1835,107 @@ def dQacc_reduce( thr_mma_dQ: cute.core.ThrMma, tdQtdQ: cute.Tensor, pipeline_dQ: PipelineAsync, - dQaccum_reduce_mbar_ptr: cute.Pointer, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, mdQ_semaphore: Optional[cute.Tensor], ): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * 4) - - dQ_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage - ) - - tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() - + num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_reduce_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) # TMEM -> RMEM - tmem_ld_atom = cute.make_copy_atom( + tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) - - tdQtdQ_t2r = thr_tmem_ld.partition_S(tdQtdQ) - - cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1])) - tdQcdQ = thr_mma_dQ.partition_C(cdQ) - tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) - tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) - - num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) - - atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dqaccum_dtype, num_bits_per_copy=128 + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) + tdQtdQ_t2r = thr_tmem_load.partition_S(tdQtdQ) + tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) + tdQrdQ_t2r_shape = thr_tmem_load.partition_D(tdQcdQ).shape + assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, ( + "dQaccum reduce stage mismatch" ) - thr_layout = cute.make_layout(shape=128, stride=1) - val_layout = cute.make_layout(shape=4, stride=1) - tiler_mn, layout_tv = cute.make_layout_tv(thr_layout=thr_layout, val_layout=val_layout) - tiled_smem_store = cute.make_tiled_copy( - atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn - ) + thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d( + self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width + ).get_slice(tidx) + tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum) - smem_thr_copy_dQaccum = tiled_smem_store.get_slice(tidx) - tdQsdQ = smem_thr_copy_dQaccum.partition_D(sdQaccum) - store_bytes = cutlass.Int32(self.tile_m * 32 * 4) - - if const_expr(self.deterministic): - read_flag = False - else: - read_flag = True + read_flag = const_expr(not self.deterministic) reduce_phase = cutlass.Int32(0) - if cute.arch.thread_idx()[0] == 0: - cute.arch.mbarrier_arrive(dQaccum_reduce_mbar_ptr) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), number_of_threads=num_reduce_threads + dQacc_reduce_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + num_threads=num_reduce_threads, ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + dQ_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage + ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / STAGE, STAGE, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) + ) + mdQ_semaphore_cur = None if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] for i in cutlass.range(m_block_max - m_block_min, unroll=1): m_block = m_block_max - 1 - i - pipeline_dQ.consumer_wait(dQ_consumer_state) - # TMEM -> RMEM - tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, Float32) - assert self.dQaccum_reduce_stage == cute.size(tdQrdQ_t2r, mode=[1]), ( - "dQaccum reduce stage mismatch" - ) - - cute.copy(thr_tmem_ld, tdQtdQ_t2r, tdQrdQ_t2r) + tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + cute.copy(thr_tmem_load, tdQtdQ_t2r, tdQrdQ_t2r) cute.arch.fence_view_async_tmem_load() - pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() # semaphore acquire if const_expr(self.deterministic): barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), - number_of_threads=num_reduce_threads, - ) + dQacc_reduce_barrier.arrive_and_wait() for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 - if stage >= 2 and cute.arch.thread_idx()[0] == 0: - cute.arch.cp_async_bulk_wait_group(1, read=read_flag) - - cute.arch.mbarrier_wait(dQaccum_reduce_mbar_ptr, reduce_phase) - - tdQrdQ_r2s = tdQrdQ_t2r[None, stage, None, None] tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] tdQrdQ_r2s = cute.make_tensor( - tdQrdQ_r2s.iterator, cute.make_layout(tdQsdQ_r2s.shape) + tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape ) - - cute.copy(smem_thr_copy_dQaccum, tdQrdQ_r2s, tdQsdQ_r2s) - + cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), - number_of_threads=num_reduce_threads, - ) - - if cute.arch.thread_idx()[0] == 0: - smem_ptr = sdQaccum[None, reduce_phase].iterator - g_stage_index_elems = m_block * (self.tile_m * self.tile_hdimv) + stage * ( - self.tile_m * 32 - ) - gmem_row_ptr = cute.domain_offset( - (g_stage_index_elems,), mdQaccum_cur - ).iterator - - copy_utils.cpasync_reduce_bulk_add_f32(smem_ptr, gmem_row_ptr, store_bytes) + dQacc_reduce_barrier.arrive_and_wait() + if warp_idx == 0: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, reduce_phase].iterator, + gdQaccum[None, stage, m_block].iterator, + self.tma_copy_bytes["dQ"], + ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(1, read=read_flag) - - cute.arch.mbarrier_arrive(dQaccum_reduce_mbar_ptr) - + dQacc_reduce_barrier.arrive_and_wait() reduce_phase ^= 1 - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), - number_of_threads=num_reduce_threads, - ) - # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic): - if cute.arch.thread_idx()[0] == 0: + if tidx == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), - number_of_threads=num_reduce_threads, - ) + dQacc_reduce_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) - if cute.arch.thread_idx()[0] == 0: + if warp_idx == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2189,7 +2130,7 @@ def epilogue_dK_or_dV_tma( num_epi_stages = cute.size(tdKVgdKV.shape[1]) assert num_epi_stages == 1 or num_epi_stages == 2, "Wrong number of epi stages" - tmem_ld_atom = cute.make_copy_atom( + tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) @@ -2213,17 +2154,17 @@ def epilogue_dK_or_dV_tma( for s in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdKVtdKV) - thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV) + thr_tmem_load = tiled_tmem_ld.get_slice(tidx) - tdKVtdKV_t2r_p = thr_tmem_ld.partition_S(tdKVtdKV) + tdKVtdKV_t2r_p = thr_tmem_load.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) - tdKVcdKV_t2r_p = thr_tmem_ld.partition_D(tdKVcdKV) + tdKVcdKV_t2r_p = thr_tmem_load.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVcdKV_t2r = tdKVcdKV_t2r[None, s] @@ -2235,7 +2176,7 @@ def epilogue_dK_or_dV_tma( ) # TMEM -> RMEM -- copy and fence - cute.copy(thr_tmem_ld, tdKVtdKV_t2r, tdKVrdKV_t2r) + cute.copy(thr_tmem_load, tdKVtdKV_t2r, tdKVrdKV_t2r) cute.arch.fence_view_async_tmem_load() # RMEM -- scale and convert From 7254904b5e8ad84e9625d8f70cd8cf4bab1f2a1c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 21:41:38 -0400 Subject: [PATCH 167/258] [Cute,Bwd,Sm100] Iterate from m_block_min -> m_block_max --- flash_attn/cute/flash_bwd_sm100.py | 38 ++++++++++++++---------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index faf4bf4a96a..8a653cb9912 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -713,8 +713,7 @@ def kernel( ) pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, - cute.arch.WARP_SIZE * len(self.reduce_warp_ids), - alignment=128, + len(self.reduce_warp_ids), ) # Compute pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=self.dQaccum_mma_stage, @@ -1105,7 +1104,7 @@ def load( # K & Q pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) - load_Q(m_block_max - 1, producer_state=producer_state_Q) + load_Q(m_block_min, producer_state=producer_state_Q) pipeline_Q.producer_commit(producer_state_Q) producer_state_Q.advance() # LSE @@ -1113,11 +1112,11 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - load_LSE(src_idx=m_block_max - 1, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) + load_LSE(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) # V & dO pipeline_dO.producer_acquire(producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"]) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) - load_dO(m_block_max - 1, producer_state=producer_state_dO) + load_dO(m_block_min, producer_state=producer_state_dO) pipeline_dO.producer_commit(producer_state_dO) producer_state_dO.advance() # dPsum @@ -1125,13 +1124,12 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( dpsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - load_dPsum(src_idx=m_block_max - 1, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) + load_dPsum(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) lse_empty_consumer_phase = cute.Int32(0) dpsum_empty_consumer_phase = cute.Int32(0) - for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): - m_block = m_block_max - 2 - i + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # Q pipeline_Q.producer_acquire(producer_state_Q) load_Q(m_block, producer_state=producer_state_Q) @@ -1316,7 +1314,7 @@ def mma( # 4. dP = V @ dO.T # 5. dV = P.T @ dO - for i in cutlass.range(m_block_max - m_block_min - 1, unroll=1): + for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # 1) S = K @ Q_i consumer_state_Q_prev = consumer_state_Q.clone() consumer_state_Q.advance() @@ -1527,9 +1525,7 @@ def compute_loop( ) # Mainloop - for i in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block = m_block_max - 1 - i - + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_S.consumer_wait(s_consumer_state) pipeline_P.producer_acquire(p_producer_state) @@ -1537,8 +1533,8 @@ def compute_loop( cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) lse_consumer_phase ^= 1 - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tStS) - thr_tmem_load = tiled_tmem_ld.get_slice(tidx) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) tStP = cute.make_tensor( @@ -1562,7 +1558,7 @@ def compute_loop( tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 #### TMEM->RMEM (Load S from TMEM) - cute.copy(tiled_tmem_ld, tStS_t2r, tSrS_t2r) + cute.copy(tiled_tmem_load, tStS_t2r, tSrS_t2r) cute.arch.fence_view_async_tmem_load() #### Sync for load fence and LSE @@ -1862,6 +1858,7 @@ def dQacc_reduce( read_flag = const_expr(not self.deterministic) + # TODO: reduce_phase is currently hardcoded for 2 stages reduce_phase = cutlass.Int32(0) dQacc_reduce_barrier = cutlass.pipeline.NamedBarrier( @@ -1888,14 +1885,15 @@ def dQacc_reduce( if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] - for i in cutlass.range(m_block_max - m_block_min, unroll=1): - m_block = m_block_max - 1 - i + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) cute.copy(thr_tmem_load, tdQtdQ_t2r, tdQrdQ_t2r) cute.arch.fence_view_async_tmem_load() - pipeline_dQ.consumer_release(dQ_consumer_state) + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() # semaphore acquire @@ -2154,8 +2152,8 @@ def epilogue_dK_or_dV_tma( for s in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup - tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV) - thr_tmem_load = tiled_tmem_ld.get_slice(tidx) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) tdKVtdKV_t2r_p = thr_tmem_load.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] From 21876951ef2aa3ae7cc94c6bc79428fd7b4ce8c0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 22:52:22 -0400 Subject: [PATCH 168/258] [Cute,Bwd,Sm100] Try direct atomicadd rmem -> gmem --- flash_attn/cute/copy_utils.py | 40 ++++++++++++++++++++++++++++-- flash_attn/cute/flash_bwd_sm100.py | 38 +++++++++++++++++----------- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 25263f2bd1f..a97344768de 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -5,8 +5,7 @@ import cutlass import cutlass.cute as cute - -from cutlass import Int32, Boolean, const_expr +from cutlass import Float32, Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync from cutlass.cutlass_dsl import dsl_user_op from cutlass._mlir.dialects import llvm @@ -92,6 +91,43 @@ def tiled_copy_2d( return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + @dsl_user_op def cpasync_bulk_g2s( gmem_ptr: cute.Pointer, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 8a653cb9912..aec993b998e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -120,7 +120,8 @@ def __init__( self.num_regs_reduce = 144 self.num_regs_compute = 128 - self.num_regs_load = 96 + # self.num_regs_load = 96 + self.num_regs_load = 112 self.num_regs_mma = 112 self.num_regs_empty = 24 @@ -1629,7 +1630,7 @@ def compute_loop( own1, offset=j, mask=FULL, mask_and_clamp=MAC ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = cute.arch.fma_packed_f32x2( + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.fma_packed_f32x2( ((tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0])), (softmax_scale_log2, softmax_scale_log2), (-lse_j, -lse_j1), @@ -1736,7 +1737,7 @@ def compute_loop( (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), (psum_j, psum_j1) ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = cute.arch.mul_packed_f32x2( + tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.mul_packed_f32x2( (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0]), (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), ) @@ -1796,8 +1797,7 @@ def compute_loop( tma_atom_dV, thr_copy_r2s_dKdV, pipeline_dV, - softmax_scale, - False, # apply scale + None, # Don't scale int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdV_semaphore, ) @@ -1815,7 +1815,6 @@ def compute_loop( thr_copy_r2s_dKdV, pipeline_dK, softmax_scale, - True, # apply scale int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, ) @@ -1922,6 +1921,17 @@ def dQacc_reduce( cute.arch.cp_async_bulk_wait_group(1, read=read_flag) dQacc_reduce_barrier.arrive_and_wait() reduce_phase ^= 1 + # Directly add to gmem, much slower + # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) + # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) + # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): + # copy_utils.atomic_add_fp32x4( + # tdQrdQ_r2s[4 * i], + # tdQrdQ_r2s[4 * i + 1], + # tdQrdQ_r2s[4 * i + 2], + # tdQrdQ_r2s[4 * i + 3], + # utils.elem_pointer(tdQgdQ, 4 * i), + # ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar @@ -2089,8 +2099,7 @@ def epilogue_dK_or_dV_tma( tma_atom_dKV: cute.CopyAtom, thr_copy_r2s_dKdV: cute.TiledCopy, pipeline: PipelineAsync, - softmax_scale: Float32, - do_scale: cutlass.Constexpr[cutlass.Boolean], + scale: Optional[Float32], barrier_id: Int32, mdKV_semaphore: Optional[cute.Tensor], ): @@ -2178,14 +2187,13 @@ def epilogue_dK_or_dV_tma( cute.arch.fence_view_async_tmem_load() # RMEM -- scale and convert + if const_expr(scale is not None): + for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True): + tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( + (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) + ) tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) - if const_expr(do_scale): - scale = softmax_scale - else: - scale = Float32(1) - - dKV_vec = tdKVrdKV_t2r.load() * scale - tdKVrdKV.store(dKV_vec.to(self.dv_dtype)) + tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) # RMEM -> SMEM -- setup tdKVcdKV_r2s_p = thr_copy_r2s_dKdV.partition_S(cdKV) From 12e1c0498cf520458c290064e5493dc92f02a697 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Oct 2025 23:34:31 -0400 Subject: [PATCH 169/258] [Cute,Bwd,Sm100] Combine pipeline_dK and pipeline_dV into one --- flash_attn/cute/flash_bwd_sm100.py | 358 +++++++++++++---------------- 1 file changed, 157 insertions(+), 201 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index aec993b998e..41a14180d55 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -130,22 +130,20 @@ def __init__( self.num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) def _setup_attributes(self): - self.q_stage = 2 + self.Q_stage = 2 self.k_stage = self.v_stage = 1 self.dO_stage = 1 self.dS_stage = 1 - self.lse_stage = 1 + self.LSE_stage = 1 self.acc_stage = 1 - self.s_stage = 1 + self.S_stage = 1 self.dP_stage = 1 - self.dV_stage = 1 - self.dK_stage = 1 self.dS_stage = 1 self.dQaccum_mma_stage = 1 self.sdQaccum_stage = 2 - self.dpsum_stage = 1 + self.dPsum_stage = 1 self.p_tmem_stage = 1 - self.sdKdVaccum_stage = 2 + self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma self.dQaccum_reduce_stage = self.tile_hdim // 32 @@ -202,7 +200,7 @@ def _setup_smem_layout(self): self.tiled_mma_SdP, self.mma_tiler_kq, self.q_dtype, - self.q_stage, + self.Q_stage, ) # dV += P @ dO self.sdO_layout = sm100_utils_basic.make_smem_layout_b( @@ -235,7 +233,7 @@ def _setup_smem_layout(self): self.tiled_mma_dK, self.mma_tiler_dsq, self.q_dtype, - self.q_stage, + self.Q_stage, ) # dQaccum = dS @ K self.sdS_layout = sm100_utils_basic.make_smem_layout_a( @@ -253,11 +251,11 @@ def _setup_smem_layout(self): self.sdQaccum_layout = cute.make_layout((self.tile_m * 32, self.sdQaccum_stage)) self.sLSE_layout = cute.make_layout( - shape=(self.tile_m, self.lse_stage), + shape=(self.tile_m, self.LSE_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) self.sdPsum_layout = cute.make_layout( - shape=(self.tile_m, self.dpsum_stage), + shape=(self.tile_m, self.dPsum_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) @@ -344,35 +342,35 @@ def __call__( raise RuntimeError("The layout of mdK is wrong") if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - self.sdKdV_epi_tile = ( + self.sdKV_epi_tile = ( self.tile_n, 128 // (self.dk_dtype.width // 8), ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] - sdKdV_layout = sm100_utils_basic.make_smem_layout_epi( + sdKV_layout = sm100_utils_basic.make_smem_layout_epi( self.dk_dtype, self.mdK_layout_enum, - self.sdKdV_epi_tile, - self.sdKdVaccum_stage, + self.sdKV_epi_tile, + self.sdKVaccum_stage, ) if const_expr(self.use_tma_store): if const_expr(self.dk_dtype.width == 32): - tma_copy_op_dKdV = cpasync.CopyReduceBulkTensorTileS2GOp() + tma_copy_op_dKV = cpasync.CopyReduceBulkTensorTileS2GOp() else: - tma_copy_op_dKdV = cpasync.CopyBulkTensorTileS2GOp() + tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( - tma_copy_op_dKdV, + tma_copy_op_dKV, mdK, - cute.select(sdKdV_layout, mode=[0, 1]), - self.sdKdV_epi_tile, + cute.select(sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, 1, # no mcast ) tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( - tma_copy_op_dKdV, + tma_copy_op_dKV, mdV, - cute.select(sdKdV_layout, mode=[0, 1]), - self.sdKdV_epi_tile, + cute.select(sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, 1, # no mcast ) else: @@ -382,19 +380,17 @@ def __call__( tma_atom_dV = None tma_atom_dK = None - thr_layout_r2s_dKdV = cute.make_ordered_layout( - (self.tile_n, 1), order=(1, 0) - ) # 128 threads - val_layout_r2s_dKdV = cute.make_ordered_layout( + thr_layout_r2s_dKV = cute.make_ordered_layout((self.tile_n, 1), order=(1, 0)) # 128 threads + val_layout_r2s_dKV = cute.make_ordered_layout( (1, 128 // self.dk_dtype.width), order=(1, 0) ) # 4 or 8 vals for 16 byte store - r2s_copy_atom_r2s_dKdV = cute.make_copy_atom( + r2s_copy_atom_r2s_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=128, ) - tiled_copy_r2s_dKdV = cute.make_tiled_copy_tv( - r2s_copy_atom_r2s_dKdV, thr_layout_r2s_dKdV, val_layout_r2s_dKdV + tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( + r2s_copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) @@ -474,19 +470,17 @@ def __call__( @cute.struct class SharedStorage: - q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.q_stage] - lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.lse_stage] - do_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] - lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.k_stage] - dpsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dpsum_stage] - dpsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dpsum_stage] - s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] + lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] + dpsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] + dpsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] - p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.s_stage] + P_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] - dV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dV_stage] - dK_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dK_stage] + dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] # TMEM @@ -565,12 +559,12 @@ class SharedStorage: self.sdS_layout, self.sKt_layout, self.sdQaccum_layout, - sdKdV_layout, + sdKV_layout, self.tiled_mma_SdP, self.tiled_mma_dV, self.tiled_mma_dK, self.tiled_mma_dQ, - tiled_copy_r2s_dKdV, + tiled_copy_r2s_dKV, softmax_scale, softmax_scale_log2, tile_sched_params, @@ -618,12 +612,12 @@ def kernel( sdS_layout: cute.ComposedLayout, sKt_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - sdKdV_layout: cute.ComposedLayout, + sdKV_layout: cute.ComposedLayout, tiled_mma_SdP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, - tiled_copy_r2s_dKdV: cute.TiledCopy, + tiled_copy_r2s_dKV: cute.TiledCopy, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, tile_sched_params: ParamsBase, @@ -667,18 +661,16 @@ def kernel( pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - pipeline_Q = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.q_mbar_ptr.data_ptr(), - num_stages=self.q_stage, + barrier_storage=storage.Q_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"], init_wait=False, ) - pipeline_dO = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.do_mbar_ptr.data_ptr(), + barrier_storage=storage.dO_mbar_ptr.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, @@ -690,27 +682,27 @@ def kernel( pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) + # Only 1 thread per warp will signal pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) ) - pipeline_S = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.s_stage, + num_stages=self.S_stage, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, - barrier_storage=storage.s_mbar_ptr.data_ptr(), + barrier_storage=storage.S_mbar_ptr.data_ptr(), ) - pipeline_dV = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.dV_stage, + pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=self.dP_stage, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, - barrier_storage=storage.dV_mbar_ptr.data_ptr(), + barrier_storage=storage.dP_mbar_ptr.data_ptr(), ) - pipeline_dK = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.dK_stage, + pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=2, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, - barrier_storage=storage.dK_mbar_ptr.data_ptr(), + barrier_storage=storage.dKV_mbar_ptr.data_ptr(), ) pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, @@ -722,32 +714,26 @@ def kernel( consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, barrier_storage=storage.dQaccum_mbar_ptr.data_ptr(), ) - pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.dP_stage, - producer_group=pipeline_producer_group_MMA_AsyncThread, - consumer_group=pipeline_consumer_group_MMA_AsyncThread, - barrier_storage=storage.dP_mbar_ptr.data_ptr(), - ) # AsyncThread producers and UMMA consumers - pipeline_pdS_producer_group = cutlass.pipeline.CooperativeGroup( + pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) # Compute - pipeline_pdS_consumer_group = cutlass.pipeline.CooperativeGroup( + pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) # MMA pipeline_P = cutlass.pipeline.PipelineAsyncUmma.create( - num_stages=self.s_stage, - producer_group=pipeline_pdS_producer_group, - consumer_group=pipeline_pdS_consumer_group, - barrier_storage=storage.p_mbar_ptr.data_ptr(), + num_stages=self.S_stage, + producer_group=pipeline_PdS_producer_group, + consumer_group=pipeline_PdS_consumer_group, + barrier_storage=storage.P_mbar_ptr.data_ptr(), ) pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=self.dS_stage, - producer_group=pipeline_pdS_producer_group, - consumer_group=pipeline_pdS_consumer_group, + producer_group=pipeline_PdS_producer_group, + consumer_group=pipeline_PdS_consumer_group, barrier_storage=storage.dS_mbar_ptr.data_ptr(), ) @@ -777,19 +763,19 @@ def kernel( sLSE_load = storage.sLSE.get_tensor(sLSE_layout) sLSE_mma = storage.sLSE.get_tensor( - cute.make_layout(shape=(self.tile_m, self.tile_n, self.lse_stage), stride=(0, 1, 0)) + cute.make_layout(shape=(self.tile_m, self.tile_n, self.LSE_stage), stride=(0, 1, 0)) ) sdPsum_load = storage.sdPsum.get_tensor(sdPsum_layout) sdPsum_mma = storage.sdPsum.get_tensor( - cute.make_layout(shape=(self.tile_m, self.tile_n, self.dpsum_stage), stride=(0, 1, 0)) + cute.make_layout(shape=(self.tile_m, self.tile_n, self.dPsum_stage), stride=(0, 1, 0)) ) sdV = storage.sdO.get_tensor( - sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype ) sdK = storage.sQ.get_tensor( - sdKdV_layout.outer, swizzle=sdKdV_layout.inner, dtype=self.dk_dtype + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype ) assert cute.cosize(sdV) * self.dv_dtype.width <= cute.cosize(sdO) * self.do_dtype.width, ( @@ -936,8 +922,7 @@ def kernel( pipeline_S, pipeline_P, pipeline_dS, - pipeline_dV, - pipeline_dK, + pipeline_dKV, pipeline_dP, pipeline_dQ, block_info, @@ -978,8 +963,7 @@ def kernel( pipeline_S, pipeline_P, pipeline_dS, - pipeline_dV, - pipeline_dK, + pipeline_dKV, pipeline_dP, softmax_scale, softmax_scale_log2, @@ -993,7 +977,7 @@ def kernel( mdK_tma_tensor, tma_atom_dV, tma_atom_dK, - tiled_copy_r2s_dKdV, + tiled_copy_r2s_dKV, mdK_semaphore, mdV_semaphore, ) @@ -1050,7 +1034,7 @@ def load( TileSchedulerCls: Callable, ): producer_state_Q = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.q_stage + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage @@ -1191,8 +1175,7 @@ def mma( pipeline_S: PipelineAsync, pipeline_P: PipelineAsync, pipeline_dS: PipelineAsync, - pipeline_dV: PipelineAsync, - pipeline_dK: PipelineAsync, + pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, pipeline_dQ: PipelineAsync, block_info: BlockInfo, @@ -1239,28 +1222,25 @@ def mma( mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, A_idx=0) consumer_state_Q = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.q_stage + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) producer_state_S = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.s_stage + cutlass.pipeline.PipelineUserType.Producer, self.S_stage ) producer_state_dP = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dP_stage ) consumer_state_P = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.s_stage + cutlass.pipeline.PipelineUserType.Consumer, self.S_stage ) consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage ) - dV_producer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dV_stage - ) - dK_producer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dK_stage + producer_state_dKV = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, 2 ) producer_state_dQ = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage @@ -1354,13 +1334,9 @@ def mma( pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() - pipeline_dV.producer_acquire(dV_producer_state) - pipeline_dV.producer_commit(dV_producer_state) - dV_producer_state.advance() - - pipeline_S.producer_tail(producer_state_S) - pipeline_dP.producer_tail(producer_state_dP) - pipeline_dV.producer_tail(dV_producer_state) + pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.producer_commit(producer_state_dKV) + producer_state_dKV.advance() # ----------------------------------------------------------- ###### Remaining 2 @@ -1368,9 +1344,9 @@ def mma( # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) - pipeline_dK.producer_acquire(dK_producer_state) - pipeline_dK.producer_commit(dK_producer_state) - dK_producer_state.advance() + pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.producer_commit(producer_state_dKV) + producer_state_dKV.advance() # 2) dQ = dS @ K mma_dsk_fn(A_idx=consumer_state_dS.index) @@ -1381,12 +1357,14 @@ def mma( pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() - pipeline_dK.producer_tail(dK_producer_state) - pipeline_dQ.producer_tail(producer_state_dQ) - tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + pipeline_S.producer_tail(producer_state_S) + pipeline_dP.producer_tail(producer_state_dP) + pipeline_dKV.producer_tail(producer_state_dKV) + pipeline_dQ.producer_tail(producer_state_dQ) + @cute.jit def split_wg( self, @@ -1452,8 +1430,7 @@ def compute_loop( pipeline_S: PipelineAsync, pipeline_P: PipelineAsync, pipeline_dS: PipelineAsync, - pipeline_dV: PipelineAsync, - pipeline_dK: PipelineAsync, + pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, @@ -1467,7 +1444,7 @@ def compute_loop( mdK_tma_tensor: Optional[cute.Tensor], tma_atom_dV: Optional[cute.CopyAtom], tma_atom_dK: Optional[cute.CopyAtom], - tiled_copy_r2s_dKdV: Optional[cute.TiledCopy], + tiled_copy_r2s_dKV: Optional[cute.TiledCopy], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], ): @@ -1491,19 +1468,21 @@ def compute_loop( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) - s_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.s_stage + consumer_state_S = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.S_stage ) - p_producer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.s_stage + producer_state_P = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.S_stage ) - dS_producer_state = cutlass.pipeline.make_pipeline_state( + producer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dS_stage ) - - dP_consumer_state = cutlass.pipeline.make_pipeline_state( + consumer_state_dP = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dP_stage ) + consumer_state_dKV = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 2 + ) lse_consumer_phase = psum_consumer_phase = cute.Int32(0) @@ -1527,8 +1506,8 @@ def compute_loop( # Mainloop for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - pipeline_S.consumer_wait(s_consumer_state) - pipeline_P.producer_acquire(p_producer_state) + pipeline_S.consumer_wait(consumer_state_S) + pipeline_P.producer_acquire(producer_state_P) if warp_idx == self.compute_warp_ids[0]: cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) @@ -1603,11 +1582,6 @@ def compute_loop( ) tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] - WIDTH = cute.arch.WARP_SIZE - CLAMP = WIDTH - 1 - MAC = (0 << 8) | CLAMP - FULL = cute.arch.FULL_MASK - lidx = cute.arch.lane_idx() tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 @@ -1619,17 +1593,9 @@ def compute_loop( for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): own0 = tLSErLSE[(lidx, 0), i, 0, 0] own1 = tLSErLSE[(lidx + 1, 0), i, 0, 0] - # own1 = cute.arch.shuffle_sync(own0, offset=((lidx + 1) & CLAMP), - # mask=FULL, mask_and_clamp=MAC) - for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): - lse_j = cute.arch.shuffle_sync( - own0, offset=j, mask=FULL, mask_and_clamp=MAC - ) - lse_j1 = cute.arch.shuffle_sync( - own1, offset=j, mask=FULL, mask_and_clamp=MAC - ) - + lse_j = utils.shuffle_sync(own0, offset=j) + lse_j1 = utils.shuffle_sync(own1, offset=j) tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.fma_packed_f32x2( ((tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0])), (softmax_scale_log2, softmax_scale_log2), @@ -1650,11 +1616,13 @@ def compute_loop( number_of_threads=self.num_compute_threads, ) - pipeline_P.producer_commit(p_producer_state) - p_producer_state.advance() + pipeline_P.producer_commit(producer_state_P) + producer_state_P.advance() - pipeline_S.consumer_release(s_consumer_state) - s_consumer_state.advance() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_S.consumer_release(consumer_state_S) + consumer_state_S.advance() if warp_idx == self.compute_warp_ids[0]: with cute.arch.elect_one(): @@ -1667,8 +1635,8 @@ def compute_loop( cute.arch.mbarrier_wait(dpsum_full_mbar_ptr, psum_consumer_phase) psum_consumer_phase ^= 1 - pipeline_dP.consumer_wait(dP_consumer_state) - pipeline_dS.producer_acquire(dS_producer_state) + pipeline_dP.consumer_wait(consumer_state_dP) + pipeline_dS.producer_acquire(producer_state_dS) #### TMEM->RMEM (Load dP from TMEM) tiled_tmem_ld_dP = tcgen05.make_tmem_copy(tmem_load_atom, tdPtdP) @@ -1721,34 +1689,27 @@ def compute_loop( for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() - own0 = tPsumrPsum[(lidx, 0), i, 0, 0] own1 = tPsumrPsum[(lidx + 1, 0), i, 0, 0] - for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): - psum_j = cute.arch.shuffle_sync( - own0, offset=j, mask=FULL, mask_and_clamp=MAC - ) - psum_j1 = cute.arch.shuffle_sync( - own1, offset=j, mask=FULL, mask_and_clamp=MAC - ) - + psum_j = utils.shuffle_sync(own0, offset=j) + psum_j1 = utils.shuffle_sync(own1, offset=j) tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0] = utils.sub_packed_f32x2( (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), (psum_j, psum_j1) ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.mul_packed_f32x2( (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0]), (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), ) - tSrS_t2r_bf16[j, i, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.ds_dtype) tSrS_t2r_bf16[j + 1, i, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.ds_dtype) cute.autovec_copy(tSrS_t2r_bf16[None, i, 0, 0], tdKsdS[None, i, 0, 0]) - pipeline_dP.consumer_release(dP_consumer_state) - dP_consumer_state.advance() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dP.consumer_release(consumer_state_dP) + consumer_state_dP.advance() cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta @@ -1758,15 +1719,15 @@ def compute_loop( number_of_threads=self.num_compute_threads, ) - pipeline_dS.producer_commit(dS_producer_state) - dS_producer_state.advance() + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() if warp_idx == self.compute_warp_ids[0]: with cute.arch.elect_one(): cute.arch.mbarrier_arrive(dpsum_empty_mbar_ptr) if const_expr(not self.use_tma_store): - self.epilogue_dKV( + consumer_state_dKV = self.epilogue_dKV( tidx, warp_idx, batch_idx, @@ -1778,14 +1739,14 @@ def compute_loop( tdKtdK, mdV, mdK, - pipeline_dV, - pipeline_dK, + pipeline_dKV, + consumer_state_dKV, softmax_scale, ) else: - thr_copy_r2s_dKdV = tiled_copy_r2s_dKdV.get_slice(tidx) + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(tidx) #### STORE dV - self.epilogue_dK_or_dV_tma( + consumer_state_dKV = self.epilogue_dK_or_dV_tma( tidx, batch_idx, head_idx, @@ -1795,14 +1756,15 @@ def compute_loop( mdV_tma_tensor, sdV, tma_atom_dV, - thr_copy_r2s_dKdV, - pipeline_dV, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, None, # Don't scale int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdV_semaphore, ) #### STORE dK - self.epilogue_dK_or_dV_tma( + consumer_state_dKV = self.epilogue_dK_or_dV_tma( tidx, batch_idx, head_idx, @@ -1812,8 +1774,9 @@ def compute_loop( mdK_tma_tensor, sdK, tma_atom_dK, - thr_copy_r2s_dKdV, - pipeline_dK, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, softmax_scale, int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, @@ -1961,8 +1924,8 @@ def epilogue_dKV( tdKtdK: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, - pipeline_dV: PipelineAsync, - pipeline_dK: PipelineAsync, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, softmax_scale: Float32, ): wg_idx = ( @@ -1970,13 +1933,6 @@ def epilogue_dKV( ) // 128 num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 - dV_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dV_stage - ) - dK_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dK_stage - ) - assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" mdV_cur = mdV[None, None, head_idx, batch_idx] mdK_cur = mdK[None, None, head_idx, batch_idx] @@ -1986,7 +1942,7 @@ def epilogue_dKV( ) # dV - pipeline_dV.consumer_wait(dV_consumer_state) + pipeline_dKV.consumer_wait(consumer_state_dKV) tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) @@ -2031,11 +1987,13 @@ def epilogue_dKV( cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) - pipeline_dV.consumer_release(dV_consumer_state) - dV_consumer_state.advance() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() # dK - pipeline_dK.consumer_wait(dK_consumer_state) + pipeline_dKV.consumer_wait(consumer_state_dKV) tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) @@ -2082,8 +2040,11 @@ def epilogue_dKV( cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) - pipeline_dK.consumer_release(dK_consumer_state) - dK_consumer_state.advance() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV @cute.jit def epilogue_dK_or_dV_tma( @@ -2097,12 +2058,13 @@ def epilogue_dK_or_dV_tma( mdKV: cute.Tensor, sdKV: cute.Tensor, tma_atom_dKV: cute.CopyAtom, - thr_copy_r2s_dKdV: cute.TiledCopy, - pipeline: PipelineAsync, + thr_copy_r2s_dKV: cute.TiledCopy, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, scale: Optional[Float32], barrier_id: Int32, mdKV_semaphore: Optional[cute.Tensor], - ): + ) -> cutlass.pipeline.PipelineState: # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) # head_dim = head_dim_v, dk_dtype = dv_dtype @@ -2117,7 +2079,7 @@ def epilogue_dK_or_dV_tma( gdKV_p = cute.local_tile(mdKV_cur, (self.tile_m, self.tile_hdimv), (n_block, 0)) gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) - gdKV_epi = cute.local_tile(gdKV, self.sdKdV_epi_tile, (0, None)) + gdKV_epi = cute.local_tile(gdKV, self.sdKV_epi_tile, (0, None)) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] @@ -2141,16 +2103,9 @@ def epilogue_dK_or_dV_tma( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) - if const_expr(self.deterministic): - read_flag = False - else: - read_flag = True + read_flag = const_expr(not self.deterministic) - # TODO: maybe support more than 1 stage - consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, 1 - ) - pipeline.consumer_wait(consumer_state) + pipeline_dKV.consumer_wait(consumer_state_dKV) # semaphore acquire if const_expr(self.deterministic): @@ -2161,9 +2116,7 @@ def epilogue_dK_or_dV_tma( for s in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV) - thr_tmem_load = tiled_tmem_load.get_slice(tidx) - + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) tdKVtdKV_t2r_p = thr_tmem_load.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): @@ -2196,7 +2149,7 @@ def epilogue_dK_or_dV_tma( tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) # RMEM -> SMEM -- setup - tdKVcdKV_r2s_p = thr_copy_r2s_dKdV.partition_S(cdKV) + tdKVcdKV_r2s_p = thr_copy_r2s_dKV.partition_S(cdKV) tdKVcdKV_r2s = self.split_wg(tdKVcdKV_r2s_p, wg_idx, num_wg) tdKVcdKV_r2s = cute.logical_divide( tdKVcdKV_r2s, @@ -2209,14 +2162,14 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVcdKV_r2s.shape) - tdKVsdKV_r2s = thr_copy_r2s_dKdV.partition_D(sdKV) + tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), ( "RMEM<->SMEM fragment size mismatch" ) # RMEM -> SMEM -- copy, fence and barrier - cute.copy(thr_copy_r2s_dKdV, tdKVrdKV_r2s, tdKVsdKV_r2s) + cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) @@ -2249,5 +2202,8 @@ def epilogue_dK_or_dV_tma( cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1) - pipeline.consumer_release(consumer_state) - consumer_state.advance() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV From d101fa73c6a8ccb4e0b95eb2aea77d1dfc1ad39e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 00:21:55 -0400 Subject: [PATCH 170/258] [Cute,Bwd,Sm100] All compute warps wait for lse_barrier --- flash_attn/cute/flash_bwd_sm100.py | 157 +++++++++++++---------------- 1 file changed, 68 insertions(+), 89 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 41a14180d55..ff0a74d5d2d 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -26,8 +26,7 @@ ParamsBase, ) -# from flash_attn.cute import barrier -from flash_attn.cute import named_barrier as barrier # TODO: temp, to make linter pass +from flash_attn.cute import barrier from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 @@ -139,7 +138,6 @@ def _setup_attributes(self): self.S_stage = 1 self.dP_stage = 1 self.dS_stage = 1 - self.dQaccum_mma_stage = 1 self.sdQaccum_stage = 2 self.dPsum_stage = 1 self.p_tmem_stage = 1 @@ -472,16 +470,16 @@ def __call__( class SharedStorage: Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] - lse_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] - lse_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] - dpsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] - dpsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] + LSE_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] + LSE_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] + dPsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] + dPsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] P_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] - dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dQaccum_mma_stage] + dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] # TMEM tmem_holding_buf: Int32 @@ -641,19 +639,19 @@ def kernel( storage = smem.allocate(self.shared_storage) tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() - lse_full_mbar_ptr = storage.lse_full_mbar_ptr.data_ptr() - lse_empty_mbar_ptr = storage.lse_empty_mbar_ptr.data_ptr() - dpsum_full_mbar_ptr = storage.dpsum_full_mbar_ptr.data_ptr() - dpsum_empty_mbar_ptr = storage.dpsum_empty_mbar_ptr.data_ptr() + LSE_full_mbar_ptr = storage.LSE_full_mbar_ptr.data_ptr() + LSE_empty_mbar_ptr = storage.LSE_empty_mbar_ptr.data_ptr() + dPsum_full_mbar_ptr = storage.dPsum_full_mbar_ptr.data_ptr() + dPsum_empty_mbar_ptr = storage.dPsum_empty_mbar_ptr.data_ptr() if warp_idx == self.load_warp_id: cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) - cute.arch.mbarrier_init(lse_full_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(lse_empty_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(dpsum_full_mbar_ptr, len([self.compute_warp_ids])) - cute.arch.mbarrier_init(dpsum_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(LSE_full_mbar_ptr, 1) + cute.arch.mbarrier_init(LSE_empty_mbar_ptr, len(self.compute_warp_ids)) + cute.arch.mbarrier_init(dPsum_full_mbar_ptr, 1) + cute.arch.mbarrier_init(dPsum_empty_mbar_ptr, len([self.compute_warp_ids])) pipeline_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) @@ -709,7 +707,7 @@ def kernel( len(self.reduce_warp_ids), ) # Compute pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.dQaccum_mma_stage, + num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, barrier_storage=storage.dQaccum_mbar_ptr.data_ptr(), @@ -878,10 +876,10 @@ def kernel( tma_atom_V, tma_atom_dO, pipeline_Q, - lse_full_mbar_ptr, - lse_empty_mbar_ptr, - dpsum_full_mbar_ptr, - dpsum_empty_mbar_ptr, + LSE_full_mbar_ptr, + LSE_empty_mbar_ptr, + dPsum_full_mbar_ptr, + dPsum_empty_mbar_ptr, pipeline_dO, block_info, SeqlenInfoCls, @@ -956,10 +954,10 @@ def kernel( sdSt, sdS, tdPtdP, - lse_full_mbar_ptr, - lse_empty_mbar_ptr, - dpsum_full_mbar_ptr, - dpsum_empty_mbar_ptr, + LSE_full_mbar_ptr, + LSE_empty_mbar_ptr, + dPsum_full_mbar_ptr, + dPsum_empty_mbar_ptr, pipeline_S, pipeline_P, pipeline_dS, @@ -1024,10 +1022,10 @@ def load( tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, pipeline_Q: PipelineAsync, - lse_full_mbar_ptr: cute.Pointer, - lse_empty_mbar_ptr: cute.Pointer, - dpsum_full_mbar_ptr: cute.Pointer, - dpsum_empty_mbar_ptr: cute.Pointer, + LSE_full_mbar_ptr: cute.Pointer, + LSE_empty_mbar_ptr: cute.Pointer, + dPsum_full_mbar_ptr: cute.Pointer, + dPsum_empty_mbar_ptr: cute.Pointer, pipeline_dO: PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1095,9 +1093,9 @@ def load( # LSE with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] + LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - load_LSE(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) + load_LSE(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=LSE_full_mbar_ptr) # V & dO pipeline_dO.producer_acquire(producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"]) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) @@ -1107,9 +1105,9 @@ def load( # dPsum with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - dpsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - load_dPsum(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) + load_dPsum(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=dPsum_full_mbar_ptr) lse_empty_consumer_phase = cute.Int32(0) dpsum_empty_consumer_phase = cute.Int32(0) @@ -1121,26 +1119,26 @@ def load( pipeline_Q.producer_commit(producer_state_Q) producer_state_Q.advance() # LSE - cute.arch.mbarrier_wait(lse_empty_mbar_ptr, lse_empty_consumer_phase) + cute.arch.mbarrier_wait(LSE_empty_mbar_ptr, lse_empty_consumer_phase) lse_empty_consumer_phase ^= 1 with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - lse_full_mbar_ptr, self.tma_copy_bytes["LSE"] + LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - load_LSE(src_idx=m_block, dst_idx=0, tma_bar_ptr=lse_full_mbar_ptr) + load_LSE(src_idx=m_block, dst_idx=0, tma_bar_ptr=LSE_full_mbar_ptr) # dO pipeline_dO.producer_acquire(producer_state_dO) load_dO(m_block, producer_state=producer_state_dO) pipeline_dO.producer_commit(producer_state_dO) producer_state_dO.advance() # dPsum - cute.arch.mbarrier_wait(dpsum_empty_mbar_ptr, dpsum_empty_consumer_phase) + cute.arch.mbarrier_wait(dPsum_empty_mbar_ptr, dpsum_empty_consumer_phase) dpsum_empty_consumer_phase ^= 1 with cute.arch.elect_one(): cute.arch.mbarrier_arrive_and_expect_tx( - dpsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - load_dPsum(src_idx=m_block, dst_idx=0, tma_bar_ptr=dpsum_full_mbar_ptr) + load_dPsum(src_idx=m_block, dst_idx=0, tma_bar_ptr=dPsum_full_mbar_ptr) pipeline_Q.producer_tail(producer_state_Q) pipeline_dO.producer_tail(producer_state_dO) @@ -1243,7 +1241,7 @@ def mma( cutlass.pipeline.PipelineUserType.Producer, 2 ) producer_state_dQ = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dQaccum_mma_stage + cutlass.pipeline.PipelineUserType.Producer, 1 ) tile_scheduler = TileSchedulerCls() @@ -1423,10 +1421,10 @@ def compute_loop( sdSt: cute.Tensor, sdSt_pi: cute.Tensor, tdPtdP: cute.Tensor, - lse_full_mbar_ptr: cute.Pointer, - lse_empty_mbar_ptr: cute.Pointer, - dpsum_full_mbar_ptr: cute.Pointer, - dpsum_empty_mbar_ptr: cute.Pointer, + LSE_full_mbar_ptr: cute.Pointer, + LSE_empty_mbar_ptr: cute.Pointer, + dPsum_full_mbar_ptr: cute.Pointer, + dPsum_empty_mbar_ptr: cute.Pointer, pipeline_S: PipelineAsync, pipeline_P: PipelineAsync, pipeline_dS: PipelineAsync, @@ -1450,7 +1448,6 @@ def compute_loop( ): # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 - tidx = cute.arch.thread_idx()[0] % 128 # 0...128 wg_idx = ( cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) @@ -1484,7 +1481,7 @@ def compute_loop( cutlass.pipeline.PipelineUserType.Consumer, 2 ) - lse_consumer_phase = psum_consumer_phase = cute.Int32(0) + consumer_phase_LSE = consumer_phase_dPsum = cute.Int32(0) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -1508,22 +1505,14 @@ def compute_loop( for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_S.consumer_wait(consumer_state_S) pipeline_P.producer_acquire(producer_state_P) + cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) + consumer_phase_LSE ^= 1 - if warp_idx == self.compute_warp_ids[0]: - cute.arch.mbarrier_wait(lse_full_mbar_ptr, lse_consumer_phase) - lse_consumer_phase ^= 1 - - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) - thr_tmem_load = tiled_tmem_load.get_slice(tidx) - + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(tidx) tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) - tStP = cute.make_tensor( - tStS.iterator, - cute.composition(tStS.layout, cute.make_layout((self.tile_m, tileP_f32_like))), - ) + tStP = cute.composition(tStS, cute.make_layout((self.tile_m, tileP_f32_like))) - tiled_tmem_st = tcgen05.make_tmem_copy(tmem_store_atom, tStP) - thr_tmem_st = tiled_tmem_st.get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) #### TMEM tStS_t2r_p = thr_tmem_load.partition_S(tStS) @@ -1531,17 +1520,17 @@ def compute_loop( #### RMEM tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) - tScS_tensor = cute.make_tensor(tScS.iterator, tScS.layout) - tScS_t2r_p = thr_tmem_load.partition_D(tScS_tensor) + tScS_t2r_p = thr_tmem_load.partition_D(tScS) tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 #### TMEM->RMEM (Load S from TMEM) - cute.copy(tiled_tmem_load, tStS_t2r, tSrS_t2r) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) cute.arch.fence_view_async_tmem_load() - #### Sync for load fence and LSE + # Without this barrier, we could have 1 warp writing to P in tmem while + # another warp is still reading S from tmem. cute.arch.barrier( barrier_id=int(NamedBarrierBwdSm100.Compute), number_of_threads=self.num_compute_threads, @@ -1549,11 +1538,7 @@ def compute_loop( #### APPLY MASK if const_expr(self.is_causal or self.is_local): - mask_fn( - tSrS_t2r, - tScS_t2r, - m_block=m_block, - ) + mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block) # --------------------------------------------- #### P = exp(S - LSE) @@ -1565,10 +1550,10 @@ def compute_loop( cute.composition(tScS.layout, cute.make_layout((self.tile_m, tileP_f32_like))), ) - tScP_r2t_p = thr_tmem_st.partition_S(cP_f32) + tScP_r2t_p = thr_tmem_store.partition_S(cP_f32) tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) - tStP_r2t_p = thr_tmem_st.partition_D(tStP) + tStP_r2t_p = thr_tmem_store.partition_D(tStP) tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) #### Compute P = exp(S * scale - LSE) @@ -1582,7 +1567,7 @@ def compute_loop( ) tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] - lidx = cute.arch.lane_idx() + lane_idx = cute.arch.lane_idx() tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 tSrP_r2t = cute.make_tensor( @@ -1591,8 +1576,8 @@ def compute_loop( ) for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): - own0 = tLSErLSE[(lidx, 0), i, 0, 0] - own1 = tLSErLSE[(lidx + 1, 0), i, 0, 0] + own0 = tLSErLSE[(lane_idx, 0), i, 0, 0] + own1 = tLSErLSE[(lane_idx + 1, 0), i, 0, 0] for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): lse_j = utils.shuffle_sync(own0, offset=j) lse_j1 = utils.shuffle_sync(own1, offset=j) @@ -1601,20 +1586,14 @@ def compute_loop( (softmax_scale_log2, softmax_scale_log2), (-lse_j, -lse_j1), ) - tSrS_t2r[j, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j, i, 0, 0]) tSrS_t2r[j + 1, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j + 1, i, 0, 0]) - tSrP_r2t[j, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.q_dtype) tSrP_r2t[j + 1, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.q_dtype) - cute.copy(thr_tmem_st, tSrP_r2t_f32[None, None], tStP_r2t[None, None, i]) + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.Compute), - number_of_threads=self.num_compute_threads, - ) pipeline_P.producer_commit(producer_state_P) producer_state_P.advance() @@ -1624,16 +1603,16 @@ def compute_loop( pipeline_S.consumer_release(consumer_state_S) consumer_state_S.advance() - if warp_idx == self.compute_warp_ids[0]: - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive(lse_empty_mbar_ptr) + # Already sync_warp before this + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(LSE_empty_mbar_ptr) # --------------------------------------------- # dS.T = P.T * (dP.T - D) # --------------------------------------------- if warp_idx == self.compute_warp_ids[0]: - cute.arch.mbarrier_wait(dpsum_full_mbar_ptr, psum_consumer_phase) - psum_consumer_phase ^= 1 + cute.arch.mbarrier_wait(dPsum_full_mbar_ptr, consumer_phase_dPsum) + consumer_phase_dPsum ^= 1 pipeline_dP.consumer_wait(consumer_state_dP) pipeline_dS.producer_acquire(producer_state_dS) @@ -1689,8 +1668,8 @@ def compute_loop( for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() - own0 = tPsumrPsum[(lidx, 0), i, 0, 0] - own1 = tPsumrPsum[(lidx + 1, 0), i, 0, 0] + own0 = tPsumrPsum[(lane_idx, 0), i, 0, 0] + own1 = tPsumrPsum[(lane_idx + 1, 0), i, 0, 0] for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): psum_j = utils.shuffle_sync(own0, offset=j) psum_j1 = utils.shuffle_sync(own1, offset=j) @@ -1724,7 +1703,7 @@ def compute_loop( if warp_idx == self.compute_warp_ids[0]: with cute.arch.elect_one(): - cute.arch.mbarrier_arrive(dpsum_empty_mbar_ptr) + cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( @@ -1831,7 +1810,7 @@ def dQacc_reduce( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() dQ_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dQaccum_mma_stage + cutlass.pipeline.PipelineUserType.Consumer, 1 ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx From 82c9cbb97fe4c406c63a47a6bc8afc79041ae82f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 00:28:02 -0400 Subject: [PATCH 171/258] [Cute,Bwd,Sm100] sdQaccum doesn't need swizzle --- flash_attn/cute/flash_bwd_sm100.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index ff0a74d5d2d..0c2bdad1ced 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -110,11 +110,11 @@ def __init__( SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS - self.tmem_s_offset = 0 - self.tmem_p_offset = 0 # overlap with S - self.tmem_dV_offset = self.tmem_s_offset + self.tile_n + self.tmem_S_offset = 0 + self.tmem_P_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_S_offset + self.tile_n self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv - self.tmem_dQaccum_offset = self.tmem_dP_offset # overlap with dP + self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.num_regs_reduce = 144 @@ -783,8 +783,7 @@ def kernel( "Not enough space for sdK" ) - swz128 = cute.make_swizzle(3, 4, 3) - sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM # S @@ -806,7 +805,7 @@ def kernel( thr_mma_dQ = tiled_mma_dQ.get_slice(0) dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) - tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQaccum_offset, tdQtdQ.layout) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQ_offset, tdQtdQ.layout) # dP dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) From 91f14ca07b792645b72efbb05b233907a831c898 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 01:15:54 -0400 Subject: [PATCH 172/258] [Cute,Bwd,Sm100] Try gemm_ptx --- flash_attn/cute/blackwell_helpers.py | 23 ++++++++++ flash_attn/cute/flash_bwd_sm100.py | 64 ++++++++++++++++++---------- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index aefb6182575..83ba1cd518d 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -35,6 +35,29 @@ def gemm_w_idx( cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial(mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init) + + @cute.jit def gemm( tiled_mma: cute.TiledMma, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 0c2bdad1ced..a3cf59b697e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -16,7 +16,7 @@ from flash_attn.cute import utils from flash_attn.cute import copy_utils from flash_attn.cute import pipeline -from flash_attn.cute.blackwell_helpers import gemm_w_idx +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo @@ -759,18 +759,18 @@ def kernel( cute.recast_ptr(sdO.iterator, swizzle_=sdOt_layout.inner), sdOt_layout.outer ) - sLSE_load = storage.sLSE.get_tensor(sLSE_layout) + sLSE = storage.sLSE.get_tensor(sLSE_layout) sLSE_mma = storage.sLSE.get_tensor( cute.make_layout(shape=(self.tile_m, self.tile_n, self.LSE_stage), stride=(0, 1, 0)) ) - sdPsum_load = storage.sdPsum.get_tensor(sdPsum_layout) + sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) sdPsum_mma = storage.sdPsum.get_tensor( cute.make_layout(shape=(self.tile_m, self.tile_n, self.dPsum_stage), stride=(0, 1, 0)) ) sdV = storage.sdO.get_tensor( - sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype ) sdK = storage.sQ.get_tensor( sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype @@ -867,8 +867,8 @@ def kernel( sQ, sK, sV, - sLSE_load, - sdPsum_load, + sLSE, + sdPsum, sdO, tma_atom_Q, tma_atom_K, @@ -1209,14 +1209,29 @@ def mma( tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) + # mma_qk_fn = partial( + # gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, A_idx=0, zero_init=True + # ) mma_dov_fn = partial( gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True ) + # mma_dov_fn = partial( + # gemm_ptx_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, sA=sV, sB=sdOt, A_idx=0, zero_init=True + # ) mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, A_idx=None) + # mma_pdo_fn = partial( + # gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO, A_idx=None + # ) mma_dsk_fn = partial( gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, B_idx=0, zero_init=True ) + # mma_dsk_fn = partial( + # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, B_idx=0, zero_init=True + # ) mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, A_idx=0) + # mma_dsq_fn = partial( + # gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=sdSt, sB=sQt, A_idx=0 + # ) consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage @@ -1270,7 +1285,7 @@ def mma( # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) pipeline_dP.producer_acquire(producer_state_dP) - pipeline_dQ.producer_acquire(producer_state_dQ) + pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP mma_dov_fn(B_idx=consumer_state_dO.index) # Don't release dO yet pipeline_dP.producer_commit(producer_state_dP) @@ -1304,7 +1319,7 @@ def mma( # 2) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) - pipeline_dP.producer_acquire(producer_state_dP) + pipeline_dP.producer_acquire(producer_state_dP) # dP uses the same tmem as dQ mma_dsk_fn(A_idx=consumer_state_dS.index) pipeline_dQ.producer_commit(producer_state_dQ) producer_state_dQ.advance() @@ -1318,7 +1333,7 @@ def mma( # 4) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dQ.producer_acquire(producer_state_dQ) + pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.producer_commit(producer_state_dP) producer_state_dP.advance() @@ -1331,9 +1346,11 @@ def mma( pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() + # signal to the epilogue that dV is ready pipeline_dKV.producer_acquire(producer_state_dKV) pipeline_dKV.producer_commit(producer_state_dKV) producer_state_dKV.advance() + pipeline_dKV.producer_acquire(producer_state_dKV) # ----------------------------------------------------------- ###### Remaining 2 @@ -1341,7 +1358,7 @@ def mma( # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) - pipeline_dKV.producer_acquire(producer_state_dKV) + # signal to the epilogue that dK is ready pipeline_dKV.producer_commit(producer_state_dKV) producer_state_dKV.advance() @@ -1349,6 +1366,7 @@ def mma( mma_dsk_fn(A_idx=consumer_state_dS.index) pipeline_dQ.producer_commit(producer_state_dQ) producer_state_dQ.advance() + # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() pipeline_dS.consumer_release(consumer_state_dS) @@ -1556,15 +1574,15 @@ def compute_loop( tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) #### Compute P = exp(S * scale - LSE) - tLSE = thr_tmem_load.partition_D(sLSE_2D) + tLSEsLSE_s2r = thr_tmem_load.partition_D(sLSE_2D) # split to wg0 & wg1 - tLSErLSE_p = cute.make_tensor( - cute.recast_ptr(tLSE.iterator), + tLSEsLSE_p = cute.make_tensor( + cute.recast_ptr(tLSEsLSE_s2r.iterator), cute.make_layout( (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) ), ) - tLSErLSE = tLSErLSE_p[None, (None, wg_idx), None, None] + tLSEsLSE = tLSEsLSE_p[None, (None, wg_idx), None, None] lane_idx = cute.arch.lane_idx() @@ -1575,8 +1593,8 @@ def compute_loop( ) for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): - own0 = tLSErLSE[(lane_idx, 0), i, 0, 0] - own1 = tLSErLSE[(lane_idx + 1, 0), i, 0, 0] + own0 = tLSEsLSE[(lane_idx, 0), i, 0, 0] + own1 = tLSEsLSE[(lane_idx + 1, 0), i, 0, 0] for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): lse_j = utils.shuffle_sync(own0, offset=j) lse_j1 = utils.shuffle_sync(own1, offset=j) @@ -1653,22 +1671,22 @@ def compute_loop( cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape ) - tPsum = thr_tmem_load.partition_D(sPsum_2D) - tPsumrPsum_p = cute.make_tensor( - cute.recast_ptr(tPsum.iterator), + tLSEsdPsum_s2r = thr_tmem_load.partition_D(sPsum_2D) + tLSEsdPsum_p = cute.make_tensor( + cute.recast_ptr(tLSEsdPsum_s2r.iterator), cute.make_layout( (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) ), ) - tPsumrPsum = tPsumrPsum_p[ + tLSEsdPsum = tLSEsdPsum_p[ None, (None, wg_idx), None, None - ] # self.split_wg(tLSErLSE_p, wg_idx, num_wg) + ] # self.split_wg(tLSEsLSE_p, wg_idx, num_wg) for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() - own0 = tPsumrPsum[(lane_idx, 0), i, 0, 0] - own1 = tPsumrPsum[(lane_idx + 1, 0), i, 0, 0] + own0 = tLSEsdPsum[(lane_idx, 0), i, 0, 0] + own1 = tLSEsdPsum[(lane_idx + 1, 0), i, 0, 0] for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): psum_j = utils.shuffle_sync(own0, offset=j) psum_j1 = utils.shuffle_sync(own1, offset=j) From 53c884b793cbd882de438f1082fa740415a06105 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 20:39:09 -0400 Subject: [PATCH 173/258] [Cute,Bwd,Sm100] Clean up compute fn --- flash_attn/cute/flash_bwd_sm100.py | 221 ++++++++++++----------------- 1 file changed, 93 insertions(+), 128 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index a3cf59b697e..c6eea6e5260 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -651,7 +651,7 @@ def kernel( cute.arch.mbarrier_init(LSE_full_mbar_ptr, 1) cute.arch.mbarrier_init(LSE_empty_mbar_ptr, len(self.compute_warp_ids)) cute.arch.mbarrier_init(dPsum_full_mbar_ptr, 1) - cute.arch.mbarrier_init(dPsum_empty_mbar_ptr, len([self.compute_warp_ids])) + cute.arch.mbarrier_init(dPsum_empty_mbar_ptr, len(self.compute_warp_ids)) pipeline_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) @@ -748,8 +748,6 @@ def kernel( sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) - sdSt_pi = storage.sdS.get_tensor(sdSt_layout) - sdS = cute.make_tensor( cute.recast_ptr(sdSt.iterator, swizzle_=sdS_layout.inner), sdS_layout.outer ) @@ -760,14 +758,7 @@ def kernel( ) sLSE = storage.sLSE.get_tensor(sLSE_layout) - sLSE_mma = storage.sLSE.get_tensor( - cute.make_layout(shape=(self.tile_m, self.tile_n, self.LSE_stage), stride=(0, 1, 0)) - ) - sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) - sdPsum_mma = storage.sdPsum.get_tensor( - cute.make_layout(shape=(self.tile_m, self.tile_n, self.dPsum_stage), stride=(0, 1, 0)) - ) sdV = storage.sdO.get_tensor( sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype @@ -944,8 +935,8 @@ def kernel( thr_mma_dV, thr_mma_dK, tStS, - sLSE_mma, - sdPsum_mma, + sLSE, + sdPsum, tdVtdV, tdKtdK, mdV, @@ -1429,14 +1420,14 @@ def compute_loop( thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tStS: cute.Tensor, - sLSE_2D: cute.Tensor, - sPsum_2D: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, sdSt: cute.Tensor, - sdSt_pi: cute.Tensor, + sdS: cute.Tensor, tdPtdP: cute.Tensor, LSE_full_mbar_ptr: cute.Pointer, LSE_empty_mbar_ptr: cute.Pointer, @@ -1463,24 +1454,65 @@ def compute_loop( mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], ): + sLSE_2D = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.LSE_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + sdPsum_2D = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.dPsum_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + # if const_expr(self.SdP_swapAB): + if const_expr(True): + sLSE_2D = utils.transpose_view(sLSE_2D) + sdPsum_2D = utils.transpose_view(sdPsum_2D) # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 tidx = cute.arch.thread_idx()[0] % 128 # 0...128 wg_idx = ( cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) ) // 128 + wg_idx = cute.arch.make_warp_uniform(wg_idx) num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 # 2 # wg_idx: # 0: [256...384] # 1: [128...256] + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) + tStP = cute.composition(tStS, cute.make_layout((self.tile_m, tileP_f32_like))) + tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) + tScP = cute.composition(tScS, cute.make_layout((self.tile_m, tileP_f32_like))) + tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(tidx) + tStS_t2r_p = thr_tmem_load.partition_S(tStS) + tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) + tdPtdP_t2r_p = thr_tmem_load.partition_S(tdPtdP) + tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) + tScS_t2r_p = thr_tmem_load.partition_D(tScS) + tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) + tSsLSE_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sLSE_2D)) + tSsLSE = self.split_wg(tSsLSE_p, wg_idx, num_wg) # ((32, 1), 2, 1, 1, STAGE) + tSsdPsum_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) + tSsdPsum = self.split_wg(tSsdPsum_p, wg_idx, num_wg) + tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) + tScP_r2t_p = thr_tmem_store.partition_S(tScP) + tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) + tStP_r2t_p = thr_tmem_store.partition_D(tStP) + tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) consumer_state_S = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.S_stage @@ -1521,31 +1553,14 @@ def compute_loop( # Mainloop for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_S.consumer_wait(consumer_state_S) - pipeline_P.producer_acquire(producer_state_P) - cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) - consumer_phase_LSE ^= 1 - - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(tidx) - tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) - tStP = cute.composition(tStS, cute.make_layout((self.tile_m, tileP_f32_like))) - - thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) - - #### TMEM - tStS_t2r_p = thr_tmem_load.partition_S(tStS) - tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) - - #### RMEM - tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) - tScS_t2r_p = thr_tmem_load.partition_D(tScS) - tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) - - tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 - #### TMEM->RMEM (Load S from TMEM) + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) cute.arch.fence_view_async_tmem_load() + cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) + consumer_phase_LSE ^= 1 + # Without this barrier, we could have 1 warp writing to P in tmem while # another warp is still reading S from tmem. cute.arch.barrier( @@ -1561,29 +1576,6 @@ def compute_loop( #### P = exp(S - LSE) # --------------------------------------------- - #### RMEM (coordinates for P) - cP_f32 = cute.make_tensor( - tScS.iterator, - cute.composition(tScS.layout, cute.make_layout((self.tile_m, tileP_f32_like))), - ) - - tScP_r2t_p = thr_tmem_store.partition_S(cP_f32) - tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) - - tStP_r2t_p = thr_tmem_store.partition_D(tStP) - tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) - - #### Compute P = exp(S * scale - LSE) - tLSEsLSE_s2r = thr_tmem_load.partition_D(sLSE_2D) - # split to wg0 & wg1 - tLSEsLSE_p = cute.make_tensor( - cute.recast_ptr(tLSEsLSE_s2r.iterator), - cute.make_layout( - (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) - ), - ) - tLSEsLSE = tLSEsLSE_p[None, (None, wg_idx), None, None] - lane_idx = cute.arch.lane_idx() tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 @@ -1592,26 +1584,27 @@ def compute_loop( tSrS_t2r[None, 0, None, None].layout, ) - for i in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): - own0 = tLSEsLSE[(lane_idx, 0), i, 0, 0] - own1 = tLSEsLSE[(lane_idx + 1, 0), i, 0, 0] - for j in cutlass.range_constexpr(0, cute.size(tSrP_r2t), 2, unroll=1): - lse_j = utils.shuffle_sync(own0, offset=j) - lse_j1 = utils.shuffle_sync(own1, offset=j) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.fma_packed_f32x2( - ((tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0])), + pipeline_P.producer_acquire(producer_state_P) + for stage in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsLSE_cur = tSsLSE[None, stage, 0, 0, 0] # TODO: have stages + lse_val = tSsLSE_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tSrP_r2t) // 2, unroll_full=True): + lse_pair = ( + utils.shuffle_sync(lse_val, offset=2 * v), + utils.shuffle_sync(lse_val, offset=2 * v + 1), + ) + tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2( + ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])), (softmax_scale_log2, softmax_scale_log2), - (-lse_j, -lse_j1), + (-lse_pair[0], -lse_pair[1]), ) - tSrS_t2r[j, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j, i, 0, 0]) - tSrS_t2r[j + 1, i, 0, 0] = cute.arch.exp2(tSrS_t2r[j + 1, i, 0, 0]) - tSrP_r2t[j, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.q_dtype) - tSrP_r2t[j + 1, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.q_dtype) - - cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None], tStP_r2t[None, None, i]) + tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) + tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) + utils.cvt_f16(tSrS_cur, tSrP_r2t[None, 0, 0]) + cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t[None, None, stage]) cute.arch.fence_view_async_tmem_store() - pipeline_P.producer_commit(producer_state_P) producer_state_P.advance() @@ -1627,30 +1620,15 @@ def compute_loop( # --------------------------------------------- # dS.T = P.T * (dP.T - D) # --------------------------------------------- - if warp_idx == self.compute_warp_ids[0]: - cute.arch.mbarrier_wait(dPsum_full_mbar_ptr, consumer_phase_dPsum) + cute.arch.mbarrier_wait(dPsum_full_mbar_ptr, consumer_phase_dPsum) consumer_phase_dPsum ^= 1 pipeline_dP.consumer_wait(consumer_state_dP) pipeline_dS.producer_acquire(producer_state_dS) #### TMEM->RMEM (Load dP from TMEM) - tiled_tmem_ld_dP = tcgen05.make_tmem_copy(tmem_load_atom, tdPtdP) - thr_tmem_ld_dP = tiled_tmem_ld_dP.get_slice(tidx) - - tdPtdP_t2r_p = thr_tmem_ld_dP.partition_S(tdPtdP) # - tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) - - #### TMEM->RMEM (Load dP from TMEM) - cdP = cute.make_identity_tensor((self.mma_tiler_vdo[0], self.mma_tiler_vdo[1])) - tdPcdP = thr_mma_SdP.partition_C(cdP) - tdPcdP_tensor = cute.make_tensor(tdPcdP.iterator, tdPcdP.layout) - - tdPcdP_t2r_p = thr_tmem_ld_dP.partition_D(tdPcdP_tensor) - tdPcdP_t2r = self.split_wg(tdPcdP_t2r_p, wg_idx, num_wg) - tdPrdP_t2r = cute.make_fragment( - tdPcdP_t2r[(None, 0, None, None)].shape, Float32 - ) # ((32,1),1,1) + # ((32,1),1,1) + tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) #### Sync for load fence and Psum cute.arch.barrier( @@ -1659,48 +1637,35 @@ def compute_loop( ) ##### dS.T = P.T * (dP.T - Psum) - sdSt_mn = cute.make_tensor( - sdSt_pi.iterator, - cute.composition(sdSt_pi.layout, cute.make_layout((self.tile_m, self.tile_n))), - ) + sdSt_mn = cute.composition(sdS, cute.make_layout((self.tile_m, self.tile_n))) tdKsdS = cute.composition( sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape) ) - tSrS_t2r_bf16 = cute.make_tensor( cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape ) - tLSEsdPsum_s2r = thr_tmem_load.partition_D(sPsum_2D) - tLSEsdPsum_p = cute.make_tensor( - cute.recast_ptr(tLSEsdPsum_s2r.iterator), - cute.make_layout( - (tScS_t2r_p.shape[0], (tScS_t2r_p.shape[1] // num_wg, num_wg), 1, 1) - ), - ) - tLSEsdPsum = tLSEsdPsum_p[ - None, (None, wg_idx), None, None - ] # self.split_wg(tLSEsLSE_p, wg_idx, num_wg) - - for i in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): - cute.copy(thr_tmem_ld_dP, tdPtdP_t2r[None, i, None, None], tdPrdP_t2r) + for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): + cute.copy(thr_tmem_load, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() - own0 = tLSEsdPsum[(lane_idx, 0), i, 0, 0] - own1 = tLSEsdPsum[(lane_idx + 1, 0), i, 0, 0] - for j in cutlass.range_constexpr(0, cute.size(tdPrdP_t2r), 2, unroll=1): - psum_j = utils.shuffle_sync(own0, offset=j) - psum_j1 = utils.shuffle_sync(own1, offset=j) - tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0] = utils.sub_packed_f32x2( - (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), (psum_j, psum_j1) + tdPrdP_cur = tdPrdP_t2r[None, 0, 0] + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, 0] # TODO: have stages + dPsum_val = tSsdPsum_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r) // 2, unroll=1): + dPsum_pair = ( + utils.shuffle_sync(dPsum_val, offset=2 * v), + utils.shuffle_sync(dPsum_val, offset=2 * v + 1), ) - tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0] = utils.mul_packed_f32x2( - (tSrS_t2r[j, i, 0, 0], tSrS_t2r[j + 1, i, 0, 0]), - (tdPrdP_t2r[j, 0, 0], tdPrdP_t2r[j + 1, 0, 0]), + tdPrdP_cur[2 * v], tdPrdP_t2r[2 * v + 1] = utils.sub_packed_f32x2( + (tdPrdP_cur[2 * v], tdPrdP_t2r[2 * v + 1]), dPsum_pair ) - tSrS_t2r_bf16[j, i, 0, 0] = tSrS_t2r[j, i, 0, 0].to(self.ds_dtype) - tSrS_t2r_bf16[j + 1, i, 0, 0] = tSrS_t2r[j + 1, i, 0, 0].to(self.ds_dtype) - - cute.autovec_copy(tSrS_t2r_bf16[None, i, 0, 0], tdKsdS[None, i, 0, 0]) + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2( + (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), + ) + utils.cvt_f16(tdPrdP_cur, tSrS_t2r_bf16[None, stage, 0, 0]) + cute.autovec_copy(tSrS_t2r_bf16[None, stage, 0, 0], tdKsdS[None, stage, 0, 0]) cute.arch.sync_warp() with cute.arch.elect_one(): @@ -1718,9 +1683,9 @@ def compute_loop( pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() - if warp_idx == self.compute_warp_ids[0]: - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) + # Already sync_warp before this + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( From 0f56550a69ab0f597e07ba85110a46a1e5f11ed6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 21:28:39 -0400 Subject: [PATCH 174/258] [Cute,Bwd,Sm100] Combine pipeline_S and pipeline_P into 1 --- flash_attn/cute/flash_bwd_sm100.py | 93 ++++++++++++------------------ 1 file changed, 38 insertions(+), 55 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index c6eea6e5260..6cb87b3970d 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -135,7 +135,6 @@ def _setup_attributes(self): self.dS_stage = 1 self.LSE_stage = 1 self.acc_stage = 1 - self.S_stage = 1 self.dP_stage = 1 self.dS_stage = 1 self.sdQaccum_stage = 2 @@ -474,9 +473,8 @@ class SharedStorage: LSE_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] dPsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] dPsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] - S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] - P_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.S_stage] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] @@ -644,12 +642,14 @@ def kernel( dPsum_full_mbar_ptr = storage.dPsum_full_mbar_ptr.data_ptr() dPsum_empty_mbar_ptr = storage.dPsum_empty_mbar_ptr.data_ptr() - if warp_idx == self.load_warp_id: + if warp_idx == 1: cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) + if warp_idx == 2: cute.arch.mbarrier_init(LSE_full_mbar_ptr, 1) cute.arch.mbarrier_init(LSE_empty_mbar_ptr, len(self.compute_warp_ids)) + if warp_idx == 3: cute.arch.mbarrier_init(dPsum_full_mbar_ptr, 1) cute.arch.mbarrier_init(dPsum_empty_mbar_ptr, len(self.compute_warp_ids)) @@ -684,8 +684,8 @@ def kernel( pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) ) - pipeline_S = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.S_stage, + pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.S_mbar_ptr.data_ptr(), @@ -721,13 +721,6 @@ def kernel( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) # MMA - pipeline_P = cutlass.pipeline.PipelineAsyncUmma.create( - num_stages=self.S_stage, - producer_group=pipeline_PdS_producer_group, - consumer_group=pipeline_PdS_consumer_group, - barrier_storage=storage.P_mbar_ptr.data_ptr(), - ) - pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=self.dS_stage, producer_group=pipeline_PdS_producer_group, @@ -907,8 +900,7 @@ def kernel( tdQtdQ, pipeline_Q, pipeline_dO, - pipeline_S, - pipeline_P, + pipeline_S_P, pipeline_dS, pipeline_dKV, pipeline_dP, @@ -948,8 +940,7 @@ def kernel( LSE_empty_mbar_ptr, dPsum_full_mbar_ptr, dPsum_empty_mbar_ptr, - pipeline_S, - pipeline_P, + pipeline_S_P, pipeline_dS, pipeline_dKV, pipeline_dP, @@ -1160,8 +1151,7 @@ def mma( tdQtdQ: cute.Tensor, pipeline_Q: PipelineAsync, pipeline_dO: PipelineAsync, - pipeline_S: PipelineAsync, - pipeline_P: PipelineAsync, + pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, @@ -1230,15 +1220,12 @@ def mma( consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) - producer_state_S = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.S_stage + producer_state_S_P = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, 1 ) producer_state_dP = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dP_stage ) - consumer_state_P = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.S_stage - ) consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage ) @@ -1267,11 +1254,11 @@ def mma( # 1) S = Q0 @ K.T pipeline_Q.consumer_wait(consumer_state_Q) - pipeline_S.producer_acquire(producer_state_S) + pipeline_S_P.producer_acquire(producer_state_S_P) mma_qk_fn(B_idx=consumer_state_Q.index) # Don't release Q yet - pipeline_S.producer_commit(producer_state_S) - producer_state_S.advance() + pipeline_S_P.producer_commit(producer_state_S_P) + producer_state_S_P.advance() # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) @@ -1283,10 +1270,9 @@ def mma( producer_state_dP.advance() # 3) dV = P.T @ dO - pipeline_P.consumer_wait(consumer_state_P) + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.producer_acquire(producer_state_S_P) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) - pipeline_P.consumer_release(consumer_state_P) - consumer_state_P.advance() pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() # ----------------------------------------------------------- @@ -1303,10 +1289,10 @@ def mma( consumer_state_Q_prev = consumer_state_Q.clone() consumer_state_Q.advance() pipeline_Q.consumer_wait(consumer_state_Q) - pipeline_S.producer_acquire(producer_state_S) + # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready mma_qk_fn(B_idx=consumer_state_Q.index) - pipeline_S.producer_commit(producer_state_S) - producer_state_S.advance() + pipeline_S_P.producer_commit(producer_state_S_P) + producer_state_S_P.advance() # 2) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) @@ -1330,13 +1316,15 @@ def mma( producer_state_dP.advance() # 5) dV += P @ dO - pipeline_P.consumer_wait(consumer_state_P) + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.producer_acquire(producer_state_S_P) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) - pipeline_P.consumer_release(consumer_state_P) - consumer_state_P.advance() pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() + pipeline_S_P.producer_commit(producer_state_S_P) + producer_state_S_P.advance() + # signal to the epilogue that dV is ready pipeline_dKV.producer_acquire(producer_state_dKV) pipeline_dKV.producer_commit(producer_state_dKV) @@ -1366,7 +1354,8 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - pipeline_S.producer_tail(producer_state_S) + # Currently it hangs if we have this S_P.producer_tail, will need to understand why + # pipeline_S_P.producer_tail(producer_state_S_P) pipeline_dP.producer_tail(producer_state_dP) pipeline_dKV.producer_tail(producer_state_dKV) pipeline_dQ.producer_tail(producer_state_dQ) @@ -1433,8 +1422,7 @@ def compute_loop( LSE_empty_mbar_ptr: cute.Pointer, dPsum_full_mbar_ptr: cute.Pointer, dPsum_empty_mbar_ptr: cute.Pointer, - pipeline_S: PipelineAsync, - pipeline_P: PipelineAsync, + pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, @@ -1493,6 +1481,10 @@ def compute_loop( tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(tidx) tStS_t2r_p = thr_tmem_load.partition_S(tStS) tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) @@ -1505,20 +1497,14 @@ def compute_loop( tSsdPsum_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) tSsdPsum = self.split_wg(tSsdPsum_p, wg_idx, num_wg) - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 - ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tScP_r2t_p = thr_tmem_store.partition_S(tScP) tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) tStP_r2t_p = thr_tmem_store.partition_D(tStP) tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) - consumer_state_S = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.S_stage - ) - producer_state_P = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.S_stage + consumer_state_S_P = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 ) producer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dS_stage @@ -1552,7 +1538,7 @@ def compute_loop( # Mainloop for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - pipeline_S.consumer_wait(consumer_state_S) + pipeline_S_P.consumer_wait(consumer_state_S_P) #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) @@ -1584,7 +1570,6 @@ def compute_loop( tSrS_t2r[None, 0, None, None].layout, ) - pipeline_P.producer_acquire(producer_state_P) for stage in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsLSE_cur = tSsLSE[None, stage, 0, 0, 0] # TODO: have stages @@ -1605,13 +1590,11 @@ def compute_loop( cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t[None, None, stage]) cute.arch.fence_view_async_tmem_store() - pipeline_P.producer_commit(producer_state_P) - producer_state_P.advance() cute.arch.sync_warp() with cute.arch.elect_one(): - pipeline_S.consumer_release(consumer_state_S) - consumer_state_S.advance() + pipeline_S_P.consumer_release(consumer_state_S_P) + consumer_state_S_P.advance() # Already sync_warp before this with cute.arch.elect_one(): @@ -1657,8 +1640,8 @@ def compute_loop( utils.shuffle_sync(dPsum_val, offset=2 * v), utils.shuffle_sync(dPsum_val, offset=2 * v + 1), ) - tdPrdP_cur[2 * v], tdPrdP_t2r[2 * v + 1] = utils.sub_packed_f32x2( - (tdPrdP_cur[2 * v], tdPrdP_t2r[2 * v + 1]), dPsum_pair + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2( + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair ) tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2( (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), From 22f7daab93d531c5945de850bb245ac668313924 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 20 Oct 2025 23:38:27 -0400 Subject: [PATCH 175/258] [Cute,Bwd,Sm100] Don't shuffle LSE & dPsum, reduce state variables --- flash_attn/cute/flash_bwd_sm100.py | 199 +++++++++++++++++------------ 1 file changed, 114 insertions(+), 85 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 6cb87b3970d..8f62dd617b4 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -87,6 +87,10 @@ def __init__( self.use_tma_store = True self.deterministic = deterministic + # Speed optimizations, does not affect correctness + self.shuffle_LSE = False + self.shuffle_dPsum = False + self.reduce_warp_ids = (0, 1, 2, 3) self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) self.mma_warp_id = 12 @@ -117,12 +121,11 @@ def __init__( self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m - self.num_regs_reduce = 144 + self.num_regs_reduce = 160 self.num_regs_compute = 128 - # self.num_regs_load = 96 - self.num_regs_load = 112 - self.num_regs_mma = 112 + self.num_regs_other = 80 self.num_regs_empty = 24 + assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 self.buffer_align_bytes = 1024 @@ -135,7 +138,6 @@ def _setup_attributes(self): self.dS_stage = 1 self.LSE_stage = 1 self.acc_stage = 1 - self.dP_stage = 1 self.dS_stage = 1 self.sdQaccum_stage = 2 self.dPsum_stage = 1 @@ -474,7 +476,7 @@ class SharedStorage: dPsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] dPsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] - dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dP_stage] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] @@ -691,7 +693,7 @@ def kernel( barrier_storage=storage.S_mbar_ptr.data_ptr(), ) pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( - num_stages=self.dP_stage, + num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dP_mbar_ptr.data_ptr(), @@ -838,7 +840,7 @@ def kernel( # LOAD # (13) if warp_idx == self.load_warp_id: - cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.load( thr_mma_SdP, thr_mma_dV, @@ -872,7 +874,7 @@ def kernel( # MMA # (12) if warp_idx == self.mma_warp_id: - cute.arch.warpgroup_reg_dealloc(self.num_regs_mma) + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) # Alloc tmem buffer tmem_alloc_cols = Int32(self.tmem_alloc_cols) @@ -1220,25 +1222,29 @@ def mma( consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) - producer_state_S_P = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, 1 - ) - producer_state_dP = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dP_stage - ) + # producer_state_S_P = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 1 + # ) + producer_phase_S_P = Int32(1) + # producer_state_dP = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 1 + # ) + producer_phase_dP = Int32(1) consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage ) - producer_state_dKV = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, 2 - ) - producer_state_dQ = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, 1 - ) + # producer_state_dQ = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 1 + # ) + producer_phase_dQ = Int32(1) + # producer_state_dKV = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 2 + # ) + producer_phase_dKV = Int32(1) + cta_group = pipeline_S_P.cta_group tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k @@ -1254,24 +1260,32 @@ def mma( # 1) S = Q0 @ K.T pipeline_Q.consumer_wait(consumer_state_Q) - pipeline_S_P.producer_acquire(producer_state_S_P) + # pipeline_S_P.producer_acquire(producer_state_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) mma_qk_fn(B_idx=consumer_state_Q.index) # Don't release Q yet - pipeline_S_P.producer_commit(producer_state_S_P) - producer_state_S_P.advance() + # pipeline_S_P.producer_commit(producer_state_S_P) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + # producer_state_S_P.advance() + producer_phase_S_P ^= 1 # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dP.producer_acquire(producer_state_dP) - pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP + # pipeline_dP.producer_acquire(producer_state_dP) + pipeline_dP.sync_object_empty.wait(0, producer_phase_dP) + # pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) mma_dov_fn(B_idx=consumer_state_dO.index) # Don't release dO yet - pipeline_dP.producer_commit(producer_state_dP) - producer_state_dP.advance() + # pipeline_dP.producer_commit(producer_state_dP) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + # producer_state_dP.advance() + producer_phase_dP ^= 1 # 3) dV = P.T @ dO # wait for P to be ready, which uses the same tmem as S - pipeline_S_P.producer_acquire(producer_state_S_P) + # pipeline_S_P.producer_acquire(producer_state_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() @@ -1291,15 +1305,20 @@ def mma( pipeline_Q.consumer_wait(consumer_state_Q) # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready mma_qk_fn(B_idx=consumer_state_Q.index) - pipeline_S_P.producer_commit(producer_state_S_P) - producer_state_S_P.advance() + # pipeline_S_P.producer_commit(producer_state_S_P) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + # producer_state_S_P.advance() + producer_phase_S_P ^= 1 # 2) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) - pipeline_dP.producer_acquire(producer_state_dP) # dP uses the same tmem as dQ + # pipeline_dP.producer_acquire(producer_state_dP) # dP uses the same tmem as dQ + pipeline_dP.sync_object_empty.wait(0, producer_phase_dP) mma_dsk_fn(A_idx=consumer_state_dS.index) - pipeline_dQ.producer_commit(producer_state_dQ) - producer_state_dQ.advance() + # pipeline_dQ.producer_commit(producer_state_dQ) + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # producer_state_dQ.advance() + producer_phase_dQ ^= 1 # 3) dK = dS.T @ Q mma_dsq_fn(B_idx=consumer_state_Q_prev.index, zero_init=not accumulate_dK) @@ -1310,26 +1329,35 @@ def mma( # 4) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP + # pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) mma_dov_fn(B_idx=consumer_state_dO.index) - pipeline_dP.producer_commit(producer_state_dP) - producer_state_dP.advance() + # pipeline_dP.producer_commit(producer_state_dP) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + # producer_state_dP.advance() + producer_phase_dP ^= 1 # 5) dV += P @ dO # wait for P to be ready, which uses the same tmem as S - pipeline_S_P.producer_acquire(producer_state_S_P) + # pipeline_S_P.producer_acquire(producer_state_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() - pipeline_S_P.producer_commit(producer_state_S_P) - producer_state_S_P.advance() + # pipeline_S_P.producer_commit(producer_state_S_P) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + # producer_state_S_P.advance() + producer_phase_S_P ^= 1 # signal to the epilogue that dV is ready - pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.producer_commit(producer_state_dKV) - producer_state_dKV.advance() - pipeline_dKV.producer_acquire(producer_state_dKV) + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) # ----------------------------------------------------------- ###### Remaining 2 @@ -1338,13 +1366,17 @@ def mma( pipeline_dS.consumer_wait(consumer_state_dS) mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) # signal to the epilogue that dK is ready - pipeline_dKV.producer_commit(producer_state_dKV) - producer_state_dKV.advance() + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + producer_phase_dKV ^= 1 # 2) dQ = dS @ K mma_dsk_fn(A_idx=consumer_state_dS.index) - pipeline_dQ.producer_commit(producer_state_dQ) - producer_state_dQ.advance() + # pipeline_dQ.producer_commit(producer_state_dQ) + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # producer_state_dQ.advance() + producer_phase_dQ ^= 1 # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() @@ -1356,9 +1388,9 @@ def mma( # Currently it hangs if we have this S_P.producer_tail, will need to understand why # pipeline_S_P.producer_tail(producer_state_S_P) - pipeline_dP.producer_tail(producer_state_dP) - pipeline_dKV.producer_tail(producer_state_dKV) - pipeline_dQ.producer_tail(producer_state_dQ) + # pipeline_dP.producer_tail(producer_state_dP) + # pipeline_dKV.producer_tail(producer_state_dKV) + # pipeline_dQ.producer_tail(producer_state_dQ) @cute.jit def split_wg( @@ -1510,7 +1542,7 @@ def compute_loop( cutlass.pipeline.PipelineUserType.Producer, self.dS_stage ) consumer_state_dP = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dP_stage + cutlass.pipeline.PipelineUserType.Consumer, 1 ) consumer_state_dKV = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 2 @@ -1544,9 +1576,6 @@ def compute_loop( cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) cute.arch.fence_view_async_tmem_load() - cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) - consumer_phase_LSE ^= 1 - # Without this barrier, we could have 1 warp writing to P in tmem while # another warp is still reading S from tmem. cute.arch.barrier( @@ -1554,6 +1583,9 @@ def compute_loop( number_of_threads=self.num_compute_threads, ) + cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) + consumer_phase_LSE ^= 1 + #### APPLY MASK if const_expr(self.is_causal or self.is_local): mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block) @@ -1573,12 +1605,19 @@ def compute_loop( for stage in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsLSE_cur = tSsLSE[None, stage, 0, 0, 0] # TODO: have stages - lse_val = tSsLSE_cur[lane_idx] + if const_expr(not self.shuffle_LSE): + tSrLSE = cute.make_fragment_like(tSsLSE_cur, Float32) + cute.autovec_copy(tSsLSE_cur, tSrLSE) + else: + tSrLSE = tSsLSE_cur[lane_idx] for v in cutlass.range_constexpr(cute.size(tSrP_r2t) // 2, unroll_full=True): - lse_pair = ( - utils.shuffle_sync(lse_val, offset=2 * v), - utils.shuffle_sync(lse_val, offset=2 * v + 1), - ) + if const_expr(not self.shuffle_LSE): + lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1]) + else: + lse_pair = ( + utils.shuffle_sync(tSrLSE, offset=2 * v), + utils.shuffle_sync(tSrLSE, offset=2 * v + 1), + ) tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2( ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])), (softmax_scale_log2, softmax_scale_log2), @@ -1594,11 +1633,8 @@ def compute_loop( cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P) - consumer_state_S_P.advance() - - # Already sync_warp before this - with cute.arch.elect_one(): cute.arch.mbarrier_arrive(LSE_empty_mbar_ptr) + consumer_state_S_P.advance() # --------------------------------------------- # dS.T = P.T * (dP.T - D) @@ -1613,12 +1649,6 @@ def compute_loop( # ((32,1),1,1) tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) - #### Sync for load fence and Psum - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.Compute), - number_of_threads=self.num_compute_threads, - ) - ##### dS.T = P.T * (dP.T - Psum) sdSt_mn = cute.composition(sdS, cute.make_layout((self.tile_m, self.tile_n))) tdKsdS = cute.composition( @@ -1634,12 +1664,19 @@ def compute_loop( tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, 0] # TODO: have stages - dPsum_val = tSsdPsum_cur[lane_idx] + if const_expr(not self.shuffle_dPsum): + tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32) + cute.autovec_copy(tSsdPsum_cur, tSrdPsum) + else: + tSrdPsum = tSsdPsum_cur[lane_idx] for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r) // 2, unroll=1): - dPsum_pair = ( - utils.shuffle_sync(dPsum_val, offset=2 * v), - utils.shuffle_sync(dPsum_val, offset=2 * v + 1), - ) + if const_expr(not self.shuffle_dPsum): + dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1]) + else: + dPsum_pair = ( + utils.shuffle_sync(tSrdPsum, offset=2 * v), + utils.shuffle_sync(tSrdPsum, offset=2 * v + 1), + ) tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2( (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair ) @@ -1653,23 +1690,15 @@ def compute_loop( cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dP.consumer_release(consumer_state_dP) + cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) consumer_state_dP.advance() cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.Compute), - number_of_threads=self.num_compute_threads, - ) - pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() - # Already sync_warp before this - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) - if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( tidx, From 3cac07ac752d390196d31dff2b5ac0db1d4a22d6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 21 Oct 2025 00:26:34 -0400 Subject: [PATCH 176/258] [Cute,Bwd,Sm100] Hardcode dS_stage = 1 --- flash_attn/cute/flash_bwd_sm100.py | 51 +++++++++++++++--------------- flash_attn/cute/pipeline.py | 15 +++++++-- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 8f62dd617b4..1c8d60b46e6 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -135,13 +135,9 @@ def _setup_attributes(self): self.Q_stage = 2 self.k_stage = self.v_stage = 1 self.dO_stage = 1 - self.dS_stage = 1 self.LSE_stage = 1 - self.acc_stage = 1 - self.dS_stage = 1 self.sdQaccum_stage = 2 self.dPsum_stage = 1 - self.p_tmem_stage = 1 self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma self.dQaccum_reduce_stage = self.tile_hdim // 32 @@ -226,7 +222,7 @@ def _setup_smem_layout(self): self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, - self.dS_stage, + 1, ) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, @@ -239,7 +235,7 @@ def _setup_smem_layout(self): self.tiled_mma_dQ, self.mma_tiler_dsk, self.q_dtype, - self.dS_stage, + 1, ) self.sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, @@ -477,7 +473,7 @@ class SharedStorage: dPsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] - dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dS_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] @@ -724,7 +720,7 @@ def kernel( ) # MMA pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( - num_stages=self.dS_stage, + num_stages=1, producer_group=pipeline_PdS_producer_group, consumer_group=pipeline_PdS_consumer_group, barrier_storage=storage.dS_mbar_ptr.data_ptr(), @@ -1185,7 +1181,7 @@ def mma( tiled_mma_dV, self.mma_tiler_pdo, self.q_dtype, - self.acc_stage, + 1, ) tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) tdVrP = thr_mma_dV.make_fragment_A(tP)[None, None, None, 0] @@ -1206,10 +1202,10 @@ def mma( # gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO, A_idx=None # ) mma_dsk_fn = partial( - gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, B_idx=0, zero_init=True + gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, A_idx=0, B_idx=0, zero_init=True ) # mma_dsk_fn = partial( - # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, B_idx=0, zero_init=True + # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, A_idx=0, B_idx=0, zero_init=True # ) mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, A_idx=0) # mma_dsq_fn = partial( @@ -1231,7 +1227,7 @@ def mma( # ) producer_phase_dP = Int32(1) consumer_state_dS = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.dS_stage + cutlass.pipeline.PipelineUserType.Consumer, 1 ) # producer_state_dQ = cutlass.pipeline.make_pipeline_state( # cutlass.pipeline.PipelineUserType.Producer, 1 @@ -1314,7 +1310,7 @@ def mma( pipeline_dS.consumer_wait(consumer_state_dS) # pipeline_dP.producer_acquire(producer_state_dP) # dP uses the same tmem as dQ pipeline_dP.sync_object_empty.wait(0, producer_phase_dP) - mma_dsk_fn(A_idx=consumer_state_dS.index) + mma_dsk_fn() # pipeline_dQ.producer_commit(producer_state_dQ) pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) # producer_state_dQ.advance() @@ -1372,7 +1368,7 @@ def mma( producer_phase_dKV ^= 1 # 2) dQ = dS @ K - mma_dsk_fn(A_idx=consumer_state_dS.index) + mma_dsk_fn() # pipeline_dQ.producer_commit(producer_state_dQ) pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) # producer_state_dQ.advance() @@ -1535,14 +1531,12 @@ def compute_loop( tStP_r2t_p = thr_tmem_store.partition_D(tStP) tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) - consumer_state_S_P = cutlass.pipeline.make_pipeline_state( + consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 cutlass.pipeline.PipelineUserType.Consumer, 1 ) - producer_state_dS = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.dS_stage - ) - consumer_state_dP = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, 1 + # consumer_phase_S_P_dP = Int32(0) + producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 + cutlass.pipeline.PipelineUserType.Producer, 1 ) consumer_state_dKV = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 2 @@ -1570,7 +1564,8 @@ def compute_loop( # Mainloop for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - pipeline_S_P.consumer_wait(consumer_state_S_P) + pipeline_S_P.consumer_wait(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) @@ -1632,9 +1627,10 @@ def compute_loop( cute.arch.sync_warp() with cute.arch.elect_one(): - pipeline_S_P.consumer_release(consumer_state_S_P) + pipeline_S_P.consumer_release(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) cute.arch.mbarrier_arrive(LSE_empty_mbar_ptr) - consumer_state_S_P.advance() + # consumer_state_S_P_dP.advance() # --------------------------------------------- # dS.T = P.T * (dP.T - D) @@ -1642,7 +1638,8 @@ def compute_loop( cute.arch.mbarrier_wait(dPsum_full_mbar_ptr, consumer_phase_dPsum) consumer_phase_dPsum ^= 1 - pipeline_dP.consumer_wait(consumer_state_dP) + pipeline_dP.consumer_wait(consumer_state_S_P_dP) + # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) pipeline_dS.producer_acquire(producer_state_dS) #### TMEM->RMEM (Load dP from TMEM) @@ -1689,9 +1686,11 @@ def compute_loop( cute.arch.sync_warp() with cute.arch.elect_one(): - pipeline_dP.consumer_release(consumer_state_dP) + # pipeline_dP.consumer_release(consumer_state_dP) + pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) - consumer_state_dP.advance() + consumer_state_S_P_dP.advance() + # consumer_phase_S_P_dP ^= 1 cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 541b0b5bed7..6228037d203 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -72,7 +72,10 @@ def stages(self) -> int: def index(self) -> Int32: # return self._phase_index & 0xFFFF # return self._phase_index & ((1 << self._log_stages) - 1) - return self._phase_index % self._stages + if const_expr(self._stages == 1): + return Int32(0) + else: + return self._phase_index % self._stages @property def phase(self) -> Int32: @@ -81,10 +84,16 @@ def phase(self) -> Int32: # take modulo 2. But in practice just passing the phase in without modulo works fine. # return (self._phase_index >> self._log_stages) % 2 # return self._phase_index >> self._log_stages - return self._phase_index // self._stages + if const_expr(self._stages == 1): + return self._phase_index + else: + return self._phase_index // self._stages def advance(self): - self._phase_index += 1 + if const_expr(self._stages == 1): + self._phase_index ^= 1 + else: + self._phase_index += 1 # def then_body(phase_index): # # XOR the phase bit and set the index to 0 From f29df7a1d32f466d5cae71894c83da2fbd0ea580 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 21 Oct 2025 01:23:58 -0400 Subject: [PATCH 177/258] [Cute,Bwd,Sm100] Add option for delay tma store --- flash_attn/cute/flash_bwd_sm100.py | 42 ++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 1c8d60b46e6..f3c6c307b69 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -712,8 +712,9 @@ def kernel( ) # AsyncThread producers and UMMA consumers + # Only 1 thread per warp will signal pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) ) # Compute pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) @@ -1695,7 +1696,9 @@ def compute_loop( cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) - pipeline_dS.producer_commit(producer_state_dS) + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() if const_expr(not self.use_tma_store): @@ -1773,6 +1776,7 @@ def dQacc_reduce( num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) tidx = cute.arch.thread_idx()[0] % num_reduce_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) + is_tma_warp = warp_idx == 0 # TMEM -> RMEM tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 @@ -1835,26 +1839,40 @@ def dQacc_reduce( barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) dQacc_reduce_barrier.arrive_and_wait() - for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 - tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] - tdQrdQ_r2s = cute.make_tensor( - tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape - ) - cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) + # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops + delay_tma_store = False + + def tma_store_fn(src_idx, dst_idx): + # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) dQacc_reduce_barrier.arrive_and_wait() - if warp_idx == 0: + # Copy from shared memory to global memory + if is_tma_warp: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( - sdQaccum[None, reduce_phase].iterator, - gdQaccum[None, stage, m_block].iterator, + sdQaccum[None, src_idx].iterator, + gdQaccum[None, dst_idx, m_block].iterator, self.tma_copy_bytes["dQ"], ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(1, read=read_flag) dQacc_reduce_barrier.arrive_and_wait() + + reduce_phase_prev, stage_prev = None, -1 + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] + tdQrdQ_r2s = cute.make_tensor( + tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape + ) + if const_expr(delay_tma_store): + if const_expr(stage > 0): + tma_store_fn(src_idx=reduce_phase_prev, dst_idx=stage_prev) + reduce_phase_prev, stage_prev = reduce_phase, stage + cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) + if const_expr(not delay_tma_store): + tma_store_fn(reduce_phase, stage) reduce_phase ^= 1 # Directly add to gmem, much slower # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) @@ -1867,6 +1885,8 @@ def dQacc_reduce( # tdQrdQ_r2s[4 * i + 3], # utils.elem_pointer(tdQgdQ, 4 * i), # ) + if const_expr(delay_tma_store): + tma_store_fn(src_idx=reduce_phase_prev, dst_idx=stage_prev) # semaphore release # NOTE: arrive_inc calls red_release which issues membar From 933b2c3ebb8a3da378f5fefb4e398c8a9970ad81 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 21 Oct 2025 14:50:53 -0400 Subject: [PATCH 178/258] Fix hopper cuda 13 build (#1949) --- hopper/setup.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hopper/setup.py b/hopper/setup.py index 74713208aa0..519d1c04f42 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -399,11 +399,18 @@ def nvcc_threads_args(): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") + elif bare_metal_version >= Version("13.0"): + # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/ + cccl_include = os.path.join(CUDA_HOME, "include", "cccl") + for env_var in ["CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"]: + current = os.environ.get(env_var, "") + os.environ[env_var] = cccl_include + (":" + current if current else "") # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. - if bare_metal_version != Version("12.8"): + # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain + if bare_metal_version >= Version("12.3") and bare_metal_version < Version("13.0") and bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", From a098f98b40f1d7761b0da6f7e5cfa9e9dfaeeeb4 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:07:21 -0700 Subject: [PATCH 179/258] [CuteDSL] Fix hash function for cute.jit decorator (#1953) --- flash_attn/cute/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 4db768e328c..f26f2cb8d80 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -29,6 +29,11 @@ def hash_callable(func: Callable) -> str: """Hash a callable based on the source code or bytecode and closure values.""" + if hasattr(func, "__wrapped__"): + # cute.jit returns a wrapper whose repr/closure changes per compile; hash the undecorated function. + base_func = func.__wrapped__ + func = base_func + try: data = inspect.getsource(func).encode() except (OSError, TypeError): @@ -40,7 +45,7 @@ def hash_callable(func: Callable) -> str: hasher = hashlib.sha256(data) if hasattr(func, "__closure__") and func.__closure__ is not None: - for cell in func.__closure__: + for idx, cell in enumerate(func.__closure__): cell_value = cell.cell_contents hasher.update(repr(cell_value).encode()) @@ -50,6 +55,7 @@ def hash_callable(func: Callable) -> str: def create_softcap_scoremod(softcap_val): inv_softcap = 1.0 / softcap_val + @cute.jit def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers): scores = acc_S_SSA * inv_softcap return scores * cute.math.tanh(scores, fastmath=True) From 143b0ba20df0aca7d968d8ef5852ed10fe09caab Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:11:37 -0400 Subject: [PATCH 180/258] Block Sparsity and Flex Attention mask mod support (#1942) * clean up and rebase for PR * add mask mod tests * add benchmarking files * refactor for better style * remove extraneous csrc * type hint buffers * refactor: order of non/overlap and modify blocksparse producer to agree with dense * change variable name back to buffers * remove unnecessary variable in first_half_block * restore erroneous packgqa deletion * add blocksparsity and mask_mod asserts to interface.py * fix rebase issues * Restore submodule and reset pointer to upstream/main * rename cutlass.const_expr to const_expr * support fully masked m blocks (i.e. skipped tiles) * remove outdated commented code --- flash_attn/cute/benchmark_mask_mod.py | 714 ++++++++++++++++++++++++++ flash_attn/cute/block_sparsity.py | 372 ++++++++++++++ flash_attn/cute/flash_fwd.py | 655 ++++++++++++++++++----- flash_attn/cute/interface.py | 94 +++- flash_attn/cute/mask.py | 64 ++- flash_attn/cute/mask_definitions.py | 220 ++++++++ tests/cute/test_flash_attn.py | 14 +- tests/cute/test_mask_mod.py | 570 ++++++++++++++++++++ 8 files changed, 2556 insertions(+), 147 deletions(-) create mode 100644 flash_attn/cute/benchmark_mask_mod.py create mode 100644 flash_attn/cute/block_sparsity.py create mode 100644 flash_attn/cute/mask_definitions.py create mode 100644 tests/cute/test_mask_mod.py diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py new file mode 100644 index 00000000000..071b4e02a58 --- /dev/null +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -0,0 +1,714 @@ +""" +FlashAttention benchmarking script with Flex Attention-style +mask mod support and varlen sequences. +""" + +from dataclasses import dataclass +import math +from pickle import FALSE +from typing import Any, Dict, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import numpy as np +import torch + +from flash_fwd import FlashAttentionForwardSm90 +from mask_definitions import ( + MASK_FUNCTIONS, + random_doc_id_tensor, + create_cute_sliding_window_mask, + create_flex_sliding_window_mask, +) +from block_sparsity import compute_block_sparsity + + +@dataclass +class BenchmarkConfig: + """Benchmark configuration""" + + # Model parameters + headdim: int + headdim_v: int + nheads: int + nheads_kv: int + dtype: torch.dtype + + # Sequence parameters + batch_size: int = 2 + seqlen_q: int = 8192 + seqlen_k: int = 8192 + + # Varlen parameters + use_varlen: bool = False + min_seqlen_q: Optional[int] = None # If None, use seqlen_q // 2 + max_seqlen_q: Optional[int] = None # If None, use seqlen_q + min_seqlen_k: Optional[int] = None # If None, use seqlen_k // 2 + max_seqlen_k: Optional[int] = None # If None, use seqlen_k + + # Mask parameters + use_mask_mod: bool = True + mask_mod_name: str = "causal" + has_buffers: bool = mask_mod_name == "document" + + # Sliding window parameter (used when mask_mod_name == "sliding_window") + window_size: int = 128 + + # Attention parameters + causal: bool = False + is_local: bool = False + window_left: Optional[int] = 128 # For base Flash Attention local + window_right: Optional[int] = 0 # For base Flash Attention local + softcap: Optional[float] = None + use_learnable_sink: bool = False + + # Kernel configuration + tile_m: int = 128 + tile_n: int = 128 + num_stages: int = 2 + num_threads: int = 384 + intra_wg_overlap: bool = True + mma_pv_is_rs: bool = True + + # Benchmark parameters + warmup_iters: int = 5 + benchmark_iters: int = 20 + verbose: bool = False + seed: int = 42 + + +class FlashAttentionBenchmark: + def __init__(self, config: BenchmarkConfig): + self.config = config + + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # Verify SM90 compute capability + compute_capability = torch.cuda.get_device_capability() + assert compute_capability >= (9, 0), ( + f"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}" + ) + # causal overrides use_mask_mod + if config.causal: + config.use_mask_mod = False + + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + # Use factory function for custom window size + self.mask_mod_cute = create_cute_sliding_window_mask(config.window_size) + self.mask_mod_flex = create_flex_sliding_window_mask(config.window_size) + else: + self.mask_mod_cute, self.mask_mod_flex = MASK_FUNCTIONS[config.mask_mod_name] + else: + self.mask_mod_cute = None + self.mask_mod_flex = None + + self._validate_config() + + def _validate_config(self): + config = self.config + + assert config.headdim <= 256, "headdim must be <= 256" + assert config.headdim_v <= 256, "headdim_v must be <= 256" + assert config.nheads % config.nheads_kv == 0, "nheads must be divisible by nheads_kv" + + alignment = 16 // config.dtype.itemsize + assert config.headdim % alignment == 0, f"headdim must be divisible by {alignment}" + assert config.headdim_v % alignment == 0, f"headdim_v must be divisible by {alignment}" + + # Validate is_local configuration + if config.is_local: + assert config.window_left is not None or config.window_right is not None, ( + "When is_local=True, at least one of window_left or window_right must be set" + ) + assert not config.use_mask_mod, ( + "Cannot use both is_local and use_mask_mod simultaneously" + ) + assert not config.causal, "Cannot use both is_local and causal simultaneously" + + # Validate mask_mod configuration + if config.use_mask_mod and config.mask_mod_name == "sliding_window": + assert config.window_size > 0, ( + "window_size must be positive when using sliding_window mask" + ) + + def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]: + """Generate random sequence lengths and compute cumulative lengths.""" + seqlens = torch.randint( + min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device="cuda" + ) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqlens, dtype=torch.int32, dim=0), + ] + ) + + total_tokens = cu_seqlens[-1].item() + return cu_seqlens, total_tokens + + def _create_tensors(self) -> Dict[str, torch.Tensor]: + config = self.config + device = "cuda" + + if config.use_varlen: + # Set defaults for varlen range + min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2 + max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q + min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2 + max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k + + # Generate cu_seqlens + cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q) + cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k) + + # Varlen shape: (total_tokens, nheads, headdim) + q = torch.randn( + total_q, config.nheads, config.headdim, dtype=config.dtype, device=device + ) + k = torch.randn( + total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device + ) + v = torch.randn( + total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device + ) + out = torch.empty( + total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device + ) + lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + "cu_seqlens_q": cu_seqlens_q.contiguous(), + "cu_seqlens_k": cu_seqlens_k.contiguous(), + } + + if config.verbose: + print(f"Varlen: total_q={total_q}, total_k={total_k}") + print(f"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}") + print(f"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}") + else: + # Standard shape: (batch, seqlen, nheads, headdim) + q = torch.randn( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim, + dtype=config.dtype, + device=device, + ) + k = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim, + dtype=config.dtype, + device=device, + ) + v = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + out = torch.empty( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + lse = torch.empty( + config.batch_size, + config.nheads, + config.seqlen_q, + dtype=torch.float32, + device=device, + ) + + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + if config.use_learnable_sink: + learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) + + tensors["learnable_sink"] = learnable_sink.contiguous() + + # Compute block sparsity when using mask_mod + if config.use_mask_mod: + if config.mask_mod_name == "document": + doc_id = random_doc_id_tensor( + config.batch_size, config.nheads, config.seqlen_q, device=device + ) + tensors["buffers"] = [doc_id.contiguous()] + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=self.config, + mask_mod_flex=self.mask_mod_flex, + device=device, + cu_seqlens_q=tensors.get("cu_seqlens_q"), + cu_seqlens_k=tensors.get("cu_seqlens_k"), + buffers=tensors.get("buffers"), + ) + + if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): + tensors["full_block_cnt"] = full_cnt.contiguous() + tensors["full_block_idx"] = full_idx.contiguous() + tensors["mask_block_cnt"] = mask_cnt.contiguous() + tensors["mask_block_idx"] = mask_idx.contiguous() + + if config.verbose: + total_full = full_cnt.sum().item() + total_partial = mask_cnt.sum().item() + + if config.use_varlen: + # Compute max possible blocks across all sequences + max_blocks = 0 + for i in range(config.batch_size): + seq_len_q = ( + tensors["cu_seqlens_q"][i + 1] - tensors["cu_seqlens_q"][i] + ).item() + seq_len_k = ( + tensors["cu_seqlens_k"][i + 1] - tensors["cu_seqlens_k"][i] + ).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + max_blocks += n_blocks_q * n_blocks_k * config.nheads + else: + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size + + skipped = max_blocks - total_full - total_partial + print( + f"Block stats: Full={total_full}, Partial={total_partial}, " + f"Skipped={skipped}/{max_blocks}" + ) + + return tensors + + def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]: + config = self.config + + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[config.dtype] + + qhead_per_kvhead = config.nheads // config.nheads_kv + kernel = FlashAttentionForwardSm90( + cute_dtype, + config.headdim, + config.headdim_v, + qhead_per_kvhead, + is_causal=config.causal, + is_local=config.is_local, + pack_gqa=False, + tile_m=config.tile_m, + tile_n=config.tile_n, + num_stages=config.num_stages, + num_threads=config.num_threads, + intra_wg_overlap=config.intra_wg_overlap, + mma_pv_is_rs=config.mma_pv_is_rs, + mask_mod=self.mask_mod_cute, + Q_in_regs=False, + has_buffers=config.has_buffers, + ) + + softmax_scale = 1.0 / math.sqrt(config.headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Convert tensors to cute + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack(tensors["out"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["out"].ndim - 1 + ) + lse_cute = from_dlpack(tensors["lse"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=tensors["lse"].ndim - 1 + ) + + # Varlen tensors + cu_seqlens_q_cute = ( + from_dlpack(tensors["cu_seqlens_q"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_q" in tensors + else None + ) + cu_seqlens_k_cute = ( + from_dlpack(tensors["cu_seqlens_k"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_k" in tensors + else None + ) + learnable_sink_cute = ( + from_dlpack(tensors["learnable_sink"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "learnable_sink" in tensors + else None + ) + + # Block sparsity tensors + full_block_cnt_cute = ( + from_dlpack(tensors["full_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if "full_block_cnt" in tensors + else None + ) + full_block_idx_cute = ( + from_dlpack(tensors["full_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if "full_block_idx" in tensors + else None + ) + mask_block_cnt_cute = ( + from_dlpack(tensors["mask_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if "mask_block_cnt" in tensors + else None + ) + mask_block_idx_cute = ( + from_dlpack(tensors["mask_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if "mask_block_idx" in tensors + else None + ) + + if "buffers" in tensors: + buffers_cute = [] + for i in range(len(tensors["buffers"])): + buf = from_dlpack(tensors["buffers"][i].detach(), assumed_align=4) + buffers_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + + else: + buffers_cute = None + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(config.window_left) if config.window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(config.window_right) if config.window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + learnable_sink_cute, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + buffers_cute, + # None, + ) + + args = ( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, + None, + None, + window_left_cute, + window_right_cute, + learnable_sink_cute, + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + buffers_cute, + # None, + ) + + return compiled, args + + def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: + config = self.config + + # Estimate sparsity for known mask patterns + if config.is_local: + # Local attention with window_left and window_right + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 # +1 for current position + sparsity_ratio = min(1.0, total_window / config.seqlen_k) + elif config.use_mask_mod: + if config.mask_mod_name in ["identity", "identity_partial"]: + sparsity_ratio = 1.0 + elif config.mask_mod_name in ["causal", "block_causal"]: + sparsity_ratio = 0.5 + elif config.mask_mod_name == "sliding_window": + # Use configured window size + sparsity_ratio = min(1.0, config.window_size / config.seqlen_k) + elif config.mask_mod_name == "block_diagonal": + block_size = 64 + num_blocks = (config.seqlen_k + block_size - 1) // block_size + sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 + elif config.mask_mod_name == "document": + vals = tensors["buffers"][0] + val_mask = torch.ones_like(vals, dtype=torch.bool) + val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] + total = torch.where(val_mask, vals.square(), 0).sum() + sparsity_ratio = total / (config.seqlen_q * config.seqlen_k) + else: + sparsity_ratio = 1.0 + elif config.causal: + sparsity_ratio = 0.5 + else: + sparsity_ratio = 1.0 + + if config.use_varlen: + # Compute FLOPs per sequence and sum + total_flops = 0 + cu_q = tensors["cu_seqlens_q"] + cu_k = tensors["cu_seqlens_k"] + for i in range(config.batch_size): + seq_len_q = (cu_q[i + 1] - cu_q[i]).item() + seq_len_k = (cu_k[i + 1] - cu_k[i]).item() + + # Adjust sparsity for local attention in varlen case + if config.is_local: + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 + seq_sparsity = min(1.0, total_window / seq_len_k) + elif config.use_mask_mod and config.mask_mod_name == "sliding_window": + seq_sparsity = min(1.0, config.window_size / seq_len_k) + else: + seq_sparsity = sparsity_ratio + + num_cells = int(seq_len_q * seq_len_k * seq_sparsity) + + if config.headdim == config.headdim_v: + flops_this_seq = 4 * config.nheads * num_cells * config.headdim + else: + flops_this_seq = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + total_flops += flops_this_seq + return total_flops + else: + num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio) + if config.headdim == config.headdim_v: + flops_per_batch = 4 * config.nheads * num_cells * config.headdim + else: + flops_per_batch = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + return flops_per_batch * config.batch_size + + def benchmark(self) -> Dict[str, Any]: + config = self.config + + tensors = self._create_tensors() + compiled_kernel, args = self._compile_kernel(tensors) + + # Warmup + for _ in range(config.warmup_iters): + compiled_kernel(*args) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.benchmark_iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + compiled_kernel(*args) + end.record() + torch.cuda.synchronize() + + times.append(start.elapsed_time(end)) + + times_tensor = torch.tensor(times) + mean_time = times_tensor.mean().item() + std_time = times_tensor.std().item() if len(times) > 1 else 0.0 + + total_flops = self._calculate_flops(tensors) + tflops = total_flops / (mean_time * 1e-3) / 1e12 + + # Bandwidth calculation + bytes_per_element = config.dtype.itemsize + if config.use_varlen: + total_q = tensors["q"].shape[0] + total_k = tensors["k"].shape[0] + memory_accessed = ( + total_q * config.nheads * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim_v * bytes_per_element + + total_q * config.nheads * config.headdim_v * bytes_per_element + ) + else: + memory_accessed = ( + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim_v + * bytes_per_element + + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim_v + * bytes_per_element + ) + bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9 + + results = { + "mean_time_ms": mean_time, + "std_time_ms": std_time, + "tflops": tflops, + "bandwidth_gbps": bandwidth_gbps, + } + + if config.verbose: + self._print_results(results) + + return results + + def _print_results(self, results: Dict[str, Any]): + config = self.config + + # Basic configuration + if config.use_varlen: + print( + f"Shape: B={config.batch_size} (varlen), HD={config.headdim}, " + f"NH={config.nheads}, NKV={config.nheads_kv}" + ) + else: + print( + f"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, " + f"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}" + ) + + # Attention pattern + attn_info = [] + if config.causal: + attn_info.append("causal") + if config.is_local: + window_info = f"local(L={config.window_left},R={config.window_right})" + attn_info.append(window_info) + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + attn_info.append(f"mask_mod={config.mask_mod_name}(w={config.window_size})") + else: + attn_info.append(f"mask_mod={config.mask_mod_name}") + if config.use_varlen: + attn_info.append("varlen") + if attn_info: + print(f"Attention: {', '.join(attn_info)}") + + # Performance metrics + print(f"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms") + print(f"Throughput: {results['tflops']:.2f} TFLOPS") + print(f"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s") + + +if __name__ == "__main__": + B = 2 + config = BenchmarkConfig( + headdim=128, + headdim_v=128, + nheads=16, + nheads_kv=16, + dtype=torch.bfloat16, + batch_size=B, + # batch_size=1, + seqlen_q=16384 // B, + # seqlen_q=128, + seqlen_k=16384 // B, + # seqlen_k=192, + use_varlen=False, + use_mask_mod=True, + mask_mod_name="identity", + window_size=128, # Configurable window size for mask_mod + use_learnable_sink=False, + causal=False, + is_local=False, + verbose=True, + ) + + # Example 2: Base Flash Attention Local + # config = BenchmarkConfig( + # headdim=64, + # headdim_v=64, + # nheads=64, + # nheads_kv=8, + # dtype=torch.bfloat16, + # batch_size=2, + # seqlen_q=8192, + # seqlen_k=8192, + # use_varlen=False, + # use_mask_mod=False, + # causal=False, + # is_local=True, + # window_left=128, # Left window size for base local attention + # window_right=0, # Right window size for base local attention + # verbose=True, + # ) + + benchmark = FlashAttentionBenchmark(config) + results = benchmark.benchmark() diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py new file mode 100644 index 00000000000..ce05cae1438 --- /dev/null +++ b/flash_attn/cute/block_sparsity.py @@ -0,0 +1,372 @@ +""" +Computes block-sparse attention masks for Flex Attention. + +This utility generates block sparsity patterns based on common attention masking +strategies (e.g., causal, sliding window). The resulting tensors define which +blocks are fully computed, which are partially computed (requiring a mask), and +which are skipped entirely. This is a temporary solution intended to be replaced +by a more robust preprocessing kernel in the future. +""" + +from typing import Tuple, Optional, Callable, List +import torch + +# placeholder +Config = type("Config", (), {}) + +def compute_block_sparsity( + config: Config, + mask_mod_flex: Optional[Callable], + device: str, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + buffers: Optional[List[torch.Tensor]] = None, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Computes block sparsity tensors from a given masking function. + + This function serves as the main entry point for generating block-sparse masks. + It dispatches to specialized handlers for variable-length and fixed-length + sequences. + + Args: + config: A configuration object containing model and tiling parameters. + mask_mod_flex: The mask function for generic flex attention patterns. + device: The device to create tensors on (e.g., 'cuda'). + cu_seqlens_q: Cumulative sequence lengths for Q (for varlen). + cu_seqlens_k: Cumulative sequence lengths for K (for varlen). + buffers: A list of auxiliary tensors, e.g., for document masking. + + Returns: + A tuple of four tensors: + - `full_block_cnt`: (batch, nheads, n_blocks_q) - Count of full n blocks per m block. + - `full_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of full n blocks. + - `mask_block_cnt`: (batch, nheads, n_blocks_q) - Count of partial n blocks per m block. + - `mask_block_idx`: (batch, nheads, n_blocks_q, max_n_blocks) - Indices of partial n blocks. + Returns (None, None, None, None) if masking is disabled. + """ + if not config.use_mask_mod or mask_mod_flex is None: + return None, None, None, None + + if cu_seqlens_q is not None: + # Handle variable-length sequences + return _compute_varlen_sparsity(config, mask_mod_flex, device, cu_seqlens_q, cu_seqlens_k) + else: + # Handle fixed-length sequences + return _compute_sparsity(config, device, buffers) + +## --------------------------------------------------------------------------- +## Fixed-Length Sequence Kernels +## --------------------------------------------------------------------------- + +def _compute_sparsity( + config: Config, device: str, buffers: Optional[List[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for fixed-length sequences.""" + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + # Pre-allocate output tensors + full_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) + mask_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) + full_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) + + # --- Identity Mask --- + # All blocks are fully computed. + if config.mask_mod_name == "identity": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + full_block_cnt[:, :, q_block_idx] = n_blocks_k + full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Identity Partial Mask --- + # All blocks are partially computed (masked). + elif config.mask_mod_name == "identity_partial": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + mask_block_cnt[:, :, q_block_idx] = n_blocks_k + mask_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks + + # --- Block Causal Mask --- + elif config.mask_mod_name == "block_causal": + k_blocks = torch.arange(n_blocks_k, device=device) + for q_block_idx in range(n_blocks_q): + causal_indices = k_blocks[k_blocks <= q_block_idx] + num_causal_indices = len(causal_indices) + if num_causal_indices > 0: + full_block_cnt[:, :, q_block_idx] = num_causal_indices + full_block_idx[:, :, q_block_idx, :num_causal_indices] = causal_indices + + # --- Causal and Sliding Window Masks --- + elif config.mask_mod_name in ["causal", "sliding_window"]: + q_block_indices = torch.arange(n_blocks_q, device=device) + k_block_indices = torch.arange(n_blocks_k, device=device) + + q_starts = q_block_indices * config.tile_m + q_ends = torch.minimum((q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device)) + k_starts = k_block_indices * config.tile_n + k_ends = torch.minimum((k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device)) + + # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) + q_starts, q_ends = q_starts.unsqueeze(1), q_ends.unsqueeze(1) + k_starts, k_ends = k_starts.unsqueeze(0), k_ends.unsqueeze(0) + + offset = config.seqlen_k - config.seqlen_q + + if config.mask_mod_name == "causal": + is_full = (k_ends - 1) <= (q_starts + offset) + # min(k_pos) <= max(q_pos) AND not is_full. + is_partial = (k_starts <= (q_ends - 1 + offset)) & ~is_full + + else: # sliding_window + window_size = getattr(config, 'window_size', 1024) + is_full = (k_ends - 1 <= q_starts + offset) & (k_starts >= q_ends - 1 + offset - (window_size - 1)) + # A block is EMPTY if no (q, k) pairs satisfy the constraint. + is_empty = (k_starts > q_ends - 1 + offset) | (k_ends - 1 < q_starts + offset - (window_size - 1)) + # A block is PARTIAL if it's not empty and not full. + is_partial = ~is_empty & ~is_full + + # Populate indices based on the computed block classifications + for q_block_idx in range(n_blocks_q): + full_indices = k_block_indices[is_full[q_block_idx]] + if len(full_indices) > 0: + full_block_cnt[:, :, q_block_idx] = len(full_indices) + full_block_idx[:, :, q_block_idx, :len(full_indices)] = full_indices + + partial_indices = k_block_indices[is_partial[q_block_idx]] + if len(partial_indices) > 0: + mask_block_cnt[:, :, q_block_idx] = len(partial_indices) + mask_block_idx[:, :, q_block_idx, :len(partial_indices)] = partial_indices + + elif config.mask_mod_name == "document": + raise NotImplementedError("Block sparsity for document masking not yet implemented") + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + +## --------------------------------------------------------------------------- +## Variable-Length Sequence Kernels +## --------------------------------------------------------------------------- + +def _compute_varlen_sparsity( + config: Config, + mask_mod_flex: Callable, + device: str, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes block sparsity for variable-length sequences.""" + assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" + assert cu_seqlens_q.shape[0] == config.batch_size + 1 + assert cu_seqlens_k.shape[0] == config.batch_size + 1 + + # In varlen, each sequence can have a different number of Q blocks. + # We pad up to the maximum number of Q blocks in the batch. + max_m_blocks = 0 + for seq_idx in range(config.batch_size): + seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + max_m_blocks = max(max_m_blocks, n_blocks_q) + + # The number of K blocks is determined by the total length of all sequences. + total_k_len = cu_seqlens_k[-1].item() + max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n + + # Pre-allocate padded output tensors + full_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) + mask_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) + full_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + + # Process each sequence in the batch individually + for seq_idx in range(config.batch_size): + seq_start_q = cu_seqlens_q[seq_idx].item() + seq_end_q = cu_seqlens_q[seq_idx + 1].item() + seq_len_q = seq_end_q - seq_start_q + + seq_start_k = cu_seqlens_k[seq_idx].item() + seq_end_k = cu_seqlens_k[seq_idx + 1].item() + seq_len_k = seq_end_k - seq_start_k + + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + + # Global block indices are relative to the start of the entire batch tensor + first_m_block_global = seq_start_q // config.tile_m + first_n_block_global = seq_start_k // config.tile_n + + common_args = { + "full_block_cnt": full_block_cnt, "full_block_idx": full_block_idx, + "mask_block_cnt": mask_block_cnt, "mask_block_idx": mask_block_idx, + "seq_idx": seq_idx, "n_blocks_q": n_blocks_q, "n_blocks_k": n_blocks_k, + "seq_start_q": seq_start_q, "seq_end_q": seq_end_q, + "seq_start_k": seq_start_k, "seq_end_k": seq_end_k, + "first_n_block_global": first_n_block_global, + "tile_m": config.tile_m, "tile_n": config.tile_n, "device": device + } + + if config.mask_mod_name == "causal": + _compute_causal_varlen_blocks(**common_args) + elif config.mask_mod_name == "sliding_window": + window_size = getattr(config, 'window_size', 1024) + _compute_sliding_window_varlen_blocks(**common_args, window_size=window_size) + elif config.mask_mod_name == "identity": + _compute_identity_varlen_blocks( + full_block_cnt, full_block_idx, seq_idx, + n_blocks_q, n_blocks_k, first_n_block_global, device + ) + else: + # Generic case relies on sampling the user-provided mask function + _compute_generic_varlen_blocks( + **common_args, mask_mod_flex=mask_mod_flex, + seq_len_q=seq_len_q, seq_len_k=seq_len_k, + num_heads=config.nheads, nheads_kv=config.nheads_kv, + ) + + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + +def _classify_varlen_block( + m_local: int, n_local: int, seq_start_q: int, seq_end_q: int, + seq_start_k: int, seq_end_k: int, tile_m: int, tile_n: int, + is_full_fn: Callable, is_partial_fn: Callable +) -> Tuple[bool, bool]: + """Helper to classify a varlen block as full, partial, or empty.""" + m_start_global = seq_start_q + m_local * tile_m + m_end_global = min(seq_start_q + (m_local + 1) * tile_m, seq_end_q) + n_start_global = seq_start_k + n_local * tile_n + n_end_global = min(seq_start_k + (n_local + 1) * tile_n, seq_end_k) + + # Use sequence-local coordinates for the logical check + m_start_local = m_start_global - seq_start_q + m_end_local = m_end_global - seq_start_q + n_start_local = n_start_global - seq_start_k + n_end_local = n_end_global - seq_start_k + + is_full = is_full_fn(m_start_local, m_end_local, n_start_local, n_end_local) + is_partial = is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full + + # Any block that touches the sequence boundary is partial because it requires masking. + at_boundary = (m_end_global > seq_end_q) or (n_end_global > seq_end_k) + + return is_full and not at_boundary, is_partial or (is_full and at_boundary) + +def _compute_causal_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + seq_idx, n_blocks_q, n_blocks_k, + seq_start_q, seq_end_q, seq_start_k, seq_end_k, + first_n_block_global, tile_m, tile_n, device, **kwargs +): + """Computes causal block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: (m_start >= n_end - 1) + is_partial_fn = lambda m_start, m_end, n_start, n_end: (m_end - 1 >= n_start) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, + tile_m, tile_n, is_full_fn, is_partial_fn + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + +def _compute_sliding_window_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + seq_idx, n_blocks_q, n_blocks_k, + seq_start_q, seq_end_q, seq_start_k, seq_end_k, + first_n_block_global, tile_m, tile_n, window_size, device, **kwargs +): + """Computes sliding window block sparsity for a single varlen sequence.""" + is_full_fn = lambda m_start, m_end, n_start, n_end: \ + (n_end - 1 <= m_start) and (n_start >= m_start - window_size + 1) + is_partial_fn = lambda m_start, m_end, n_start, n_end: \ + not ((n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1)) + + for m_local in range(n_blocks_q): + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + is_full, is_partial = _classify_varlen_block( + m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, + tile_m, tile_n, is_full_fn, is_partial_fn + ) + n_block_global = first_n_block_global + n_local + if is_full: + full_blocks.append(n_block_global) + elif is_partial: + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, :, m_local] = len(full_blocks) + full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + +def _compute_identity_varlen_blocks( + full_block_cnt, full_block_idx, seq_idx, n_blocks_q, + n_blocks_k, first_n_block_global, device, **kwargs +): + """Computes identity (all-attend) block sparsity for a single varlen sequence.""" + n_blocks_global = torch.arange( + first_n_block_global, first_n_block_global + n_blocks_k, + device=device, dtype=torch.int32 + ) + for m_local in range(n_blocks_q): + full_block_cnt[seq_idx, :, m_local] = n_blocks_k + full_block_idx[seq_idx, :, m_local, :n_blocks_k] = n_blocks_global + +def _compute_generic_varlen_blocks( + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, + mask_mod_flex, seq_idx, num_heads, n_blocks_q, n_blocks_k, + seq_len_q, seq_len_k, first_n_block_global, + tile_m, tile_n, nheads_kv, device, **kwargs +): + """Generic sampling-based block classification for a varlen sequence.""" + qhead_per_kvhead = num_heads // nheads_kv + + for h_q in range(num_heads): + h_kv = h_q // qhead_per_kvhead + for m_local in range(n_blocks_q): + m_start_local = m_local * tile_m + m_end_local = min((m_local + 1) * tile_m, seq_len_q) + + full_blocks, partial_blocks = [], [] + for n_local in range(n_blocks_k): + n_start_local = n_local * tile_n + n_end_local = min((n_local + 1) * tile_n, seq_len_k) + + # Sample points within the block (corners and center) to classify it. + # Coordinates are sequence-local, as required by mask_mod_flex. + sample_positions = [ + (m_start_local, n_start_local), (m_start_local, n_end_local - 1), + (m_end_local - 1, n_start_local), (m_end_local - 1, n_end_local - 1), + ((m_start_local + m_end_local) // 2, (n_start_local + n_end_local) // 2), + ] + + unmasked_count = sum( + 1 for q_pos, k_pos in sample_positions + if mask_mod_flex(seq_idx, h_q, q_pos, k_pos, seq_len_q, seq_len_k) + ) + + n_block_global = first_n_block_global + n_local + if unmasked_count == len(sample_positions): # All samples unmasked -> full + full_blocks.append(n_block_global) + elif unmasked_count > 0: # Some unmasked -> partial + partial_blocks.append(n_block_global) + + if full_blocks: + full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) + full_block_idx[seq_idx, h_q, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + if partial_blocks: + mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) + mask_block_idx[seq_idx, h_q, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) \ No newline at end of file diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 92382ae8b42..4922a1534c9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -7,14 +7,14 @@ import math from types import SimpleNamespace -from typing import Type, Callable, Optional +from typing import Type, Callable, Optional, List from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Boolean, const_expr +from cutlass import Constexpr, Float32, Int32, const_expr, Boolean from cutlass.cute.nvgpu import cpasync, warp, warpgroup from cutlass.cute.arch import ProxyKind, SharedSpace import cutlass.utils as utils_basic @@ -54,7 +54,8 @@ def __init__( num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, - score_mod: cutlass.Constexpr | None = None, + score_mod: Optional[cutlass.Constexpr] = None, + mask_mod: Optional[cutlass.Constexpr] = None, has_buffers: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -73,6 +74,8 @@ def __init__( :param is_causal: is causal :param score_mod: A callable that takes the attention scores and applies a modification. Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any`` + :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, buffers) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -94,8 +97,9 @@ def __init__( self.num_stages = num_stages self.Q_in_regs = Q_in_regs self.score_mod = score_mod + self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if cutlass.const_expr(has_buffers): + if const_expr(has_buffers): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @@ -601,7 +605,7 @@ def __call__( softmax_scale = Float32(softmax_scale) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): + if const_expr(buffers is not None): seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) @@ -938,7 +942,7 @@ def load_V_next(): # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) - if cutlass.const_expr(score_mod is not None): + if const_expr(score_mod is not None): self.apply_score_mod( mma_params.thr_mma_qk, batch_idx, @@ -984,10 +988,17 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase): arch = 90 - def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, **kwargs): + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + **kwargs, + ): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap self.mma_pv_is_rs = mma_pv_is_rs + def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( @@ -1107,19 +1118,26 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - buffers=None, + full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + buffers: Optional[list[cute.Tensor]] = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ + self._check_type( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) + # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] @@ -1146,6 +1164,7 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 + self.use_block_sparsity = const_expr(mask_block_cnt is not None and full_block_cnt is not None) self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0) self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa @@ -1255,7 +1274,7 @@ def __call__( window_size_right = Int32(window_size_right) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): + if const_expr(buffers is not None): seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) @@ -1281,6 +1300,10 @@ def __call__( window_size_left, window_size_right, learnable_sink, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1327,6 +1350,10 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1342,7 +1369,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], - buffers=None, + buffers=Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1436,6 +1463,10 @@ def kernel( pipeline_k, pipeline_v, mbar_ptr_Q, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1474,6 +1505,10 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, buffers, fastdiv_mods, ) @@ -1493,6 +1528,10 @@ def load( pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1527,44 +1566,175 @@ def load( load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - # First iteration: load both Q & K with the same mbarrier - n_block = n_block_max - 1 - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 - ) - if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - load_K(src_idx=n_block, producer_state=kv_producer_state) - if const_expr(not self.intra_wg_overlap): - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 1 - i - 1 - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() - else: - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block_prev = n_block_max - i - 1 - n_block = n_block_prev - 1 - kv_producer_state_prev = kv_producer_state.clone() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) - n_block = n_block_min - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() + else: + # ========================================== + # Flex Attention blocksparsity + # ========================================== + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(not self.intra_wg_overlap): + if curr_mask_block_cnt > 0: + # First mask block - load with Q + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1] + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask, producer_state=kv_producer_state) + kv_producer_state.advance() + + # Remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask, producer_state=kv_producer_state) + kv_producer_state.advance() + + if curr_full_block_cnt > 0: + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + # must load Q if not loaded in mask loop + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + for j in cutlass.range(1, curr_full_block_cnt): + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full, producer_state=kv_producer_state) + kv_producer_state.advance() + + else: + # ========================================== + # Overlap path + # ========================================== + + # Load Q with the first K block (whether mask or full) + n_block_first = -1 + if curr_mask_block_cnt > 0: + n_block_first = curr_mask_block_idx[curr_mask_block_cnt - 1] + elif curr_full_block_cnt > 0: + n_block_first = curr_full_block_idx[curr_full_block_cnt - 1] + + if n_block_first >= 0: + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + + if curr_mask_block_cnt > 0: + # Staggered loading for remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + n_block_mask_prev = curr_mask_block_idx[curr_mask_block_cnt - i] + n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_mask, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev) + + # Handle transition from mask to full blocks + if curr_full_block_cnt > 0: + # Load first full block K, last mask block V + n_block_mask_last = curr_mask_block_idx[0] + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + else: + # No full blocks, just load last mask block V + n_block_mask_last = curr_mask_block_idx[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state) + kv_producer_state.advance() + + if curr_full_block_cnt > 0: + # Staggered loading for remaining full blocks ( + for j in cutlass.range(1, curr_full_block_cnt): + n_block_full_prev = curr_full_block_idx[curr_full_block_cnt - j] + n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_full_prev, producer_state=kv_producer_state_prev) + + # Load last full block V + n_block_full_last = curr_full_block_idx[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_full_last, producer_state=kv_producer_state) + kv_producer_state.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1601,7 +1771,11 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - buffers=None, + full_block_cnt: Optional[cute.Tensor], + full_block_idx: Optional[cute.Tensor], + mask_block_cnt: Optional[cute.Tensor], + mask_block_idx: Optional[cute.Tensor], + buffers: Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -1663,6 +1837,20 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + softmax=softmax, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + ) while work_tile.is_valid_tile: # if work_tile.is_valid_tile: @@ -1671,18 +1859,31 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, - mask_causal=self.is_causal, mask_local=self.is_local, + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + buffers=buffers, ) score_mod_fn = None if const_expr(self.score_mod is not None): score_mod_fn = partial( self.apply_score_mod, - thr_mma_qk, batch_idx, head_idx, m_block, - softmax_scale=softmax_scale, buffers=buffers, fastdiv_mods=fastdiv_mods, + thr_mma_qk=thr_mma_qk, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + softmax_scale=softmax_scale, + buffers=buffers, + fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( - mma_one_n_block_all, softmax=softmax, score_mod_fn=score_mod_fn + mma_one_n_block_all, + softmax=softmax, + score_mod_fn=score_mod_fn, ) # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): @@ -1705,87 +1906,226 @@ def mma( # We also need masking on S if it's causal, for the last several blocks. # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False - # First iteration with seqlen masking - if const_expr(self.intra_wg_overlap): - pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) - acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) - pipeline_k.consumer_release(kv_consumer_state) - # Use vectorized score modification - if cutlass.const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block_max - 1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - softmax.online_softmax(acc_S, is_first=True) - tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_thr_copy_P, tPrP, tPsP) - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - # acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1, - mma_pv_fn=partial(mma_pv_fn, zero_init=True), - is_first_n_block=True, - mask_fn=partial(mask_fn, mask_seqlen=True), - ) - O_should_accumulate = True - # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + kv_consumer_state=kv_consumer_state, + mask_fn=mask_fn, + is_first_block=True, + ) + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + # acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), + n_block=n_block_max - 1, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_seqlen=True), ) O_should_accumulate = True - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min ) - O_should_accumulate = True - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), ) O_should_accumulate = True - # Last "half" iteration - if const_expr(self.intra_wg_overlap): - pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=not O_should_accumulate, wg_wait=0) - pipeline_v.consumer_release(kv_consumer_state) - kv_consumer_state.advance() + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + O_should_accumulate = True + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + else: - self.warp_scheduler_barrier_arrive() + # ========================================== + # Block sparsity + # ========================================== + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + # first masked and full blocks + mask_n_block = 0 + full_n_block = 0 + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + + if const_expr(not self.intra_wg_overlap): + # ========================================== + # Non-overlap path + # ========================================== + if curr_mask_block_cnt > 0: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + is_first_n_block=False, + ) + if curr_full_block_cnt == 0: + self.warp_scheduler_barrier_arrive() + + if curr_full_block_cnt > 0: + if curr_mask_block_cnt == 0: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=False, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + is_first_n_block=False, + ) + self.warp_scheduler_barrier_arrive() + else: + # ========================================== + # Overlap path + # ========================================== + + # Process first block + if curr_mask_block_cnt > 0: + kv_consumer_state = process_first_half_block( + n_block=mask_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + is_first_block=True, + ) + + # Process remaining mask blocks + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + + # Process full blocks + if curr_full_block_cnt > 0: + # If no mask blocks, first full block is the overall first + if curr_mask_block_cnt == 0: + kv_consumer_state = process_first_half_block( + n_block=full_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=None), + is_first_block=True, + ) + + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + ) + O_should_accumulate = True + + # Process remaining full blocks + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + ) + O_should_accumulate = True + + # Final PV gemm for last block + if curr_mask_block_cnt > 0 or curr_full_block_cnt > 0: + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + + if curr_mask_block_cnt + curr_full_block_cnt == 0: + softmax.reset() + acc_O.fill(0.0) + sink_val = None if const_expr(learnable_sink is not None): @@ -1815,6 +2155,74 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S=acc_S, n_block=n_block) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + # if pv gemm not rs + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + + # Advance state for next iteration + kv_consumer_state.advance() + + return kv_consumer_state + @cute.jit def mma_one_n_block( self, @@ -1840,10 +2248,13 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) @@ -1899,12 +2310,14 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - if const_expr(mask_fn is not None): + if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) @@ -1945,7 +2358,7 @@ def apply_score_mod( acc_S, n_block, softmax_scale, - buffers=None, + buffers=Optional[list[cute.Tensor]], fastdiv_mods=None, ): # Prepare index tensor diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 07a6c48bfbf..0615061a541 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype @@ -73,7 +74,12 @@ def _flash_attn_fwd( num_threads: int = 384, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, - score_mod: Callable | None = None, + score_mod: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, @@ -135,7 +141,22 @@ def _flash_attn_fwd( if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" + for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: + if t is not None: + assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" + assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + assert all( + t is None or t.is_cuda + for t in ( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + page_table, + learnable_sink, + full_block_cnt, full_block_idx, + mask_block_cnt, mask_block_idx, + ) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -183,6 +204,13 @@ def _flash_attn_fwd( for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None + + full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None + full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None + mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None + mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None + + if causal: window_size_right = 0 local = window_size_left is not None or window_size_right is not None @@ -202,22 +230,44 @@ def _flash_attn_fwd( # TODO: fix the varlen case if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): pack_gqa = False - + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None + if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None + use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None if score_mod is not None: - is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None if is_varlen: raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + if mask_mod is not None: + if not use_block_sparsity: + raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.") + if is_varlen: + raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + + if use_block_sparsity: + if is_varlen: + raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.") + if pack_gqa: + raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + cute_buffers = None if buffers is not None: cute_buffers = [from_dlpack(buf) for buf in buffers] compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, utils.hash_callable(score_mod) if score_mod is not None else None, + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, + score_mod_hash, mask_mod_hash, buffers is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, @@ -245,6 +295,9 @@ def _flash_attn_fwd( num_stages=2, num_threads=num_threads, Q_in_regs=False, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod, score_mod=score_mod, has_buffers=buffers is not None, ) @@ -264,18 +317,21 @@ def _flash_attn_fwd( else: raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") # TODO: check @can_implement - # TODO caching for buffers; cute_buffers _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, cute_buffers, + window_size_left, window_size_right, learnable_sink_tensor, + full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, + cute_buffers, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, cute_buffers + window_size_left, window_size_right, learnable_sink_tensor, + full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, + cute_buffers, ) return out, lse @@ -591,6 +647,11 @@ def forward( learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, pack_gqa: Optional[bool] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): out, lse = _flash_attn_fwd( q, @@ -603,6 +664,11 @@ def forward( learnable_sink=learnable_sink, softcap=softcap, pack_gqa=pack_gqa, + mask_mod=mask_mod, + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -706,6 +772,11 @@ def flash_attn_func( learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, pack_gqa: Optional[bool] = None, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): return FlashAttnFunc.apply( q, @@ -717,6 +788,11 @@ def flash_attn_func( learnable_sink, softcap, pack_gqa, + mask_mod, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, ) @@ -973,4 +1049,4 @@ def flash_attn_combine( lse = None _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) - return out, lse + return out, lse \ No newline at end of file diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 25c69a69bc0..0d78eb9e948 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional +from typing import Optional, Callable from dataclasses import dataclass import cutlass @@ -9,7 +9,6 @@ import flash_attn.cute.utils as utils - @cute.jit def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: # Bit manipulation, compiles down to the R2P instruction @@ -39,7 +38,6 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal for r in cutlass.range_constexpr(cute.size(X.shape[0])): X[r, c] = X[r, c] if in_bound else -Float32.inf - @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -55,12 +53,16 @@ class AttentionMask: def apply_mask( self, acc_S: cute.Tensor, - m_block: Int32, - n_block: Int32, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + n_block: cutlass.Int32, thr_mma: cute.TiledMma, mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + buffers: Optional[list[cute.Tensor]] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) @@ -76,17 +78,55 @@ def apply_mask( COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_mn[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset - if const_expr(not mask_causal and not mask_local): + if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): # The compiler now choses not to use R2P r2p = const_expr(False and not self.swap_AB) if const_expr(not r2p): + # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # FlexAttention mask mod + nrow = const_expr(cute.size(tScS_mn.shape[0])) + ncol = const_expr(cute.size(tScS_mn.shape[1])) + thr_col_offset = tScS_mn[0, 0][1] + + for r in cutlass.range_constexpr(nrow): + global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + + for col in cutlass.range_constexpr(ncol): + col_idx_local = t0ScS_mn[0, col][1] + # Convert to absolute column index + global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n + + cond = cutlass.Boolean( + mask_mod( + batch_idx, + head_idx, + tScS_mn[r, 0][0] + m_block * self.tile_m, + thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, + self.seqlen_q, + self.seqlen_k, + buffers, + ) + ) + if const_expr(mask_seqlen): + out_of_bounds = (global_row_idx >= self.seqlen_q) or ( + global_col_idx >= self.seqlen_k + ) + if out_of_bounds: + acc_S_mn[r, col] = -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + + else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row @@ -303,9 +343,9 @@ def apply_mask_sm100_transposed( tidx = cute.arch.thread_idx()[0] % 128 seqlenk_row_limit = self.seqlen_k - n_block * self.tile_n - if cutlass.const_expr(not mask_causal and not mask_local): - if cutlass.const_expr(mask_seqlen): - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + ncol = const_expr(cute.size(tScS_t2r.shape)) if tScS_t2r[0][0] >= seqlenk_row_limit: for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = -cutlass.Float32.inf @@ -313,12 +353,12 @@ def apply_mask_sm100_transposed( causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m row_idx = tScS_t2r[0][0] + n_block * self.tile_n - if cutlass.const_expr(mask_causal): + if const_expr(mask_causal): col_limit_left = row_idx + causal_row_offset - ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + ncol = const_expr(cute.size(tScS_t2r.shape)) # if tidx == 32 and wg_idx == 1: # cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1]) - if cutlass.const_expr(mask_seqlen): + if const_expr(mask_seqlen): if tScS_t2r[0][0] >= seqlenk_row_limit: col_limit_left = self.tile_m for i in cutlass.range(ncol, unroll_full=True): diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py new file mode 100644 index 00000000000..6b206fd6026 --- /dev/null +++ b/flash_attn/cute/mask_definitions.py @@ -0,0 +1,220 @@ +from typing import Callable, Optional + +import random +import math + +import cutlass +import cutlass.cute as cute +import torch + + +MaskModCallable = Optional[ + Callable[ + ["cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32"], + "cutlass.Boolean", + ] +] + + +# Flex Attention mask functions (PyTorch signatures for reference implementation) + + +def flex_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + if torch.is_tensor(q_idx): + return torch.ones_like(q_idx, dtype=torch.bool) + return True + + +def flex_identity_partial_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + if torch.is_tensor(q_idx): + return torch.ones_like(q_idx, dtype=torch.bool) + return True + + +def flex_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Right-aligned causal masking + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return kv_idx <= q_idx + offset + return kv_idx <= q_idx + + +def flex_block_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Right-aligned causal masking + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return kv_idx <= q_idx + offset + return kv_idx <= q_idx + + +def create_flex_sliding_window_mask(window_size=1024): + """Factory function to create a sliding window mask with configurable window size""" + def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + # Sliding window: q_idx - window_size <= kv_idx <= q_idx + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + return (kv_idx <= q_idx + offset) & (kv_idx >= q_idx + offset - window_size) + return (kv_idx <= q_idx) & (kv_idx >= q_idx - window_size) + return flex_sliding_window_mask + + +# Default sliding window mask with window_size=1024 for backward compatibility +def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + window_size = 1024 + if seqlen_q is not None and seqlen_k is not None: + offset = seqlen_k - seqlen_q + # Sliding window: q_pos - window_size < kv_pos <= q_pos + # Note: using strict inequality on the left to match typical sliding window behavior + return (kv_idx <= q_idx + offset) & (kv_idx > q_idx + offset - window_size) + return (kv_idx <= q_idx) & (kv_idx > q_idx - window_size) + + +def flex_block_diagonal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None, block_size=64): + return (q_idx // block_size) == (kv_idx // block_size) + + +def flex_mini_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + return (q_idx % 128) >= (kv_idx % 128) + + +def flex_half_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): + """Even k-blocks are full blocks, odd k-blocks are masked blocks (both return True)""" + if torch.is_tensor(kv_idx): + return torch.ones_like(kv_idx, dtype=torch.bool) + return True + +def flex_document_mask(b, h, q_idx, kv_idx, doc_id: torch.Tensor): + return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + +# CuTe versions for kernel compilation + + +@cute.jit +def cute_identity_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +@cute.jit +def cute_identity_partial_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +@cute.jit +def cute_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + # Right-aligned causal masking + offset = seqlen_k - seqlen_q + return cutlass.Boolean(n_idx <= m_idx + offset) + + +@cute.jit +def cute_block_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, +) -> cutlass.Boolean: + # Right-aligned causal masking + offset = seqlen_k - seqlen_q + return cutlass.Boolean(n_idx <= m_idx + offset) + + +def create_cute_sliding_window_mask(window_size=1024): + """Factory function to create a CuTe sliding window mask with configurable window size""" + @cute.jit + def cute_sliding_window_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + ) -> cutlass.Boolean: + offset = seqlen_k - seqlen_q + + return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + return cute_sliding_window_mask + + +# Default sliding window mask with window_size=1024 for backward compatibility +@cute.jit +def cute_sliding_window_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + window_size = 1024 + # offset = seqlen_k - seqlen_q + offset = 0 + return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + + +@cute.jit +def cute_document_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: list, +): + doc_id = buffers[0] + return cutlass.Boolean(doc_id[batch, head, m_idx] == doc_id[batch, head, n_idx]) + + +@cute.jit +def cute_block_diagonal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + return cutlass.Boolean((m_idx // 64) == (n_idx // 64)) + + +@cute.jit +def cute_mini_causal_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers +) -> cutlass.Boolean: + """Each tile is locally causal-masked""" + m_mod = m_idx % 128 + n_mod = n_idx % 128 + return cutlass.Boolean(m_mod >= n_mod) + + +@cute.jit +def cute_half_identity_mask( + batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32 +) -> cutlass.Boolean: + return cutlass.Boolean(True) + + +def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): + doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + for b in range(batch): + for h in range(nheads): + N = seqlen_q + n = random.randint(1, math.ceil(math.sqrt(N // 4))) + cuts = sorted(random.sample(range(1, N), n-1)) + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] + + doc_ids = [] + for i, length in enumerate(lengths): + doc_ids += [i for _ in range(length)] + + doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) + print(f"{doc_ids_tensor.shape = }") + return doc_ids_tensor + + +MASK_FUNCTIONS = { + "identity": (cute_identity_mask, flex_identity_mask), + "identity_partial": (cute_identity_partial_mask, flex_identity_partial_mask), + "causal": (cute_causal_mask, flex_causal_mask), + "block_causal": (cute_block_causal_mask, flex_block_causal_mask), + "sliding_window": (cute_sliding_window_mask, flex_sliding_window_mask), + "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), + "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "half_identity": (cute_half_identity_mask, flex_half_identity_mask), + "document": (cute_document_mask, flex_document_mask), +} + +if __name__ == "__main__": + doc_ids = random_doc_id_tensor(1, 2, 128) + print(f"{doc_ids = }") \ No newline at end of file diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index a654e90d23e..644936d8d2d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -52,6 +52,8 @@ "seqlen_q,seqlen_k", [ (1, 1), + (3, 3), + (64, 32), (64, 128), (128, 192), (256, 256), @@ -82,6 +84,8 @@ def test_flash_attn_output( device = "cuda" # set seed torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 @@ -256,8 +260,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @@ -268,8 +272,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -@pytest.mark.parametrize("d", [128, 192]) -# @pytest.mark.parametrize("d", [192]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1040,4 +1044,4 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # Test with LSE not returned out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" \ No newline at end of file diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py new file mode 100644 index 00000000000..3e6707b5fb9 --- /dev/null +++ b/tests/cute/test_mask_mod.py @@ -0,0 +1,570 @@ +# mask mod test script + +import math + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +import torch.nn.functional as F + +from flash_attn.cute.block_sparsity import compute_block_sparsity +from flash_attn.cute.flash_fwd import ( + FlashAttentionForwardSm80, + FlashAttentionForwardSm90, +) +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.mask_definitions import MASK_FUNCTIONS, flex_causal_mask, create_flex_sliding_window_mask, create_cute_sliding_window_mask +from flash_attn.cute.testing import attention_ref + + +def create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype +): + device = "cuda" + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype) + k = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype + ) + out = torch.empty( + batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype + ) + lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) + + return { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + +def compile_and_run_kernel( + tensors, + mask_mod_cute, + causal, + is_local, + window_left, + window_right, + tile_m, + tile_n, + full_block_cnt=None, + full_block_idx=None, + mask_block_cnt=None, + mask_block_idx=None, +): + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[tensors["q"].dtype] + + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + headdim_v = tensors["v"].shape[-1] + + compute_capability = torch.cuda.get_device_capability() + if compute_capability >= (10, 0): + kernel_class = FlashAttentionForwardSm100 + elif compute_capability >= (9, 0): + kernel_class = FlashAttentionForwardSm90 + else: + kernel_class = FlashAttentionForwardSm80 + + qhead_per_kvhead = nheads // nheads_kv + kernel = kernel_class( + cute_dtype, + headdim, + headdim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=is_local, + pack_gqa=False, + tile_m=tile_m, + tile_n=tile_n, + num_stages=2, + num_threads=384, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod_cute, + has_buffers=False, + Q_in_regs=False, + ) + + softmax_scale = 1.0 / math.sqrt(headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack( + tensors["out"].detach(), assumed_align=16 + ).mark_layout_dynamic(leading_dim=tensors["out"].ndim - 1) + lse_cute = from_dlpack( + tensors["lse"].detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=tensors["lse"].ndim - 1) + + full_block_cnt_cute = ( + from_dlpack(full_block_cnt.detach(), assumed_align=4) + if full_block_cnt is not None + else None + ) + full_block_idx_cute = ( + from_dlpack(full_block_idx.detach(), assumed_align=4) + if full_block_idx is not None + else None + ) + mask_block_cnt_cute = ( + from_dlpack(mask_block_cnt.detach(), assumed_align=4) + if mask_block_cnt is not None + else None + ) + mask_block_idx_cute = ( + from_dlpack(mask_block_idx.detach(), assumed_align=4) + if mask_block_idx is not None + else None + ) + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(window_left) if window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(window_right) if window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + None, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + None, # buffers + ) + + compiled( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + None, # learnable_sink + full_block_cnt_cute, + full_block_idx_cute, + mask_block_cnt_cute, + mask_block_idx_cute, + None, # buffers + ) + + torch.cuda.synchronize() + return tensors["out"] + + +def compute_reference_flash_attn( + tensors, causal, window_size, dtype_ref, upcast=True +): + """Compute reference using FlashAttention's attention_ref function""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].to(dtype_ref) + k = tensors["k"].to(dtype_ref) + v = tensors["v"].to(dtype_ref) + + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=causal, + window_size=window_size, + upcast=upcast, + reorder_ops=False, + ) + + return out_ref + + +def compute_reference_flex_attn( + tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n +): + """Compute reference using flex_attention for custom mask_mods""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].transpose(1, 2) + k = tensors["k"].transpose(1, 2) + v = tensors["v"].transpose(1, 2) + + if nheads != nheads_kv: + repeat_factor = nheads // nheads_kv + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(headdim) + + # Handle identity (no masking) case + if mask_mod_flex is None: + out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) + return out_ref.transpose(1, 2).contiguous() + + # Wrap mask_mod_flex to pass seqlen_q and seqlen_k + def mask_fn(b, h, q_idx, kv_idx): + return mask_mod_flex(b, h, q_idx, kv_idx, seqlen_q, seqlen_k) + + if mask_mod_name == "block_causal": + n_blocks_q = (seqlen_q + tile_m - 1) // tile_m + n_blocks_k = (seqlen_k + tile_n - 1) // tile_n + + mask = torch.zeros(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device) + + for q_block in range(n_blocks_q): + q_start = q_block * tile_m + q_end = min((q_block + 1) * tile_m, seqlen_q) + for k_block in range(n_blocks_k): + if k_block <= q_block: + k_start = k_block * tile_n + k_end = min((k_block + 1) * tile_n, seqlen_k) + mask[q_start:q_end, k_start:k_end] = True + + attn_mask = ( + mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) + ) + out_ref = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, scale=scale + ) + else: + block_mask = create_block_mask( + mask_fn, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + ).to(q.device) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) + + return out_ref.transpose(1, 2).contiguous() + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize("nheads", [4, 16, 32]) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) +# @pytest.mark.parametrize("headdim", [64, 128]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "use_mask_mod,is_local,mask_name,window_size,window_left,window_right", + [ + (False, False, "identity", None, None, None), + (False, False, "causal", None, None, None), + (True, False, "identity", None, None, None), + (True, False, "causal", None, None, None), + # (True, False, "block_causal", None, None, None), + # Mask mod sliding window + (True, False, "sliding_window", 128, None, None), + (True, False, "sliding_window", 256, None, None), + (True, False, "sliding_window", 512, None, None), + # Base local attention + # (False, True, None, None, 128, 0), + # (False, True, None, None, 256, 0), + # (False, True, None, None, 512, 0), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128),]) +def test_mask_mod_output( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, + use_mask_mod, is_local, mask_name, window_size, window_left, window_right, + tile_m, tile_n +): + torch.manual_seed(42) + + # Validate configuration + if is_local: + assert not use_mask_mod, "Cannot use both is_local and use_mask_mod" + assert window_left is not None or window_right is not None, \ + "Must specify window_left or window_right for is_local" + + if use_mask_mod and mask_name == "sliding_window": + assert window_size is not None, "window_size must be specified for sliding_window" + # Skip if seqlen_k is too small for the window + # if seqlen_k < window_size // 2: + # pytest.skip(f"seqlen_k={seqlen_k} too small for window_size={window_size}") + # Skip if seqlen_q > seqlen_k (problematic for sliding window) + if seqlen_q > seqlen_k: + pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window") + + if is_local: + window_left_val = window_left if window_left is not None else 0 + window_right_val = window_right if window_right is not None else 0 + total_window = window_left_val + window_right_val + 1 + # Skip if seqlen_k is too small for the window + if seqlen_k < total_window // 2: + pytest.skip(f"seqlen_k={seqlen_k} too small for window={total_window}") + # Skip if seqlen_q > seqlen_k (problematic for local window) + if seqlen_q > seqlen_k: + pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local") + + # Determine nheads_kv based on mode + if kv_mode == "mha": + nheads_kv = nheads + elif kv_mode == "gqa": + nheads_kv = nheads // 2 + elif kv_mode == "mqa": + nheads_kv = 1 + else: + raise ValueError(f"Unknown kv_mode: {kv_mode}") + + batch_size = 2 + headdim_v = headdim + + # Determine mask_mod functions and causal flag + if use_mask_mod: + if mask_name == "sliding_window": + # Use factory function for custom window size + mask_mod_cute = create_cute_sliding_window_mask(window_size) + mask_mod_flex = create_flex_sliding_window_mask(window_size) + else: + mask_mod_cute, mask_mod_flex = MASK_FUNCTIONS[mask_name] + causal = (mask_name == "causal") + elif is_local: + # Base local attention - no mask_mod + mask_mod_cute = None + mask_mod_flex = None + causal = False + else: + mask_mod_cute = None + mask_mod_flex = None + causal = (mask_name == "causal") if mask_name else False + + if causal and seqlen_k < seqlen_q: + pytest.skip("causal masking requires seqlen_k >= seqlen_q") + + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype + ) + + # Compute block sparsity for mask_mod + full_cnt, full_idx, mask_cnt, mask_idx = None, None, None, None + if use_mask_mod: + from dataclasses import dataclass + + @dataclass + class Config: + seqlen_q: int + seqlen_k: int + nheads: int + nheads_kv: int + batch_size: int + tile_m: int + tile_n: int + use_mask_mod: bool + mask_mod_name: str + window_size: int = 1024 + verbose: bool = False + + config = Config( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + batch_size=batch_size, + tile_m=tile_m, + tile_n=tile_n, + use_mask_mod=True, + mask_mod_name=mask_name, + window_size=window_size if window_size is not None else 1024, + ) + + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=config, mask_mod_flex=mask_mod_flex, device="cuda" + ) + + # Run kernel + out_cute = compile_and_run_kernel( + tensors, + mask_mod_cute, + causal=causal, + is_local=is_local, + window_left=window_left, + window_right=window_right, + tile_m=tile_m, + tile_n=tile_n, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + ) + + # Determine which reference implementation to use + dtype_ref = torch.bfloat16 + use_flash_attn_ref = False + + # Use FlashAttention reference for causal and local window cases + if mask_name == "causal" and not use_mask_mod: + use_flash_attn_ref = True + window_size_ref = (None, None) # attention_ref handles causal internally + elif mask_name == "identity" and not use_mask_mod and not is_local: + use_flash_attn_ref = True + window_size_ref = (None, None) # No window for identity + elif is_local: + use_flash_attn_ref = True + # For is_local, we need to pass the window parameters + # When window_right=0, this is inherently causal + window_size_ref = (window_left, window_right) + if window_right == 0: + causal = True # Override causal flag for reference computation + elif use_mask_mod and mask_name == "sliding_window": + use_flash_attn_ref = True + # For sliding window mask_mod, window_size corresponds directly to window_left + # in attention_ref (number of previous tokens that can be attended to) + # Sliding window with window_right=0 is inherently causal + window_size_ref = (window_size, 0) + causal = True # Override causal flag for reference computation + + if use_flash_attn_ref: + # Compute reference using FlashAttention's attention_ref + out_ref_fp32 = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=torch.float32, upcast=True + ) + out_ref = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype_ref, upcast=False + ) + + # Also compute PyTorch reference for comparison (with reorder_ops for better accuracy) + out_pt = compute_reference_flash_attn( + tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype, upcast=False + ) + else: + # Use flex_attention for custom mask_mods + tensors_fp32 = { + k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v + for k, v in tensors.items() + } + + out_ref_fp32 = compute_reference_flex_attn( + tensors_fp32, mask_mod_flex, mask_name, tile_m, tile_n + ) + out_ref = compute_reference_flex_attn( + tensors, mask_mod_flex, mask_name, tile_m, tile_n + ) + out_pt = out_ref.clone() + + # Check for invalid values + assert out_cute.shape == out_ref_fp32.shape == out_ref.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + # Compute numerical tolerance (matching flash attention tests) + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + ref_error = (out_ref - out_ref_fp32).abs().max().item() + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + # Build description string + if is_local: + mask_desc = f"is_local(L={window_left},R={window_right})" + elif use_mask_mod: + mask_desc = f"mask_mod={mask_name}" + if mask_name == "sliding_window" and window_size is not None: + mask_desc += f"(w={window_size})" + else: + mask_desc = mask_name if mask_name else "identity" + + print( + f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " + f"D={headdim}, M={tile_m}, N={tile_n}" + ) + print(f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}") + print(f" Reference vs FP32: {ref_error:.2e}") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") + + # Debug: show some sample values if error is large + if cute_error > 1e-2: + print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") + print(f" DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}") + print(f" DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}") + max_diff_idx = (out_cute - out_ref_fp32).abs().argmax() + max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape) + print(f" DEBUG: Max diff at coords: {max_diff_coords}") + print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") + print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") + + # Use the same assertion logic as FlashAttention tests + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) \ No newline at end of file From 16c7f0f647db325506691e0810114ef198df0d0a Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 21 Oct 2025 15:19:49 -0700 Subject: [PATCH 181/258] cutlass v4.3.0 (#1952) --- csrc/cutlass | 2 +- flash_attn/cute/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index c6aeb9179c5..b1d6e2c9b33 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit c6aeb9179c5f74a0fcdbd28527bf4b6ba8c60752 +Subproject commit b1d6e2c9b334dfa811e4183dfbd02419249e4b52 diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 0c34f83f1cf..a5d829a908b 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.2.1", + "nvidia-cutlass-dsl==4.3.0.dev0", "torch", "einops", ] From 9dbed03d1a7a5862998c182c83d8265fea9dc21b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 21 Oct 2025 18:31:55 -0400 Subject: [PATCH 182/258] [Cute,Bwd,Sm100] Use CopyBulkG2SOp copy op instead of calling ptx --- flash_attn/cute/flash_bwd_sm100.py | 44 ++++++++++++++---------------- flash_attn/cute/interface.py | 10 +++---- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index f3c6c307b69..b6d7fbe9fb1 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -11,7 +11,7 @@ from cutlass.utils import LayoutEnum from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic -from cutlass.pipeline import PipelineAsync +from cutlass.pipeline import PipelineAsync, PipelineConsumer from flash_attn.cute import utils from flash_attn.cute import copy_utils @@ -897,7 +897,7 @@ def kernel( tdKtdK, tdPtdP, tdQtdQ, - pipeline_Q, + pipeline_Q.make_consumer(), pipeline_dO, pipeline_S_P, pipeline_dS, @@ -1060,8 +1060,10 @@ def load( tma_atom_dO, 0, cute.make_layout(1), tdVgdO, sdO ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) - load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) - load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) + copy_atom_stats = cute.make_copy_atom( + cpasync.CopyBulkG2SOp(), Float32, num_bits_per_copy=self.tma_copy_bytes["LSE"] * 8 + ) + copy_stats = partial(cute.copy, copy_atom_stats) # First iteration: load K together w Q & LSE, then V together w dO & dPsum # K & Q @@ -1075,7 +1077,7 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - load_LSE(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=LSE_full_mbar_ptr) + copy_stats(gLSE[None, m_block_min], sLSE[None, 0], mbar_ptr=LSE_full_mbar_ptr) # V & dO pipeline_dO.producer_acquire(producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"]) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) @@ -1087,7 +1089,7 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - load_dPsum(src_idx=m_block_min, dst_idx=0, tma_bar_ptr=dPsum_full_mbar_ptr) + copy_stats(gdPsum[None, m_block_min], sdPsum[None, 0], mbar_ptr=dPsum_full_mbar_ptr) lse_empty_consumer_phase = cute.Int32(0) dpsum_empty_consumer_phase = cute.Int32(0) @@ -1105,7 +1107,7 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] ) - load_LSE(src_idx=m_block, dst_idx=0, tma_bar_ptr=LSE_full_mbar_ptr) + copy_stats(gLSE[None, m_block], sLSE[None, 0], mbar_ptr=LSE_full_mbar_ptr) # dO pipeline_dO.producer_acquire(producer_state_dO) load_dO(m_block, producer_state=producer_state_dO) @@ -1118,7 +1120,7 @@ def load( cute.arch.mbarrier_arrive_and_expect_tx( dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] ) - load_dPsum(src_idx=m_block, dst_idx=0, tma_bar_ptr=dPsum_full_mbar_ptr) + copy_stats(gdPsum[None, m_block], sdPsum[None, 0], mbar_ptr=dPsum_full_mbar_ptr) pipeline_Q.producer_tail(producer_state_Q) pipeline_dO.producer_tail(producer_state_dO) @@ -1148,7 +1150,7 @@ def mma( tdKtdK: cute.Tensor, tdPtdP: cute.Tensor, tdQtdQ: cute.Tensor, - pipeline_Q: PipelineAsync, + pipeline_Q_consumer: PipelineConsumer, pipeline_dO: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, @@ -1213,9 +1215,6 @@ def mma( # gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=sdSt, sB=sQt, A_idx=0 # ) - consumer_state_Q = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage - ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) @@ -1256,10 +1255,10 @@ def mma( # 3. dV = P @ dO # 1) S = Q0 @ K.T - pipeline_Q.consumer_wait(consumer_state_Q) + handle_Q = pipeline_Q_consumer.wait_and_advance() # pipeline_S_P.producer_acquire(producer_state_S_P) pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) - mma_qk_fn(B_idx=consumer_state_Q.index) + mma_qk_fn(B_idx=handle_Q.index) # Don't release Q yet # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) @@ -1297,11 +1296,9 @@ def mma( for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # 1) S = K @ Q_i - consumer_state_Q_prev = consumer_state_Q.clone() - consumer_state_Q.advance() - pipeline_Q.consumer_wait(consumer_state_Q) + handle_Q_next = pipeline_Q_consumer.wait_and_advance() # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready - mma_qk_fn(B_idx=consumer_state_Q.index) + mma_qk_fn(B_idx=handle_Q_next.index) # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) # producer_state_S_P.advance() @@ -1318,9 +1315,9 @@ def mma( producer_phase_dQ ^= 1 # 3) dK = dS.T @ Q - mma_dsq_fn(B_idx=consumer_state_Q_prev.index, zero_init=not accumulate_dK) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) accumulate_dK = True - pipeline_Q.consumer_release(consumer_state_Q_prev) + handle_Q.release() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1342,6 +1339,8 @@ def mma( pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() + handle_Q = handle_Q_next + # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) # producer_state_S_P.advance() @@ -1361,7 +1360,7 @@ def mma( # ----------------------------------------------------------- # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) # signal to the epilogue that dK is ready # pipeline_dKV.producer_commit(producer_state_dKV) pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) @@ -1375,8 +1374,7 @@ def mma( # producer_state_dQ.advance() producer_phase_dQ ^= 1 # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier - pipeline_Q.consumer_release(consumer_state_Q) - consumer_state_Q.advance() + handle_Q.release() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 0615061a541..8c2e5903fc4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -323,7 +323,7 @@ def _flash_attn_fwd( page_table_tensor, window_size_left, window_size_right, learnable_sink_tensor, full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - cute_buffers, + buffers=cute_buffers, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, @@ -331,7 +331,7 @@ def _flash_attn_fwd( page_table_tensor, window_size_left, window_size_right, learnable_sink_tensor, full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - cute_buffers, + buffers=cute_buffers, ) return out, lse @@ -691,7 +691,7 @@ def backward(ctx, dout, *args): ctx.causal, ctx.softcap, ) - return dq, dk, dv, *((None,) * 10) # Extra Nones is fine + return dq, dk, dv, *((None,) * 20) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @@ -759,7 +759,7 @@ def backward(ctx, dout, *args): seqused_k=seqused_k, ) - return dq, dk, dv, *((None,) * 11) + return dq, dk, dv, *((None,) * 20) def flash_attn_func( @@ -1049,4 +1049,4 @@ def flash_attn_combine( lse = None _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) - return out, lse \ No newline at end of file + return out, lse From 1b8e1e641c6a179be9a0538b7f40fd595050b735 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 21 Oct 2025 23:17:14 -0400 Subject: [PATCH 183/258] [Cute,Bwd,Sm100] More cleanup --- flash_attn/cute/flash_bwd_sm100.py | 326 ++++++++++++++--------------- 1 file changed, 161 insertions(+), 165 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index b6d7fbe9fb1..7eaf7b95849 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -110,10 +110,33 @@ def __init__( ) ) + # NamedBarrier + self.compute_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE, + ) + # self.epilogue_sync_barrier = pipeline.NamedBarrier( + # barrier_id=2, + # num_threads=self.num_compute_warps * self.threads_per_warp, + # ) + self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE, + ) + # TMEM setup SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + # self.tmem_dK_offset = 0 + # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim + # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv + # self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ + # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim) + # self.tmem_P_offset = self.tmem_S_offset # overlap with S + # self.tmem_total = self.tmem_S_offset + self.tile_n + # assert self.tmem_total <= self.tmem_alloc_cols + self.tmem_S_offset = 0 self.tmem_P_offset = 0 # overlap with S self.tmem_dV_offset = self.tmem_S_offset + self.tile_n @@ -123,24 +146,23 @@ def __init__( self.num_regs_reduce = 160 self.num_regs_compute = 128 - self.num_regs_other = 80 + self.num_regs_other = 96 self.num_regs_empty = 24 assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 self.buffer_align_bytes = 1024 - self.num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) - def _setup_attributes(self): self.Q_stage = 2 - self.k_stage = self.v_stage = 1 self.dO_stage = 1 self.LSE_stage = 1 - self.sdQaccum_stage = 2 self.dPsum_stage = 1 self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma - self.dQaccum_reduce_stage = self.tile_hdim // 32 + self.dQ_reduce_ncol = 32 + self.sdQaccum_stage = 64 // self.dQ_reduce_ncol + assert self.tile_hdim % self.dQ_reduce_ncol == 0 + self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE @@ -189,7 +211,7 @@ def _setup_smem_layout(self): self.tiled_mma_SdP, self.mma_tiler_kq, self.k_dtype, - self.k_stage, + 1, ) self.sQ_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_SdP, @@ -197,19 +219,12 @@ def _setup_smem_layout(self): self.q_dtype, self.Q_stage, ) - # dV += P @ dO - self.sdO_layout = sm100_utils_basic.make_smem_layout_b( - self.tiled_mma_dV, - self.mma_tiler_pdo, - self.do_dtype, - self.dO_stage, - ) # dP = V @ dO.T self.sV_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_SdP, self.mma_tiler_vdo, self.v_dtype, - self.v_stage, + 1, ) self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_SdP, @@ -217,6 +232,19 @@ def _setup_smem_layout(self): self.do_dtype, self.dO_stage, ) + # dV += P @ dO + self.tP_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + 1, + ) + self.sdO_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + self.dO_stage, + ) # dK += dS.T @ Q self.sdSt_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, @@ -230,21 +258,22 @@ def _setup_smem_layout(self): self.q_dtype, self.Q_stage, ) - # dQaccum = dS @ K + # dQ = dS @ K self.sdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dQ, self.mma_tiler_dsk, - self.q_dtype, + self.ds_dtype, 1, ) self.sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, self.mma_tiler_dsk, self.k_dtype, - self.k_stage, + 1, + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) ) - - self.sdQaccum_layout = cute.make_layout((self.tile_m * 32, self.sdQaccum_stage)) self.sLSE_layout = cute.make_layout( shape=(self.tile_m, self.LSE_stage), stride=(1, cute.round_up(self.tile_m, 64)), @@ -253,6 +282,17 @@ def _setup_smem_layout(self): shape=(self.tile_m, self.dPsum_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) + self.sdKV_epi_tile = ( + self.tile_n, + 128 // (self.dk_dtype.width // 8), + ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + # TODO: dK and dV could have different shapes + self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dk_dtype, + LayoutEnum.ROW_MAJOR, + self.sdKV_epi_tile, + self.sdKVaccum_stage, + ) @cute.jit def __call__( @@ -337,16 +377,6 @@ def __call__( raise RuntimeError("The layout of mdK is wrong") if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - self.sdKV_epi_tile = ( - self.tile_n, - 128 // (self.dk_dtype.width // 8), - ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] - sdKV_layout = sm100_utils_basic.make_smem_layout_epi( - self.dk_dtype, - self.mdK_layout_enum, - self.sdKV_epi_tile, - self.sdKVaccum_stage, - ) if const_expr(self.use_tma_store): if const_expr(self.dk_dtype.width == 32): @@ -357,14 +387,14 @@ def __call__( tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdK, - cute.select(sdKV_layout, mode=[0, 1]), + cute.select(self.sdKV_layout, mode=[0, 1]), self.sdKV_epi_tile, 1, # no mcast ) tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdV, - cute.select(sdKV_layout, mode=[0, 1]), + cute.select(self.sdKV_layout, mode=[0, 1]), self.sdKV_epi_tile, 1, # no mcast ) @@ -389,6 +419,7 @@ def __call__( ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) # S = K @ Q.T tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( @@ -400,22 +431,13 @@ def __call__( self.cluster_layout_vmnk.shape, ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, + tma_load_op if const_expr(self.cluster_shape_mnk[1] == 1) else tma_load_op_multicast, mQ, cute.select(self.sQ_layout, mode=[0, 1, 2]), self.mma_tiler_kq, self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) - # dV += P @ dO - tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, - mdO, - cute.select(self.sdO_layout, mode=[0, 1, 2]), - self.mma_tiler_pdo, - self.tiled_mma_dV, - self.cluster_layout_vmnk.shape, - ) # dP = V @ dO.T tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, @@ -425,6 +447,14 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op if const_expr(self.cluster_shape_mnk[1] == 1) else tma_load_op_multicast, + mdO, + cute.select(self.sdO_layout, mode=[0, 1, 2]), + self.mma_tiler_pdo, + self.tiled_mma_dV, + self.cluster_layout_vmnk.shape, + ) self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) @@ -437,7 +467,7 @@ def __call__( } self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 - self.tma_copy_bytes["dQ"] = self.tile_m * 32 * Float32.width // 8 + self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal @@ -475,9 +505,7 @@ class SharedStorage: dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] - dQaccum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] - - # TMEM + dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] tmem_holding_buf: Int32 tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] @@ -537,7 +565,6 @@ class SharedStorage: tma_atom_Q, tma_atom_K, tma_atom_V, - # tma_atom_Psum, tma_atom_dO, tma_atom_dV, tma_atom_dK, @@ -553,7 +580,8 @@ class SharedStorage: self.sdS_layout, self.sKt_layout, self.sdQaccum_layout, - sdKV_layout, + self.sdKV_layout, + self.tP_layout, self.tiled_mma_SdP, self.tiled_mma_dV, self.tiled_mma_dK, @@ -607,6 +635,7 @@ def kernel( sKt_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, sdKV_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, tiled_mma_SdP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, @@ -708,7 +737,7 @@ def kernel( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, - barrier_storage=storage.dQaccum_mbar_ptr.data_ptr(), + barrier_storage=storage.dQ_mbar_ptr.data_ptr(), ) # AsyncThread producers and UMMA consumers @@ -728,44 +757,28 @@ def kernel( ) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sQt = cute.make_tensor( - cute.recast_ptr(sQ.iterator, swizzle_=sQt_layout.inner), sQt_layout.outer - ) - + sQt = cute.make_tensor(cute.recast_ptr(sQ.iterator, sQt_layout.inner), sQt_layout.outer) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - sKt = cute.make_tensor( - cute.recast_ptr(sK.iterator, swizzle_=sKt_layout.inner), sKt_layout.outer - ) - + sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) - sdS = cute.make_tensor( - cute.recast_ptr(sdSt.iterator, swizzle_=sdS_layout.inner), sdS_layout.outer - ) - + sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) - sdOt = cute.make_tensor( - cute.recast_ptr(sdO.iterator, swizzle_=sdOt_layout.inner), sdOt_layout.outer - ) - + sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, sdOt_layout.inner), sdOt_layout.outer) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) - sdV = storage.sdO.get_tensor( sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype ) sdK = storage.sQ.get_tensor( sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype ) - - assert cute.cosize(sdV) * self.dv_dtype.width <= cute.cosize(sdO) * self.do_dtype.width, ( - "Not enough space for sdV" - ) - assert cute.cosize(sdK) * self.dk_dtype.width <= cute.cosize(sQ) * self.q_dtype.width, ( - "Not enough space for sdK" - ) - + assert cute.size_in_bytes(self.do_dtype, sdO_layout) >= cute.size_in_bytes( + self.dv_dtype, sdKV_layout + ), "Not enough space for sdV" + assert cute.size_in_bytes(self.q_dtype, sQ_layout) >= cute.size_in_bytes( + self.dk_dtype, sdKV_layout + ), "Not enough space for sdK" sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM @@ -773,12 +786,19 @@ def kernel( thr_mma_SdP = tiled_mma_SdP.get_slice(0) Sacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) tStS = thr_mma_SdP.make_fragment_C(Sacc_shape) - tStS = cute.make_tensor(tStS.iterator, tStS.layout) + # (MMA, MMA_M, MMA_N) + tStS = cute.make_tensor(tStS.iterator + self.tmem_S_offset, tStS.layout) + # dP + dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) + tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) # dV thr_mma_dV = tiled_mma_dV.get_slice(0) dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) + tP_ptr = cute.make_ptr(self.do_dtype, self.tmem_P_offset, cute.AddressSpace.tmem) + tP = cute.make_tensor(tP_ptr, tP_layout.outer) # dK thr_mma_dK = tiled_mma_dK.get_slice(0) dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) @@ -789,10 +809,6 @@ def kernel( dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQ_offset, tdQtdQ.layout) - # dP - dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) - tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) - tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) block_info = BlockInfo( self.tile_m, @@ -857,11 +873,11 @@ def kernel( tma_atom_K, tma_atom_V, tma_atom_dO, - pipeline_Q, LSE_full_mbar_ptr, LSE_empty_mbar_ptr, dPsum_full_mbar_ptr, dPsum_empty_mbar_ptr, + pipeline_Q, pipeline_dO, block_info, SeqlenInfoCls, @@ -892,10 +908,11 @@ def kernel( sdSt, sdS, sKt, + tP, tStS, + tdPtdP, tdVtdV, tdKtdK, - tdPtdP, tdQtdQ, pipeline_Q.make_consumer(), pipeline_dO, @@ -1001,11 +1018,11 @@ def load( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, - pipeline_Q: PipelineAsync, LSE_full_mbar_ptr: cute.Pointer, LSE_empty_mbar_ptr: cute.Pointer, dPsum_full_mbar_ptr: cute.Pointer, dPsum_empty_mbar_ptr: cute.Pointer, + pipeline_Q: PipelineAsync, pipeline_dO: PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, @@ -1060,9 +1077,7 @@ def load( tma_atom_dO, 0, cute.make_layout(1), tdVgdO, sdO ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) - copy_atom_stats = cute.make_copy_atom( - cpasync.CopyBulkG2SOp(), Float32, num_bits_per_copy=self.tma_copy_bytes["LSE"] * 8 - ) + copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) copy_stats = partial(cute.copy, copy_atom_stats) # First iteration: load K together w Q & LSE, then V together w dO & dPsum @@ -1093,7 +1108,6 @@ def load( lse_empty_consumer_phase = cute.Int32(0) dpsum_empty_consumer_phase = cute.Int32(0) - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # Q pipeline_Q.producer_acquire(producer_state_Q) @@ -1145,10 +1159,11 @@ def mma( sdSt: cute.Tensor, sdS: cute.Tensor, sKt: cute.Tensor, + tP: cute.Tensor, tStS: cute.Tensor, + tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, - tdPtdP: cute.Tensor, tdQtdQ: cute.Tensor, pipeline_Q_consumer: PipelineConsumer, pipeline_dO: PipelineAsync, @@ -1161,34 +1176,24 @@ def mma( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - thr_mma_SdP = tiled_mma_SdP.get_slice(0) - thr_mma_dV = tiled_mma_dV.get_slice(0) - thr_mma_dK = tiled_mma_dK.get_slice(0) - thr_mma_dQ = tiled_mma_dQ.get_slice(0) + # [2025-10-21] For reasons I don't understand, putting these partitioning in the main + # kernel (before warp specialization) is a lot slower tha putting them here. # Partition smem / tmem tensors # S = K @ Q.T - tSrK = thr_mma_SdP.make_fragment_A(sK) - tSrQ = thr_mma_SdP.make_fragment_B(sQ) + tSrK = tiled_mma_SdP.make_fragment_A(sK) + tSrQ = tiled_mma_SdP.make_fragment_B(sQ) # dP = V @ dO.T - tdPrV = thr_mma_SdP.make_fragment_A(sV) - tdPrdOt = thr_mma_SdP.make_fragment_B(sdOt) + tdPrV = tiled_mma_SdP.make_fragment_A(sV) + tdPrdOt = tiled_mma_SdP.make_fragment_B(sdOt) # dK = dS.T @ Q - tdKrdS = thr_mma_dK.make_fragment_A(sdSt) - tdKrQ = thr_mma_dK.make_fragment_B(sQt) + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K - tdQrdS = thr_mma_dQ.make_fragment_A(sdS) - tdQrK = thr_mma_dQ.make_fragment_B(sKt) + tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) + tdQrK = tiled_mma_dQ.make_fragment_B(sKt) # dV = P @ dO.T - tdVrdO = thr_mma_dV.make_fragment_B(sdO) - p_tmem_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_dV, - self.mma_tiler_pdo, - self.q_dtype, - 1, - ) - tP = cute.make_tensor(tStS.iterator, p_tmem_layout.outer) - tdVrP = thr_mma_dV.make_fragment_A(tP)[None, None, None, 0] - tdVrP = cute.make_tensor(tdVrP.iterator, tdVrP.layout) + tdVrdO = tiled_mma_dV.make_fragment_B(sdO) + tdVrP = tiled_mma_dV.make_fragment_A(tP)[None, None, None, 0] mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) # mma_qk_fn = partial( @@ -1390,21 +1395,21 @@ def mma( @cute.jit def split_wg( self, - thr_tensor: cute.Tensor, + t: cute.Tensor, wg_idx: cutlass.Int32, - num_wg: cutlass.Constexpr[cutlass.Int32], + num_wg: cutlass.Constexpr[int], ): - reduced_shape = cute.product_each(thr_tensor.shape) + reduced_shape = cute.product_each(t.shape) rank = len(reduced_shape) if const_expr(reduced_shape[1] > 1): - assert rank >= 2, "Need rank >= 2 for thr_tensor in split_wg" - t = cute.logical_divide(thr_tensor, (reduced_shape[0], reduced_shape[1] // num_wg)) + assert rank >= 2, "Need rank >= 2 for t in split_wg" + t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg)) coord = (None, (None, wg_idx)) + (None,) * (rank - 2) else: - assert rank >= 3, "Need rank >= 3 for thr_tensor in split_wg" + assert rank >= 3, "Need rank >= 3 for t in split_wg" if const_expr(rank == 3): t = cute.logical_divide( - thr_tensor, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) + t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) ) coord = ( None, @@ -1413,7 +1418,7 @@ def split_wg( ) + (None,) * (rank - 3) else: t = cute.logical_divide( - thr_tensor, + t, ( reduced_shape[0], reduced_shape[1], @@ -1487,15 +1492,14 @@ def compute_loop( if const_expr(True): sLSE_2D = utils.transpose_view(sLSE_2D) sdPsum_2D = utils.transpose_view(sdPsum_2D) + # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 - tidx = cute.arch.thread_idx()[0] % 128 # 0...128 - wg_idx = ( - cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) - ) // 128 + tidx = cute.arch.thread_idx()[0] + dp_idx = tidx % 128 + wg_idx = (tidx % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 wg_idx = cute.arch.make_warp_uniform(wg_idx) - num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 # 2 - + num_wg = len(self.compute_warp_ids) // 4 # 2 # wg_idx: # 0: [256...384] # 1: [128...256] @@ -1512,7 +1516,7 @@ def compute_loop( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(tidx) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(dp_idx) tStS_t2r_p = thr_tmem_load.partition_S(tStS) tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) tdPtdP_t2r_p = thr_tmem_load.partition_S(tdPtdP) @@ -1524,7 +1528,7 @@ def compute_loop( tSsdPsum_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) tSsdPsum = self.split_wg(tSsdPsum_p, wg_idx, num_wg) - thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(dp_idx) tScP_r2t_p = thr_tmem_store.partition_S(tScP) tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) tStP_r2t_p = thr_tmem_store.partition_D(tStP) @@ -1568,15 +1572,6 @@ def compute_loop( #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) - cute.arch.fence_view_async_tmem_load() - - # Without this barrier, we could have 1 warp writing to P in tmem while - # another warp is still reading S from tmem. - cute.arch.barrier( - barrier_id=int(NamedBarrierBwdSm100.Compute), - number_of_threads=self.num_compute_threads, - ) - cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) consumer_phase_LSE ^= 1 @@ -1620,6 +1615,11 @@ def compute_loop( tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) utils.cvt_f16(tSrS_cur, tSrP_r2t[None, 0, 0]) + if const_expr(stage == 0): + cute.arch.fence_view_async_tmem_load() + # Without this barrier, we could have 1 warp writing to P in tmem while + # another warp is still reading S from tmem. + self.compute_sync_barrier.arrive_and_wait() cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t[None, None, stage]) cute.arch.fence_view_async_tmem_store() @@ -1648,7 +1648,7 @@ def compute_loop( ##### dS.T = P.T * (dP.T - Psum) sdSt_mn = cute.composition(sdS, cute.make_layout((self.tile_m, self.tile_n))) tdKsdS = cute.composition( - sdSt_mn[(None, wg_idx), tidx], cute.make_layout(tSrS_t2r.shape) + sdSt_mn[(None, wg_idx), dp_idx], cute.make_layout(tSrS_t2r.shape) ) tSrS_t2r_bf16 = cute.make_tensor( cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape @@ -1701,7 +1701,7 @@ def compute_loop( if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( - tidx, + dp_idx, warp_idx, batch_idx, head_idx, @@ -1717,10 +1717,10 @@ def compute_loop( softmax_scale, ) else: - thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(tidx) + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) #### STORE dV consumer_state_dKV = self.epilogue_dK_or_dV_tma( - tidx, + dp_idx, batch_idx, head_idx, n_block, @@ -1738,7 +1738,7 @@ def compute_loop( ) #### STORE dK consumer_state_dKV = self.epilogue_dK_or_dV_tma( - tidx, + dp_idx, batch_idx, head_idx, n_block, @@ -1777,7 +1777,7 @@ def dQacc_reduce( is_tma_warp = warp_idx == 0 # TMEM -> RMEM tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 ) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) tdQtdQ_t2r = thr_tmem_load.partition_S(tdQtdQ) @@ -1794,19 +1794,14 @@ def dQacc_reduce( read_flag = const_expr(not self.deterministic) - # TODO: reduce_phase is currently hardcoded for 2 stages - reduce_phase = cutlass.Int32(0) - - dQacc_reduce_barrier = cutlass.pipeline.NamedBarrier( - barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), - num_threads=num_reduce_threads, - ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() dQ_consumer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) + dQ_tma_store_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sdQaccum_stage + ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) @@ -1835,7 +1830,7 @@ def dQacc_reduce( # semaphore acquire if const_expr(self.deterministic): barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) - dQacc_reduce_barrier.arrive_and_wait() + self.reduce_sync_barrier.arrive_and_wait() # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops delay_tma_store = False @@ -1845,33 +1840,34 @@ def tma_store_fn(src_idx, dst_idx): cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) - dQacc_reduce_barrier.arrive_and_wait() + self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, src_idx].iterator, gdQaccum[None, dst_idx, m_block].iterator, - self.tma_copy_bytes["dQ"], + self.tma_copy_bytes["dQ"] // 1, ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(1, read=read_flag) - dQacc_reduce_barrier.arrive_and_wait() + cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() - reduce_phase_prev, stage_prev = None, -1 + smem_idx_prev, stage_prev = None, -1 for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 - tdQsdQ_r2s = tdQsdQ[None, None, reduce_phase] + smem_idx = dQ_tma_store_producer_state.index + tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] tdQrdQ_r2s = cute.make_tensor( tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape ) if const_expr(delay_tma_store): if const_expr(stage > 0): - tma_store_fn(src_idx=reduce_phase_prev, dst_idx=stage_prev) - reduce_phase_prev, stage_prev = reduce_phase, stage + tma_store_fn(src_idx=smem_idx_prev, dst_idx=stage_prev) + smem_idx_prev, stage_prev = smem_idx, stage cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) if const_expr(not delay_tma_store): - tma_store_fn(reduce_phase, stage) - reduce_phase ^= 1 + tma_store_fn(smem_idx, stage) + dQ_tma_store_producer_state.advance() # Directly add to gmem, much slower # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) @@ -1884,14 +1880,14 @@ def tma_store_fn(src_idx, dst_idx): # utils.elem_pointer(tdQgdQ, 4 * i), # ) if const_expr(delay_tma_store): - tma_store_fn(src_idx=reduce_phase_prev, dst_idx=stage_prev) + tma_store_fn(src_idx=smem_idx_prev, dst_idx=stage_prev) # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic): if tidx == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - dQacc_reduce_barrier.arrive_and_wait() + self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) if warp_idx == 0: @@ -2057,9 +2053,9 @@ def epilogue_dK_or_dV_tma( ) -> cutlass.pipeline.PipelineState: # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) # head_dim = head_dim_v, dk_dtype = dv_dtype - - wg_idx = (cute.arch.thread_idx()[0] % self.num_compute_threads) // 128 - num_wg = self.num_compute_threads // 128 + num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) + wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128 + num_wg = num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 sdKV = sdKV[None, None, wg_idx] From e4d25a432ab5dec54cbe6aff40a0b7f1febfaf54 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Thu, 23 Oct 2025 23:41:37 -0400 Subject: [PATCH 184/258] [CuTe DSL] Update "buffers" name to "aux_tensors"; fix flex bugs (#1961) * clean up and rebase for PR * add mask mod tests * add benchmarking files * refactor for better style * remove extraneous csrc * type hint buffers * refactor: order of non/overlap and modify blocksparse producer to agree with dense * change variable name back to buffers * remove unnecessary variable in first_half_block * restore erroneous packgqa deletion * add blocksparsity and mask_mod asserts to interface.py * fix rebase issues * Restore submodule and reset pointer to upstream/main * rename cutlass.const_expr to const_expr * support fully masked m blocks (i.e. skipped tiles) * remove outdated commented code * rename buffers -> aux_tensors, fix score_mod test in sm90 fwd * fix mask mod interface issues and tests * remove newline at end of file * format with ruff * format mask & sm100 with ruff * format more files with ruff * format barrier.py with ruff --- flash_attn/cute/barrier.py | 31 +- flash_attn/cute/benchmark_mask_mod.py | 36 +- flash_attn/cute/block_sparsity.py | 327 ++++++++---- flash_attn/cute/flash_fwd.py | 690 ++++++++++++++++++-------- flash_attn/cute/flash_fwd_sm100.py | 623 +++++++++++++++++------ flash_attn/cute/interface.py | 604 ++++++++++++++++------ flash_attn/cute/mask.py | 30 +- flash_attn/cute/mask_definitions.py | 121 +++-- flash_attn/cute/softmax.py | 14 +- tests/cute/test_flash_attn.py | 505 +++++++++++++++---- tests/cute/test_mask_mod.py | 340 +++++-------- tests/cute/test_score_mod.py | 68 ++- 12 files changed, 2362 insertions(+), 1027 deletions(-) diff --git a/flash_attn/cute/barrier.py b/flash_attn/cute/barrier.py index 744e3a56507..c999b180167 100644 --- a/flash_attn/cute/barrier.py +++ b/flash_attn/cute/barrier.py @@ -4,8 +4,9 @@ from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import llvm + @dsl_user_op -def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() state = llvm.inline_asm( T.i32(), @@ -18,8 +19,11 @@ def ld_acquire(lock_ptr : cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: ) return cutlass.Int32(state) + @dsl_user_op -def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, @@ -31,8 +35,11 @@ def red_relaxed(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N asm_dialect=llvm.AsmDialect.AD_ATT, ) + @dsl_user_op -def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None) -> None: +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, @@ -43,28 +50,22 @@ def red_release(lock_ptr : cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=N is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) - + + @cute.jit -def wait_eq( - lock_ptr : cute.Pointer, - thread_idx : int | Int32, - flag_offset : int, - val : Int32 -) -> None: +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: read_val = Int32(0) while read_val != val: read_val = ld_acquire(flag_ptr) + @cute.jit def arrive_inc( - lock_ptr : cute.Pointer, - thread_idx : int | Int32, - flag_offset : int, - val : cutlass.Constexpr[Int32] + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] ) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: red_release(flag_ptr, val) - # red_relaxed(flag_ptr, val) \ No newline at end of file + # red_relaxed(flag_ptr, val) diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py index 071b4e02a58..b1aadd89395 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -5,7 +5,6 @@ from dataclasses import dataclass import math -from pickle import FALSE from typing import Any, Dict, Optional, Tuple import cuda.bindings.driver as cuda @@ -51,7 +50,7 @@ class BenchmarkConfig: # Mask parameters use_mask_mod: bool = True mask_mod_name: str = "causal" - has_buffers: bool = mask_mod_name == "document" + has_aux_tensors: bool = mask_mod_name == "document" # Sliding window parameter (used when mask_mod_name == "sliding_window") window_size: int = 128 @@ -235,7 +234,6 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: dtype=torch.float32, device=device, ) - tensors = { "q": q.contiguous(), @@ -244,10 +242,10 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: "out": out.contiguous(), "lse": lse.contiguous(), } - + if config.use_learnable_sink: learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) - + tensors["learnable_sink"] = learnable_sink.contiguous() # Compute block sparsity when using mask_mod @@ -256,14 +254,14 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: doc_id = random_doc_id_tensor( config.batch_size, config.nheads, config.seqlen_q, device=device ) - tensors["buffers"] = [doc_id.contiguous()] + tensors["aux_tensors"] = [doc_id.contiguous()] full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( config=self.config, mask_mod_flex=self.mask_mod_flex, device=device, cu_seqlens_q=tensors.get("cu_seqlens_q"), cu_seqlens_k=tensors.get("cu_seqlens_k"), - buffers=tensors.get("buffers"), + aux_tensors=tensors.get("aux_tensors"), ) if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): @@ -329,7 +327,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] mma_pv_is_rs=config.mma_pv_is_rs, mask_mod=self.mask_mod_cute, Q_in_regs=False, - has_buffers=config.has_buffers, + has_aux_tensors=config.has_aux_tensors, ) softmax_scale = 1.0 / math.sqrt(config.headdim) @@ -405,14 +403,14 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] else None ) - if "buffers" in tensors: - buffers_cute = [] - for i in range(len(tensors["buffers"])): - buf = from_dlpack(tensors["buffers"][i].detach(), assumed_align=4) - buffers_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + if "aux_tensors" in tensors: + aux_tensors_cute = [] + for i in range(len(tensors["aux_tensors"])): + buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4) + aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2)) else: - buffers_cute = None + aux_tensors_cute = None # Window parameters for is_local window_left_cute = ( @@ -443,7 +441,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] full_block_idx_cute, mask_block_cnt_cute, mask_block_idx_cute, - buffers_cute, + aux_tensors_cute, # None, ) @@ -467,7 +465,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] full_block_idx_cute, mask_block_cnt_cute, mask_block_idx_cute, - buffers_cute, + aux_tensors_cute, # None, ) @@ -496,7 +494,7 @@ def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: num_blocks = (config.seqlen_k + block_size - 1) // block_size sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 elif config.mask_mod_name == "document": - vals = tensors["buffers"][0] + vals = tensors["aux_tensors"][0] val_mask = torch.ones_like(vals, dtype=torch.bool) val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] total = torch.where(val_mask, vals.square(), 0).sum() @@ -573,7 +571,7 @@ def benchmark(self) -> Dict[str, Any]: torch.cuda.synchronize() times.append(start.elapsed_time(end)) - + times_tensor = torch.tensor(times) mean_time = times_tensor.mean().item() std_time = times_tensor.std().item() if len(times) > 1 else 0.0 @@ -683,7 +681,7 @@ def _print_results(self, results: Dict[str, Any]): # seqlen_k=192, use_varlen=False, use_mask_mod=True, - mask_mod_name="identity", + mask_mod_name="causal", window_size=128, # Configurable window size for mask_mod use_learnable_sink=False, causal=False, diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index ce05cae1438..be685dea5d4 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -14,14 +14,17 @@ # placeholder Config = type("Config", (), {}) + def compute_block_sparsity( config: Config, mask_mod_flex: Optional[Callable], device: str, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, - buffers: Optional[List[torch.Tensor]] = None, -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + aux_tensors: Optional[List[torch.Tensor]] = None, +) -> Tuple[ + Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] +]: """ Computes block sparsity tensors from a given masking function. @@ -35,7 +38,7 @@ def compute_block_sparsity( device: The device to create tensors on (e.g., 'cuda'). cu_seqlens_q: Cumulative sequence lengths for Q (for varlen). cu_seqlens_k: Cumulative sequence lengths for K (for varlen). - buffers: A list of auxiliary tensors, e.g., for document masking. + aux_tensors: A list of auxiliary tensors, e.g., for document masking. Returns: A tuple of four tensors: @@ -53,25 +56,35 @@ def compute_block_sparsity( return _compute_varlen_sparsity(config, mask_mod_flex, device, cu_seqlens_q, cu_seqlens_k) else: # Handle fixed-length sequences - return _compute_sparsity(config, device, buffers) + return _compute_sparsity(config, device, aux_tensors) + ## --------------------------------------------------------------------------- ## Fixed-Length Sequence Kernels ## --------------------------------------------------------------------------- + def _compute_sparsity( - config: Config, device: str, buffers: Optional[List[torch.Tensor]] + config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Computes block sparsity for fixed-length sequences.""" n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n - + # Pre-allocate output tensors - full_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) - mask_block_cnt = torch.zeros((config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32) - full_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) - mask_block_idx = torch.zeros((config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32) - + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), device=device, dtype=torch.int32 + ) + # --- Identity Mask --- # All blocks are fully computed. if config.mask_mod_name == "identity": @@ -79,7 +92,7 @@ def _compute_sparsity( for q_block_idx in range(n_blocks_q): full_block_cnt[:, :, q_block_idx] = n_blocks_k full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks - + # --- Identity Partial Mask --- # All blocks are partially computed (masked). elif config.mask_mod_name == "identity_partial": @@ -104,26 +117,34 @@ def _compute_sparsity( k_block_indices = torch.arange(n_blocks_k, device=device) q_starts = q_block_indices * config.tile_m - q_ends = torch.minimum((q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device)) + q_ends = torch.minimum( + (q_block_indices + 1) * config.tile_m, torch.tensor(config.seqlen_q, device=device) + ) k_starts = k_block_indices * config.tile_n - k_ends = torch.minimum((k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device)) + k_ends = torch.minimum( + (k_block_indices + 1) * config.tile_n, torch.tensor(config.seqlen_k, device=device) + ) # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) q_starts, q_ends = q_starts.unsqueeze(1), q_ends.unsqueeze(1) k_starts, k_ends = k_starts.unsqueeze(0), k_ends.unsqueeze(0) - + offset = config.seqlen_k - config.seqlen_q if config.mask_mod_name == "causal": is_full = (k_ends - 1) <= (q_starts + offset) # min(k_pos) <= max(q_pos) AND not is_full. is_partial = (k_starts <= (q_ends - 1 + offset)) & ~is_full - - else: # sliding_window - window_size = getattr(config, 'window_size', 1024) - is_full = (k_ends - 1 <= q_starts + offset) & (k_starts >= q_ends - 1 + offset - (window_size - 1)) + + else: # sliding_window + window_size = getattr(config, "window_size", 1024) + is_full = (k_ends - 1 <= q_starts + offset) & ( + k_starts >= q_ends - 1 + offset - (window_size - 1) + ) # A block is EMPTY if no (q, k) pairs satisfy the constraint. - is_empty = (k_starts > q_ends - 1 + offset) | (k_ends - 1 < q_starts + offset - (window_size - 1)) + is_empty = (k_starts > q_ends - 1 + offset) | ( + k_ends - 1 < q_starts + offset - (window_size - 1) + ) # A block is PARTIAL if it's not empty and not full. is_partial = ~is_empty & ~is_full @@ -132,22 +153,24 @@ def _compute_sparsity( full_indices = k_block_indices[is_full[q_block_idx]] if len(full_indices) > 0: full_block_cnt[:, :, q_block_idx] = len(full_indices) - full_block_idx[:, :, q_block_idx, :len(full_indices)] = full_indices + full_block_idx[:, :, q_block_idx, : len(full_indices)] = full_indices partial_indices = k_block_indices[is_partial[q_block_idx]] if len(partial_indices) > 0: mask_block_cnt[:, :, q_block_idx] = len(partial_indices) - mask_block_idx[:, :, q_block_idx, :len(partial_indices)] = partial_indices - + mask_block_idx[:, :, q_block_idx, : len(partial_indices)] = partial_indices + elif config.mask_mod_name == "document": raise NotImplementedError("Block sparsity for document masking not yet implemented") return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + ## --------------------------------------------------------------------------- ## Variable-Length Sequence Kernels ## --------------------------------------------------------------------------- + def _compute_varlen_sparsity( config: Config, mask_mod_flex: Callable, @@ -159,7 +182,7 @@ def _compute_varlen_sparsity( assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" assert cu_seqlens_q.shape[0] == config.batch_size + 1 assert cu_seqlens_k.shape[0] == config.batch_size + 1 - + # In varlen, each sequence can have a different number of Q blocks. # We pad up to the maximum number of Q blocks in the batch. max_m_blocks = 0 @@ -173,62 +196,98 @@ def _compute_varlen_sparsity( max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n # Pre-allocate padded output tensors - full_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) - mask_block_cnt = torch.zeros((config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32) - full_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) - mask_block_idx = torch.zeros((config.batch_size, config.nheads, max_m_blocks, max_n_blocks), device=device, dtype=torch.int32) + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), + device=device, + dtype=torch.int32, + ) # Process each sequence in the batch individually for seq_idx in range(config.batch_size): seq_start_q = cu_seqlens_q[seq_idx].item() seq_end_q = cu_seqlens_q[seq_idx + 1].item() seq_len_q = seq_end_q - seq_start_q - + seq_start_k = cu_seqlens_k[seq_idx].item() seq_end_k = cu_seqlens_k[seq_idx + 1].item() seq_len_k = seq_end_k - seq_start_k - + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n # Global block indices are relative to the start of the entire batch tensor first_m_block_global = seq_start_q // config.tile_m first_n_block_global = seq_start_k // config.tile_n - + common_args = { - "full_block_cnt": full_block_cnt, "full_block_idx": full_block_idx, - "mask_block_cnt": mask_block_cnt, "mask_block_idx": mask_block_idx, - "seq_idx": seq_idx, "n_blocks_q": n_blocks_q, "n_blocks_k": n_blocks_k, - "seq_start_q": seq_start_q, "seq_end_q": seq_end_q, - "seq_start_k": seq_start_k, "seq_end_k": seq_end_k, + "full_block_cnt": full_block_cnt, + "full_block_idx": full_block_idx, + "mask_block_cnt": mask_block_cnt, + "mask_block_idx": mask_block_idx, + "seq_idx": seq_idx, + "n_blocks_q": n_blocks_q, + "n_blocks_k": n_blocks_k, + "seq_start_q": seq_start_q, + "seq_end_q": seq_end_q, + "seq_start_k": seq_start_k, + "seq_end_k": seq_end_k, "first_n_block_global": first_n_block_global, - "tile_m": config.tile_m, "tile_n": config.tile_n, "device": device + "tile_m": config.tile_m, + "tile_n": config.tile_n, + "device": device, } if config.mask_mod_name == "causal": _compute_causal_varlen_blocks(**common_args) elif config.mask_mod_name == "sliding_window": - window_size = getattr(config, 'window_size', 1024) + window_size = getattr(config, "window_size", 1024) _compute_sliding_window_varlen_blocks(**common_args, window_size=window_size) elif config.mask_mod_name == "identity": _compute_identity_varlen_blocks( - full_block_cnt, full_block_idx, seq_idx, - n_blocks_q, n_blocks_k, first_n_block_global, device + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, ) else: # Generic case relies on sampling the user-provided mask function _compute_generic_varlen_blocks( - **common_args, mask_mod_flex=mask_mod_flex, - seq_len_q=seq_len_q, seq_len_k=seq_len_k, - num_heads=config.nheads, nheads_kv=config.nheads_kv, + **common_args, + mask_mod_flex=mask_mod_flex, + seq_len_q=seq_len_q, + seq_len_k=seq_len_k, + num_heads=config.nheads, + nheads_kv=config.nheads_kv, ) - + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + def _classify_varlen_block( - m_local: int, n_local: int, seq_start_q: int, seq_end_q: int, - seq_start_k: int, seq_end_k: int, tile_m: int, tile_n: int, - is_full_fn: Callable, is_partial_fn: Callable + m_local: int, + n_local: int, + seq_start_q: int, + seq_end_q: int, + seq_start_k: int, + seq_end_k: int, + tile_m: int, + tile_n: int, + is_full_fn: Callable, + is_partial_fn: Callable, ) -> Tuple[bool, bool]: """Helper to classify a varlen block as full, partial, or empty.""" m_start_global = seq_start_q + m_local * tile_m @@ -241,20 +300,35 @@ def _classify_varlen_block( m_end_local = m_end_global - seq_start_q n_start_local = n_start_global - seq_start_k n_end_local = n_end_global - seq_start_k - + is_full = is_full_fn(m_start_local, m_end_local, n_start_local, n_end_local) - is_partial = is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full - + is_partial = ( + is_partial_fn(m_start_local, m_end_local, n_start_local, n_end_local) and not is_full + ) + # Any block that touches the sequence boundary is partial because it requires masking. at_boundary = (m_end_global > seq_end_q) or (n_end_global > seq_end_k) - + return is_full and not at_boundary, is_partial or (is_full and at_boundary) + def _compute_causal_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - seq_idx, n_blocks_q, n_blocks_k, - seq_start_q, seq_end_q, seq_start_k, seq_end_k, - first_n_block_global, tile_m, tile_n, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + device, + **kwargs, ): """Computes causal block sparsity for a single varlen sequence.""" is_full_fn = lambda m_start, m_end, n_start, n_end: (m_start >= n_end - 1) @@ -264,8 +338,16 @@ def _compute_causal_varlen_blocks( full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): is_full, is_partial = _classify_varlen_block( - m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, - tile_m, tile_n, is_full_fn, is_partial_fn + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, ) n_block_global = first_n_block_global + n_local if is_full: @@ -275,98 +357,157 @@ def _compute_causal_varlen_blocks( if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + def _compute_sliding_window_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - seq_idx, n_blocks_q, n_blocks_k, - seq_start_q, seq_end_q, seq_start_k, seq_end_k, - first_n_block_global, tile_m, tile_n, window_size, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + first_n_block_global, + tile_m, + tile_n, + window_size, + device, + **kwargs, ): """Computes sliding window block sparsity for a single varlen sequence.""" - is_full_fn = lambda m_start, m_end, n_start, n_end: \ - (n_end - 1 <= m_start) and (n_start >= m_start - window_size + 1) - is_partial_fn = lambda m_start, m_end, n_start, n_end: \ - not ((n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1)) + is_full_fn = lambda m_start, m_end, n_start, n_end: (n_end - 1 <= m_start) and ( + n_start >= m_start - window_size + 1 + ) + is_partial_fn = lambda m_start, m_end, n_start, n_end: not ( + (n_start > m_end - 1) or (n_end - 1 < m_start - window_size + 1) + ) for m_local in range(n_blocks_q): full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): is_full, is_partial = _classify_varlen_block( - m_local, n_local, seq_start_q, seq_end_q, seq_start_k, seq_end_k, - tile_m, tile_n, is_full_fn, is_partial_fn + m_local, + n_local, + seq_start_q, + seq_end_q, + seq_start_k, + seq_end_k, + tile_m, + tile_n, + is_full_fn, + is_partial_fn, ) n_block_global = first_n_block_global + n_local if is_full: full_blocks.append(n_block_global) elif is_partial: partial_blocks.append(n_block_global) - + if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) + def _compute_identity_varlen_blocks( - full_block_cnt, full_block_idx, seq_idx, n_blocks_q, - n_blocks_k, first_n_block_global, device, **kwargs + full_block_cnt, + full_block_idx, + seq_idx, + n_blocks_q, + n_blocks_k, + first_n_block_global, + device, + **kwargs, ): """Computes identity (all-attend) block sparsity for a single varlen sequence.""" n_blocks_global = torch.arange( - first_n_block_global, first_n_block_global + n_blocks_k, - device=device, dtype=torch.int32 + first_n_block_global, first_n_block_global + n_blocks_k, device=device, dtype=torch.int32 ) for m_local in range(n_blocks_q): full_block_cnt[seq_idx, :, m_local] = n_blocks_k full_block_idx[seq_idx, :, m_local, :n_blocks_k] = n_blocks_global + def _compute_generic_varlen_blocks( - full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, - mask_mod_flex, seq_idx, num_heads, n_blocks_q, n_blocks_k, - seq_len_q, seq_len_k, first_n_block_global, - tile_m, tile_n, nheads_kv, device, **kwargs + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + mask_mod_flex, + seq_idx, + num_heads, + n_blocks_q, + n_blocks_k, + seq_len_q, + seq_len_k, + first_n_block_global, + tile_m, + tile_n, + nheads_kv, + device, + **kwargs, ): """Generic sampling-based block classification for a varlen sequence.""" qhead_per_kvhead = num_heads // nheads_kv - + for h_q in range(num_heads): h_kv = h_q // qhead_per_kvhead for m_local in range(n_blocks_q): m_start_local = m_local * tile_m m_end_local = min((m_local + 1) * tile_m, seq_len_q) - + full_blocks, partial_blocks = [], [] for n_local in range(n_blocks_k): n_start_local = n_local * tile_n n_end_local = min((n_local + 1) * tile_n, seq_len_k) - + # Sample points within the block (corners and center) to classify it. # Coordinates are sequence-local, as required by mask_mod_flex. sample_positions = [ - (m_start_local, n_start_local), (m_start_local, n_end_local - 1), - (m_end_local - 1, n_start_local), (m_end_local - 1, n_end_local - 1), + (m_start_local, n_start_local), + (m_start_local, n_end_local - 1), + (m_end_local - 1, n_start_local), + (m_end_local - 1, n_end_local - 1), ((m_start_local + m_end_local) // 2, (n_start_local + n_end_local) // 2), ] - + unmasked_count = sum( - 1 for q_pos, k_pos in sample_positions + 1 + for q_pos, k_pos in sample_positions if mask_mod_flex(seq_idx, h_q, q_pos, k_pos, seq_len_q, seq_len_k) ) - + n_block_global = first_n_block_global + n_local - if unmasked_count == len(sample_positions): # All samples unmasked -> full + if unmasked_count == len(sample_positions): # All samples unmasked -> full full_blocks.append(n_block_global) - elif unmasked_count > 0: # Some unmasked -> partial + elif unmasked_count > 0: # Some unmasked -> partial partial_blocks.append(n_block_global) - + if full_blocks: full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) - full_block_idx[seq_idx, h_q, m_local, :len(full_blocks)] = torch.tensor(full_blocks, device=device) + full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, device=device + ) if partial_blocks: mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, h_q, m_local, :len(partial_blocks)] = torch.tensor(partial_blocks, device=device) \ No newline at end of file + mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, device=device + ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4922a1534c9..b49a693dfcd 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -32,12 +32,17 @@ from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) from flash_attn.cute.fast_math import FastDivmod class FlashAttentionForwardBase: - arch: int = 80 def __init__( @@ -56,7 +61,7 @@ def __init__( Q_in_regs: bool = False, score_mod: Optional[cutlass.Constexpr] = None, mask_mod: Optional[cutlass.Constexpr] = None, - has_buffers: bool = False, + has_aux_tensors: bool = False, ): """Initializes the configuration for a flash attention kernel. @@ -73,9 +78,9 @@ def __init__( :type num_threads: int :param is_causal: is causal :param score_mod: A callable that takes the attention scores and applies a modification. - Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any`` + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. - Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, buffers) -> Boolean`` + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -99,15 +104,22 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if const_expr(has_buffers): + if const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @staticmethod def can_implement( - dtype, head_dim, head_dim_v, tile_m, tile_n, num_stages, num_threads, is_causal, - Q_in_regs=False + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, ) -> bool: """Check if the kernel can be implemented with the given parameters. @@ -142,7 +154,9 @@ def can_implement( smem_usage_Q = tile_m * head_dim * 2 smem_usage_K = tile_n * head_dim * num_stages * 2 smem_usage_V = tile_n * head_dim_v * num_stages * 2 - smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") @@ -186,22 +200,34 @@ def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: Q/K/V # /////////////////////////////////////////////////////////////////////////////// - sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = self._get_smem_layout_atom() + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( + self._get_smem_layout_atom() + ) self.sQ_layout = cute.tile_to_shape( - sQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1), + sQ_layout_atom, + (self.tile_m, self.tile_hdim), + (0, 1), ) self.sK_layout = cute.tile_to_shape( - sK_layout_atom, (self.tile_n, self.tile_hdim, self.num_stages), (0, 1, 2), + sK_layout_atom, + (self.tile_n, self.tile_hdim, self.num_stages), + (0, 1, 2), ) self.sV_layout = cute.tile_to_shape( - sV_layout_atom, (self.tile_n, self.tile_hdimv, self.num_stages), (0, 1, 2), + sV_layout_atom, + (self.tile_n, self.tile_hdimv, self.num_stages), + (0, 1, 2), ) self.sO_layout = cute.tile_to_shape( - sO_layout_atom, (self.tile_m, self.tile_hdimv), (0, 1), + sO_layout_atom, + (self.tile_m, self.tile_hdimv), + (0, 1), ) if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( - sP_layout_atom, (self.tile_m, self.tile_n), (0, 1), + sP_layout_atom, + (self.tile_m, self.tile_n), + (0, 1), ) else: self.sP_layout = None @@ -220,28 +246,38 @@ def _setup_attributes(self): ) # atom_universal_copy: universal copy atom for O store atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, ) # tQ_layout and tK_layout: thread layout for QK load tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems - assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" - assert self.num_producer_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) tQ_layout = cute.make_ordered_layout( - (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) tK_layout = cute.make_ordered_layout( - (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we load Q assert self.tile_m % tQ_layout.shape[0] == 0 tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems tV_layout = cute.make_ordered_layout( - (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # TODO: need a different layout for O if O dtype is not the same as V dtype # tO_layout: thread layout for O store tO_layout = cute.make_ordered_layout( - (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.tile_m % tO_layout.shape[0] == 0 @@ -304,7 +340,9 @@ def epilogue( rO = cute.make_fragment_like(acc_O, self.dtype) rO.store(acc_O.load().to(self.dtype)) # Make sure all threads have finished reading V - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) @@ -313,7 +351,9 @@ def epilogue( cute.copy(smem_copy_atom_O, taccOrO, taccOsO) cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) - pack_gqa = PackGQA(self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead + ) # Write LSE from rmem -> gmem if const_expr(mLSE is not None): @@ -336,7 +376,10 @@ def epilogue( # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): - if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]: + if ( + t0accOcO[m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] + ): taccOgLSE[m, 0] = lse[m] else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) @@ -353,19 +396,28 @@ def epilogue( if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) - cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) store_O, _, _ = copy_utils.tma_get_copy_fn( tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) store_O() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) else: - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) tOrO = cute.make_fragment_like(tOsO, self.dtype) @@ -379,12 +431,17 @@ def epilogue( tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]: + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, ) else: pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @@ -452,7 +509,9 @@ def load_K( cute.copy( gmem_tiled_copy, tKgK[None, n, None, block], - tKsK[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + tKsK[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. @@ -483,7 +542,11 @@ def load_V( if const_expr(need_predicates or not is_even_n_smem_v): for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked - if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.tile_n: + if ( + is_even_n_smem_v + or n < cute.size(tVsV.shape[1]) - 1 + or tVcV[0, n, 0][0] < self.tile_n + ): predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None if const_expr(need_predicates): seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] @@ -491,11 +554,15 @@ def load_V( predicate = cute.make_fragment_like(tVpV[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): - predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n + predicate[i, k] = ( + tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True + ) and predicate_n cute.copy( gmem_tiled_copy, tVgV[None, n, None, block], - tVsV[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + tVsV[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], pred=predicate, ) else: @@ -508,7 +575,6 @@ def load_V( class FlashAttentionForwardSm80(FlashAttentionForwardBase): - def _get_smem_layout_atom(self): sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) sK_layout_atom = sQ_layout_atom @@ -564,7 +630,7 @@ def __call__( window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, - buffers=None, + aux_tensors=None, ): """Configures and launches the flash attention kernel. @@ -572,7 +638,9 @@ def __call__( (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ assert learnable_sink is None, "Learnable sink is not supported in this kernel" - self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))) + self._check_type( + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size self.num_producer_threads = self.num_threads @@ -583,9 +651,18 @@ def __call__( self._setup_attributes() SharedStorage = self._get_shared_storage_cls() # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) for t in (mQ, mK, mV, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) + for t in (mQ, mK, mV, mO) + ] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) # grid_dim: (m_block, num_head, batch_size) grid_dim = ( @@ -605,8 +682,10 @@ def __call__( softmax_scale = Float32(softmax_scale) fastdiv_mods = None - if const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -634,7 +713,7 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -667,7 +746,7 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, - buffers=None, + aux_tensors=None, fastdiv_mods=None, ): # Thread index, block index @@ -675,8 +754,12 @@ def kernel( m_block, num_head, batch_size = cute.arch.block_idx() block_info = BlockInfo( - self.tile_m, self.tile_n, self.is_causal, self.is_local, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) @@ -735,10 +818,12 @@ def kernel( # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_QK = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, ) smem_copy_atom_V = cute.make_copy_atom( - warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self.dtype, ) smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) @@ -773,29 +858,49 @@ def kernel( tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) - softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) softmax.reset() # group parameters for compute_one_n_block mma_params = SimpleNamespace( - thr_mma_qk=thr_mma_qk, thr_mma_pv=thr_mma_pv, - tSrQ=tSrQ, tSrK=tSrK, tOrVt=tOrVt, acc_O=acc_O, + thr_mma_qk=thr_mma_qk, + thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, ) smem_copy_params = SimpleNamespace( smem_thr_copy_Q=smem_thr_copy_Q, smem_thr_copy_K=smem_thr_copy_K, smem_thr_copy_V=smem_thr_copy_V, - tSsQ=tSsQ, tSsK=tSsK, tOsVt=tOsVt, + tSsQ=tSsQ, + tSsK=tSsK, + tOsVt=tOsVt, + ) + load_K = partial( + self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k + ) + load_V = partial( + self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k ) - load_K = partial(self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, - seqlen=seqlen.seqlen_k) - load_V = partial(self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, - seqlen=seqlen.seqlen_k) compute_one_n_block = partial( - self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, - softmax=softmax, load_K=load_K, load_V=load_V, score_mod=self.score_mod, - batch_idx=batch_size, head_idx=num_head, m_block=m_block, buffers=buffers, + self.compute_one_n_block, + mma_params=mma_params, + smem_copy_params=smem_copy_params, + softmax=softmax, + load_K=load_K, + load_V=load_V, + score_mod=self.score_mod, + batch_idx=batch_size, + head_idx=num_head, + m_block=m_block, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -826,11 +931,11 @@ def preprocess_Q(): for stage in cutlass.range_constexpr(self.num_stages): if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: - load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(stage < self.num_stages - 1): if stage == 0 or n_block - stage >= 0: - load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(not self.Q_in_regs): preprocess_Q() @@ -844,20 +949,33 @@ def preprocess_Q(): # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( - self.tile_m, self.tile_n, seqlen.seqlen_q, seqlen.seqlen_k, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + seqlen.seqlen_q, + seqlen.seqlen_k, + window_size_left, + window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( - mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, - mask_causal=self.is_causal, mask_local=self.is_local, + mask.apply_mask, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, ) # First iteration with seqlen masking smem_pipe_read = Int32(0) smem_pipe_write = Int32(self.num_stages - 1) - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, - check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + is_first_n_block=True, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking @@ -867,13 +985,20 @@ def preprocess_Q(): ) for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 2 - n_tile - compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=False)) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range(n_block, unroll=1): - compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) + compute_one_n_block( + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # TODO: local @@ -888,8 +1013,19 @@ def preprocess_Q(): # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, - gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + None, + tiled_mma_pv, + tidx, + m_block, + num_head, + batch_size, ) @cute.jit @@ -907,7 +1043,7 @@ def compute_one_n_block( batch_idx: cutlass.Int32, head_idx: cutlass.Int32, m_block: cutlass.Int32, - buffers=None, + aux_tensors=None, fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, @@ -918,6 +1054,7 @@ def compute_one_n_block( This function provides different variants for processing the first n block versus subsequent blocks. """ + def sync(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) cute.arch.barrier() @@ -927,18 +1064,29 @@ def sync(): acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S sync() + # need predicates for the first tile def load_V_next(): if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: - load_V(n_block - self.num_stages + 1, smem_pipe_write, - need_predicates=is_first_n_block and self.num_stages == 1) + load_V( + n_block - self.num_stages + 1, + smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1, + ) cute.arch.cp_async_commit_group() + load_V_next() sm80_utils.gemm( - mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, + mma_params.thr_mma_qk, + acc_S, + mma_params.tSrQ, + mma_params.tSrK, smem_copy_params.tSsQ, - smem_copy_params.tSsK[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], - smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, + smem_copy_params.tSsK[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_Q, + smem_copy_params.smem_thr_copy_K, # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) @@ -951,15 +1099,17 @@ def load_V_next(): acc_S, n_block, softmax_scale=softmax.softmax_scale, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) smem_pipe_write = self.advance_pipeline(smem_pipe_write) + def load_K_next(): if n_block - self.num_stages >= 0: load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) cute.arch.cp_async_commit_group() + # wait for smem tile V for O if const_expr(self.num_stages == 1): sync() @@ -975,8 +1125,13 @@ def load_K_next(): sync() load_K_next() sm80_utils.gemm_rs( - mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, - smem_copy_params.tOsVt[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], + mma_params.thr_mma_pv, + mma_params.acc_O, + tOrP, + mma_params.tOrVt, + smem_copy_params.tOsVt[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], smem_copy_params.smem_thr_copy_V, # hook_fn=load_K_next, ) @@ -985,7 +1140,6 @@ def load_K_next(): class FlashAttentionForwardSm90(FlashAttentionForwardBase): - arch = 90 def __init__( @@ -998,21 +1152,18 @@ def __init__( super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap self.mma_pv_is_rs = mma_pv_is_rs - def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim - ), - self.dtype + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv ), - self.dtype + self.dtype, ) sO_layout_atom = sV_layout_atom if not self.mma_pv_is_rs: @@ -1020,7 +1171,7 @@ def _get_smem_layout_atom(self): sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n ), - self.dtype + self.dtype, ) else: sP_layout_atom = None @@ -1044,7 +1195,9 @@ def _get_tiled_mma(self): Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, ) tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -1054,7 +1207,7 @@ def _get_tiled_mma(self): Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM + a_source=warpgroup.OperandSource.RMEM, ) return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs @@ -1066,8 +1219,8 @@ def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], alignment] for layout, alignment in zip( - (self.sQ_layout, self.sK_layout, self.sV_layout), - (sQ_alignment, sK_alignment, sV_alignment) + (self.sQ_layout, self.sK_layout, self.sV_layout), + (sQ_alignment, sK_alignment, sV_alignment), ) ] cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) @@ -1122,7 +1275,7 @@ def __call__( full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - buffers: Optional[list[cute.Tensor]] = None, + aux_tensors: Optional[list] = None, ): """Configures and launches the flash attention kernel. @@ -1131,14 +1284,22 @@ def __call__( """ self._check_type( - *(t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) ) # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] @@ -1164,10 +1325,20 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_block_sparsity = const_expr(mask_block_cnt is not None and full_block_cnt is not None) - self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) - self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0) - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + self.use_block_sparsity = const_expr( + mask_block_cnt is not None and full_block_cnt is not None + ) + self.use_scheduler_barrier = ( + (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_mma_warp_groups == 2) + ) + self.use_tma_Q = self.arch >= 90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = ( + self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + ) # TODO: rescale_O_before_gemm self._setup_attributes() # TODO: we prob don't need most of what's in _setup_attributes @@ -1189,16 +1360,50 @@ def __call__( SharedStorage = self._get_shared_storage_cls() if const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) if const_expr(mLSE is not None): - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) # TMA gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() @@ -1215,39 +1420,53 @@ def __call__( tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.tile_m, self.tile_hdim), # No mcast + gmem_tiled_copy_Q, + mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast ) tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), - 1 # No mcast for now + 1, # No mcast for now ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), - 1 # No mcast for now + 1, # No mcast for now ) tma_atom_O, tma_tensor_O = None, None if const_expr(self.use_tma_O): tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_O, mO, self.sO_layout, (self.tile_m, self.tile_hdimv), # No mcast + gmem_tiled_copy_O, + mO, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: - TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], - total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=(self.tile_m, self.tile_n), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, @@ -1274,8 +1493,10 @@ def __call__( window_size_right = Int32(window_size_right) fastdiv_mods = None - if const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -1319,7 +1540,7 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -1369,7 +1590,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], - buffers=Optional[list[cute.Tensor]], + aux_tensors=Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1392,7 +1613,9 @@ def kernel( cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread + ) pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group ) @@ -1421,7 +1644,9 @@ def kernel( if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: - sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = utils.transpose_view(sV) sP = None @@ -1431,19 +1656,29 @@ def kernel( sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) block_info = BlockInfo( - self.tile_m, self.tile_n, self.is_causal, self.is_local, - window_size_left, window_size_right, + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + SeqlenInfoQK, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( - AttentionMask, self.tile_m, self.tile_n, - window_size_left=window_size_left, window_size_right=window_size_right, + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -1509,7 +1744,7 @@ def kernel( full_block_idx, mask_block_cnt, mask_block_idx, - buffers, + aux_tensors, fastdiv_mods, ) @@ -1545,11 +1780,13 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: + # if work_tile.is_valid_tile: m_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) @@ -1561,12 +1798,15 @@ def load( ) # TODO: mcast # TODO check warp_idx if we have 128 producer threads - load_K, _, _ = copy_utils.tma_get_copy_fn(tma_atom_K, 0, cute.make_layout(1), gK, sK) + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) - load_V, _, _ = copy_utils.tma_get_copy_fn(tma_atom_V, 0, cute.make_layout(1), gV, sV) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: @@ -1575,7 +1815,9 @@ def load( n_block = n_block_max - 1 pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) @@ -1614,22 +1856,26 @@ def load( curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - + if const_expr(not self.intra_wg_overlap): if curr_mask_block_cnt > 0: # First mask block - load with Q n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1] pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) + ) load_K(src_idx=n_block_mask, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask, producer_state=kv_producer_state) kv_producer_state.advance() - + # Remaining mask blocks for i in cutlass.range(1, curr_mask_block_cnt): n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] @@ -1638,17 +1884,23 @@ def load( pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask, producer_state=kv_producer_state) kv_producer_state.advance() - + if curr_full_block_cnt > 0: n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] - if curr_mask_block_cnt == 0: + if curr_mask_block_cnt == 0: # must load Q if not loaded in mask loop pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier( + kv_producer_state + ) + ) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_full, producer_state=kv_producer_state) @@ -1666,28 +1918,32 @@ def load( pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_full, producer_state=kv_producer_state) kv_producer_state.advance() - + else: # ========================================== # Overlap path # ========================================== - + # Load Q with the first K block (whether mask or full) n_block_first = -1 if curr_mask_block_cnt > 0: n_block_first = curr_mask_block_idx[curr_mask_block_cnt - 1] elif curr_full_block_cnt > 0: n_block_first = curr_full_block_idx[curr_full_block_cnt - 1] - + if n_block_first >= 0: pipeline_k.producer_acquire( kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] if const_expr(self.use_tma_Q) else 0 + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, ) if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_Q( + tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) + ) load_K(src_idx=n_block_first, producer_state=kv_producer_state) - + if curr_mask_block_cnt > 0: # Staggered loading for remaining mask blocks for i in cutlass.range(1, curr_mask_block_cnt): @@ -1698,8 +1954,10 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_mask, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev) - + load_V( + src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev + ) + # Handle transition from mask to full blocks if curr_full_block_cnt > 0: # Load first full block K, last mask block V @@ -1710,14 +1968,16 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + load_V( + src_idx=n_block_mask_last, producer_state=kv_producer_state_prev + ) else: # No full blocks, just load last mask block V n_block_mask_last = curr_mask_block_idx[0] pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state) kv_producer_state.advance() - + if curr_full_block_cnt > 0: # Staggered loading for remaining full blocks ( for j in cutlass.range(1, curr_full_block_cnt): @@ -1728,8 +1988,10 @@ def load( pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_full_prev, producer_state=kv_producer_state_prev) - + load_V( + src_idx=n_block_full_prev, producer_state=kv_producer_state_prev + ) + # Load last full block V n_block_full_last = curr_full_block_idx[0] pipeline_v.producer_acquire(kv_producer_state) @@ -1775,7 +2037,7 @@ def mma( full_block_idx: Optional[cute.Tensor], mask_block_cnt: Optional[cute.Tensor], mask_block_idx: Optional[cute.Tensor], - buffers: Optional[list[cute.Tensor]], + aux_tensors: Optional[list], fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -1820,11 +2082,15 @@ def mma( mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) mma_one_n_block_all = partial( - self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, mma_qk_fn=mma_qk_fn, tiled_mma_pv_rs=tiled_mma_pv_rs, - pipeline_k=pipeline_k, pipeline_v=pipeline_v, - acc_O=acc_O, tOrP=tOrP, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, smem_copy_params=smem_copy_params, check_inf=True, ) @@ -1836,8 +2102,12 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - softmax = Softmax.create(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale) - + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + process_first_half_block = partial( self.first_half_block_overlap, mma_qk_fn=mma_qk_fn, @@ -1852,7 +2122,7 @@ def mma( mma_pv_fn=mma_pv_fn, ) while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: + # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) m_block, head_idx, batch_idx = work_tile.tile_idx @@ -1866,18 +2136,18 @@ def mma( thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, - buffers=buffers, + aux_tensors=aux_tensors, ) score_mod_fn = None if const_expr(self.score_mod is not None): score_mod_fn = partial( self.apply_score_mod, - thr_mma_qk=thr_mma_qk, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=m_block, + thr_mma_qk, + batch_idx, + head_idx, + m_block, softmax_scale=softmax_scale, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( @@ -1887,7 +2157,9 @@ def mma( ) # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): - pack_gqa = PackGQA(self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) @@ -1906,10 +2178,9 @@ def mma( # We also need masking on S if it's causal, for the last several blocks. # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False - - + # ========================================== - # MAINLOOP + # MAINLOOP # ========================================== if const_expr(not self.use_block_sparsity): # ========================================== @@ -1921,6 +2192,7 @@ def mma( n_block=n_block_max - 1, kv_consumer_state=kv_consumer_state, mask_fn=mask_fn, + score_mod_fn=score_mod_fn, is_first_block=True, ) # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter @@ -1943,7 +2215,9 @@ def mma( seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, @@ -1984,7 +2258,7 @@ def mma( O_should_accumulate = True else: self.warp_scheduler_barrier_arrive() - + else: # ========================================== # Block sparsity @@ -2069,6 +2343,7 @@ def mma( n_block=mask_n_block, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -2091,6 +2366,7 @@ def mma( n_block=full_n_block, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None), + score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -2124,8 +2400,7 @@ def mma( if curr_mask_block_cnt + curr_full_block_cnt == 0: softmax.reset() - acc_O.fill(0.0) - + acc_O.fill(0.0) sink_val = None if const_expr(learnable_sink is not None): @@ -2148,8 +2423,19 @@ def mma( # Epilogue # /////////////////////////////////////////////////////////////////////////////// self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, - gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, ) tile_scheduler.advance_to_next_work() @@ -2177,7 +2463,7 @@ def first_half_block_overlap( # Apply score modification if present if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S=acc_S, n_block=n_block) + score_mod_fn(acc_S, n_block=n_block) # Apply mask; mask_seqlen always True for first block # Caveat: if full block further right than mask block, seqlen masking is redundant; @@ -2203,7 +2489,7 @@ def first_half_block_overlap( cute.arch.sync_warp() return kv_consumer_state - + @cute.jit def last_half_block_overlap( self, @@ -2213,14 +2499,14 @@ def last_half_block_overlap( zero_init: bool, ): """Processes the final PV GEMM when using intra-warpgroup-overlap""" - + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) pipeline_v.consumer_release(kv_consumer_state) - + # Advance state for next iteration kv_consumer_state.advance() - + return kv_consumer_state @cute.jit @@ -2248,17 +2534,19 @@ def mma_one_n_block( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) - + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) if const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) - + mask_fn(acc_S=acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) # tOrP.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of @@ -2310,19 +2598,21 @@ def mma_one_n_block_intrawg_overlap( self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) - + # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block) - if const_expr(mask_fn is not None): - mask_fn(acc_S, n_block=n_block) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) - + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) - tOrP_cur = tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of @@ -2358,7 +2648,7 @@ def apply_score_mod( acc_S, n_block, softmax_scale, - buffers=Optional[list[cute.Tensor]], + aux_tensors: Optional[list] = None, fastdiv_mods=None, ): # Prepare index tensor @@ -2375,7 +2665,7 @@ def apply_score_mod( softmax_scale, self.vec_size, self.qk_acc_dtype, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -2384,8 +2674,10 @@ def apply_score_mod( def warp_scheduler_barrier_sync(self): if const_expr(self.use_scheduler_barrier): cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), - number_of_threads=2 * self.num_threads_per_warp_group + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, ) def warp_scheduler_barrier_arrive(self): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 7bf1480bbae..83755896d51 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -37,7 +37,14 @@ from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.fast_math import FastDivmod -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) # class NamedBarrierFwd(enum.IntEnum): @@ -50,7 +57,6 @@ class FlashAttentionForwardSm100: - arch = 100 def __init__( @@ -66,7 +72,7 @@ def __init__( n_block_size: int = 128, is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, - has_buffers: cutlass.Constexpr = False, + has_aux_tensors: cutlass.Constexpr = False, ): # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -96,9 +102,11 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = pack_gqa if pack_gqa: - assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + assert m_block_size % self.qhead_per_kvhead == 0, ( + "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + ) self.score_mod = score_mod - if cutlass.const_expr(has_buffers): + if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 2 @@ -133,11 +141,16 @@ def __init__( ) self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 - self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 + self.tmem_o_offset = [ + self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded + for i in range(self.q_stage) + ] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS self.tmem_s_to_p_offset = self.n_block_size // 2 - self.tmem_p_offset = [self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)] # 0, 128 + self.tmem_p_offset = [ + self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) + ] # 0, 128 # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset @@ -182,8 +195,14 @@ def _setup_attributes(self): # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. - self.uneven_kv_smem = self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 - self.uneven_kv_smem_offset = self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 + self.uneven_kv_smem = ( + self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + ) + self.uneven_kv_smem_offset = ( + self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + if self.uneven_kv_smem + else 0 + ) assert self.uneven_kv_smem_offset % 1024 == 0 @cute.jit @@ -204,7 +223,9 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - buffers = None # Not typing for now since conversion behaves a lil funny + aux_tensors: Optional[ + list + ] = None, # Not typing for now since conversion behaves a lil funny ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -226,8 +247,14 @@ def __call__( self.v_dtype = mV.element_type self.o_dtype = mO.element_type # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) @@ -240,7 +267,11 @@ def __call__( for t in (mK, mV) ] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if const_expr(mLSE is not None) + else None + ) # (s, d, h, b) -> (d, s, h, b) V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) @@ -266,7 +297,9 @@ def __call__( self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None # This can be tuned self.e2e_freq = 16 - if const_expr(self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa): + if const_expr( + self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa + ): self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 cta_group = tcgen05.CtaGroup.ONE @@ -300,39 +333,108 @@ def __call__( self.epi_tile = self.mma_tiler_pv[:2] sQ_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, + tiled_mma_qk, + self.mma_tiler_qk, + self.q_dtype, + self.q_stage, ) sK_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, + tiled_mma_qk, + self.mma_tiler_qk, + self.k_dtype, + self.kv_stage, ) tP_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.acc_stage, + tiled_mma_pv, + self.mma_tiler_pv, + self.q_dtype, + self.acc_stage, ) sV_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage, + tiled_mma_pv, + self.mma_tiler_pv, + self.v_dtype, + self.kv_stage, ) sO_layout = sm100_utils_basic.make_smem_layout_epi( - self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, ) if const_expr(not self.same_hdim_kv_padded): # sK and sV are using the same physical smem so we need to adjust the stride so that they line up - stride_sK = const_expr(max(sK_layout.outer.stride[-1], 0)) # take max to turn tuple to Int32 + stride_sK = const_expr( + max(sK_layout.outer.stride[-1], 0) + ) # take max to turn tuple to Int32 stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) - stage_stride = const_expr(max(stride_sK, stride_sV) if not self.uneven_kv_smem else (stride_sK + stride_sV) // 2) - sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) - sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) + stage_stride = const_expr( + max(stride_sK, stride_sV) + if not self.uneven_kv_smem + else (stride_sK + stride_sV) // 2 + ) + sK_layout = cute.make_composed_layout( + sK_layout.inner, + 0, + cute.make_layout( + (*sK_layout.outer.shape[:-1], self.kv_stage), + stride=(*sK_layout.outer.stride[:-1], stage_stride), + ), + ) + sV_layout = cute.make_composed_layout( + sV_layout.inner, + 0, + cute.make_layout( + (*sV_layout.outer.shape[:-1], self.kv_stage), + stride=(*sV_layout.outer.stride[:-1], stage_stride), + ), + ) if const_expr(self.pack_gqa): - shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) - stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) - mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) - shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) - stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) - mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) if const_expr(mLSE is not None): - shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) - stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) - mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) @@ -386,11 +488,14 @@ def __call__( universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.o_dtype.width atom_universal_copy = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), self.o_dtype, num_bits_per_copy=universal_copy_bits, + cute.nvgpu.CopyUniversalOp(), + self.o_dtype, + num_bits_per_copy=universal_copy_bits, ) tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems tO_layout = cute.make_ordered_layout( - (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), order=(1, 0), + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), + order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.m_block_size % tO_layout.shape[0] == 0 @@ -412,15 +517,25 @@ def __call__( if const_expr(self.is_causal or self.is_local): TileScheduler = SingleTileLPTScheduler else: - TileScheduler = SingleTileScheduler if const_expr(not self.is_persistent) else StaticPersistentTileScheduler + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_persistent) + else StaticPersistentTileScheduler + ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], mQ.shape[1], mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 - total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=self.cta_tiler[:2], mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, @@ -493,8 +608,10 @@ class SharedStorage: window_size_right = Int32(window_size_right) fastdiv_mods = None - if cutlass.const_expr(buffers is not None): - seqlen_q = cute.size(mQ.shape[0]) // (self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + if cutlass.const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmod.create(seqlen_q) seqlen_k_divmod = FastDivmod.create(seqlen_k) @@ -530,7 +647,7 @@ class SharedStorage: tiled_mma_qk, tiled_mma_pv, tile_sched_params, - buffers, + aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, @@ -573,8 +690,8 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, - buffers = None, - fastdiv_mods = (None, None), + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): """The device kernel implementation of the Fused Multi-Head Attention. @@ -609,28 +726,55 @@ def kernel( if warp_idx == 1: # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) + ) if warp_idx == 2: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4 + ) if warp_idx == 3: if const_expr(self.s0_s1_barrier): for i in cutlass.range_constexpr(8): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE + ) if warp_idx == 4: for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_full_offset + i, + cute.arch.WARP_SIZE * len(self.correction_warp_ids), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_empty_offset + i, + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), + ) if warp_idx == 5: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) - cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, + cute.arch.WARP_SIZE + * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) + ) if warp_idx == 6: for i in cutlass.range_constexpr(2): - cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_2_offset + i, cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_2_offset + i, + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), + ) if warp_idx == 7: cute.arch.mbarrier_init( mbar_ptr + self.mbar_tmem_dealloc_offset, @@ -668,43 +812,60 @@ def kernel( tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. - tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, - assumed_align=16) + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) - tStSs = tuple(cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) - for stage in range(2)) - tOtOs = tuple(cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) - for stage in range(self.q_stage)) + tStSs = tuple( + cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2) + ) + tOtOs = tuple( + cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage) + ) tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] - tOrPs = [cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], - tOrP.layout, - ) for stage in range(2)] + tOrPs = [ + cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], + tOrP.layout, + ) + for stage in range(2) + ] block_info = BlockInfo( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) - self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, - window_size_left, window_size_right, + self.cta_tiler[0], + self.cta_tiler[1], + self.is_causal, + self.is_local, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], - seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], - mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( - AttentionMask, self.m_block_size, self.n_block_size, - window_size_left=window_size_left, window_size_right=window_size_right, + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) @@ -745,7 +906,7 @@ def kernel( # MMA # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) # Alloc tmem buffer tmem_alloc_cols = Int32(self.tmem_alloc_cols) @@ -787,7 +948,9 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - self.epilogue_s2g(mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls) + self.epilogue_s2g( + mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls + ) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -808,7 +971,7 @@ def kernel( SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, TileSchedulerCls=TileSchedulerCls, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -817,8 +980,9 @@ def kernel( softmax_loop( stage=stage, tStSi=cute.make_tensor( - tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), - tStS.layout + tStS.iterator + + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), + tStS.layout, ), ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -880,7 +1044,6 @@ def load( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.kv_stage @@ -893,7 +1056,9 @@ def load( mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) - head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) if const_expr(mPageTable is None): if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] @@ -905,8 +1070,12 @@ def load( else: # Need to keep batch coord None since we'll index into it with page idx mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] - gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) - gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)) + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) + ) tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) @@ -929,26 +1098,40 @@ def load( ) load_Q = partial( - self.load_Q, load_Q_fn, - mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, + self.load_Q, + load_Q_fn, + mbar_ptr + self.mbar_load_q_full_offset, + mbar_ptr + self.mbar_load_q_empty_offset, phase=q_producer_phase, ) # We have to use mbarrier directly in the load for KV instead of replying on # pipeline_kv, because we could have different number of TMA bytes for K and V load_K = partial( - self.load_KV, tma_atom_K, tKgK, tKsK, - mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.load_KV, + tma_atom_K, + tKgK, + tKsK, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="K", ) load_V = partial( - self.load_KV, tma_atom_V, tVgV, tVsV, - mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + self.load_KV, + tma_atom_V, + tVgV, + tVsV, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="V", ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - page_idx = mPageTable[batch_idx, n_block_max - 1] if const_expr(mPageTable is not None) else None + page_idx = ( + mPageTable[batch_idx, n_block_max - 1] + if const_expr(mPageTable is not None) + else None + ) load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() if const_expr(self.q_stage == 2): @@ -958,7 +1141,9 @@ def load( kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i - page_idx = mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + page_idx = ( + mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + ) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() @@ -1005,7 +1190,7 @@ def mma( self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], - zero_init=True + zero_init=True, ) for stage in range(2) ] @@ -1036,7 +1221,9 @@ def mma( for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1049,7 +1236,9 @@ def mma( # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) sK_cur = sK[None, None, None, mma_kv_consumer_state.index] if const_expr(self.uneven_kv_smem): - sK_cur = self.offset_kv_smem(sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase) + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) # 4. release S0 / S1 with cute.arch.elect_one(): @@ -1078,7 +1267,7 @@ def mma( # the last iteration of the previous work tile has finished. cute.arch.mbarrier_wait( mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase + P_full_O_rescaled_phase, ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) @@ -1091,7 +1280,7 @@ def mma( sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase + mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial / O1_partial # Don't need to signal O_full to the correction warps anymore since the @@ -1145,8 +1334,7 @@ def mma( for stage in cutlass.range_constexpr(2): # 2. acquire corrected Oi_partial and Pi cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) @@ -1159,7 +1347,7 @@ def mma( sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase + mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp @@ -1197,8 +1385,8 @@ def softmax_loop( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - buffers = None, - fastdiv_mods = (None, None) + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1214,8 +1402,7 @@ def softmax_loop( tidx = cute.arch.thread_idx()[0] % ( cute.arch.WARP_SIZE # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) - * (len(self.softmax0_warp_ids) - ) + * (len(self.softmax0_warp_ids)) ) tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) @@ -1223,23 +1410,30 @@ def softmax_loop( tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width - tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + Float32, ) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) tStS_t2r = thr_tmem_load.partition_S(tStSi) tmem_store_scale_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), + Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( + tidx ) - thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), + Float32, ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) @@ -1266,9 +1460,13 @@ def softmax_loop( thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, - mask_local=self.is_local + mask_local=self.is_local, + ) + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, ) - softmax = SoftmaxSm100.create(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale) softmax.reset() softmax_step = partial( @@ -1289,15 +1487,24 @@ def softmax_loop( head_idx=head_idx, m_block=self.q_stage * m_block + stage, seqlen=seqlen, - buffers=buffers, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) si_corr_producer_phase ^= 1 # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) n_block_max -= 1 # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): @@ -1306,7 +1513,15 @@ def softmax_loop( ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( @@ -1314,13 +1529,23 @@ def softmax_loop( ) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block + ) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) # Now that we no longer already have the 1st iteration, need mask_seqlen=True here # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape @@ -1330,7 +1555,9 @@ def softmax_loop( # cute.arch.fence_view_async_tmem_store() sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): - sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ + 0 + ] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) @@ -1383,8 +1610,8 @@ def softmax_step( head_idx: Int32, m_block: Int32, seqlen, - buffers = None, - fastdiv_mods = (None, None), + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), mask_fn: Optional[Callable] = None, is_first: bool = False, ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: @@ -1422,8 +1649,8 @@ def softmax_step( m_block, n_block, softmax, - buffers, - fastdiv_mods + aux_tensors, + fastdiv_mods, ) if const_expr(mask_fn is not None): @@ -1446,14 +1673,21 @@ def softmax_step( softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait if const_expr(self.s0_s1_barrier): - cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) + cute.arch.mbarrier_wait( + mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase + ) tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) tSrP_r2t = cute.make_tensor( - cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) - softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and self.head_dim_padded <= 128, - e2e_freq=self.e2e_freq) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq, + ) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) @@ -1464,12 +1698,16 @@ def softmax_step( cute.arch.fence_view_async_tmem_store() # Notify mma warp that P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])): + for i in cutlass.range_constexpr( + cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) + ): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) cute.arch.fence_view_async_tmem_store() # Notify mma warp that the 2nd half of P is ready cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.arch.exp2(acc_scale_) return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 @@ -1496,11 +1734,14 @@ def correction_loop( tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) - tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) - for stage in range(2)) + tStScales = tuple( + cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) + for stage in range(2) + ) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tmem_load_v_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), + self.qk_acc_dtype, ) thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) @@ -1523,16 +1764,23 @@ def correction_loop( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # Ignore first signal from softmax as no correction is required - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase + ) softmax_corr_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): for stage in cutlass.range_constexpr(2): # wait for S0 / S1 - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -1548,7 +1796,9 @@ def correction_loop( thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) softmax_corr_consumer_phase ^= 1 # o_corr_consumer_phase ^= 1 # End of seqlen_corr_loop_steps @@ -1566,10 +1816,15 @@ def correction_loop( learnable_sink_val = [sink_val] * self.q_stage else: # Each thread might have a different sink value due to different q_head for stage in cutlass.range_constexpr(self.q_stage): - q_head_idx = ((self.q_stage * m_block + stage) * self.m_block_size + tidx) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -1581,14 +1836,24 @@ def correction_loop( cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) - row_sum += utils.exp2f(learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2) + row_sum += utils.exp2f( + learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2 + ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase + ) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) self.correction_epilogue( - thr_mma_pv, tOtOs[stage], tidx, scale, sO[None, None, stage], + thr_mma_pv, + tOtOs[stage], + tidx, + scale, + sO[None, None, stage], ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so @@ -1599,19 +1864,28 @@ def correction_loop( if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)) + gLSE = cute.local_tile( + mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) + ) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 - if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead ) - seqlen_q = seqlen.seqlen_q if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: # This actually just works with PackGQA too gLSE[tidx] = lse @@ -1693,7 +1967,8 @@ def correction_rescale( cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), ) tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) @@ -1748,7 +2023,9 @@ def correction_epilogue( epi_subtile, use_2cta_instrs=False, ) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice(tidx) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice( + tidx + ) thr_tmem_load = tiled_tmem_load.get_slice(tidx) smem_copy_atom = sm100_utils_basic.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load @@ -1765,14 +2042,16 @@ def correction_epilogue( cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), ) tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, ) @cute.jit @@ -1812,7 +2091,9 @@ def epilogue_s2g( cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) else: - tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) @@ -1822,11 +2103,18 @@ def epilogue_s2g( tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it assert not self.pack_gqa - pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) # 2. copy O0 / O1 to gmem # load acc O from smem to rmem for wider vectorization tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) @@ -1834,15 +2122,29 @@ def epilogue_s2g( # copy acc O from rmem to gmem if const_expr(not self.pack_gqa): for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + pred=tOpO[None, rest_m, None] + if self.check_hdim_v_oob + else None, ) else: - pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, self.q_stage * m_block + stage, seqlen.seqlen_q) + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) # Advance to next tile @@ -1886,7 +2188,9 @@ def load_KV( if stage == 0: cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V]) + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V] + ) tXsX_cur = tXsX[None, stage] if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it @@ -1907,9 +2211,12 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): return sX def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) return cutlass.pipeline.PipelineTmaUmma.create( barrier_storage=load_kv_mbar_ptr, num_stages=self.kv_stage, @@ -1950,7 +2257,7 @@ def apply_score_mod( m_block, n_block, softmax, - buffers=None, + aux_tensors=None, fastdiv_mods=(None, None), ): """Apply score modification for SM100 (constant q_idx).""" @@ -1971,7 +2278,7 @@ def apply_score_mod( head_offset = q_physical - q_idx_logical * self.qhead_per_kvhead head_idx = head_idx * self.qhead_per_kvhead + head_offset - if cutlass.const_expr(buffers is not None): + if cutlass.const_expr(aux_tensors is not None): seqlen_q_divmod, _ = fastdiv_mods _, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical) @@ -1984,7 +2291,7 @@ def apply_score_mod( softmax.softmax_scale, self.vec_size, self.qk_acc_dtype, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 8c2e5903fc4..e3d2eb0891b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype @@ -51,6 +52,7 @@ def maybe_contiguous(x): torch.float32: cutlass.Float32, } + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -83,7 +85,7 @@ def _flash_attn_fwd( return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, - buffers: Optional[list[torch.Tensor]] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -93,7 +95,7 @@ def _flash_attn_fwd( return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. - buffers: Some score_mods will want to read from global buffers. This is how we thread them through to the inner kernel. + aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] @@ -127,34 +129,52 @@ def _flash_attn_fwd( else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" - assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" - assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + assert seqused_q is None or seqused_q.shape == (batch_size,), ( + "seqused_q must have shape (batch_size,)" + ) + assert seqused_k is None or seqused_k.shape == (batch_size,), ( + "seqused_k must have shape (batch_size,)" + ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: - assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" - assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + assert t.dtype == torch.int32, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + ) + assert t.stride(0) == 1, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + ) if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: if t is not None: assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" - assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + # assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" assert all( t is None or t.is_cuda for t in ( - q, k, v, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, page_table, learnable_sink, - full_block_cnt, full_block_idx, - mask_block_cnt, mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, ) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" @@ -177,20 +197,38 @@ def _flash_attn_fwd( requires_grad = q.requires_grad or k.requires_grad or v.requires_grad if out is None: - out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) + out = torch.empty( + *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device + ) else: expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) - assert out.shape == expected_out_shape, f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" - assert out.dtype == out_torch_dtype, f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" - assert out.device == device, f"out tensor device {out.device} does not match input device {device}" + assert out.shape == expected_out_shape, ( + f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" + ) + assert out.dtype == out_torch_dtype, ( + f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" + ) + assert out.device == device, ( + f"out tensor device {out.device} does not match input device {device}" + ) assert out.is_cuda, "out tensor must be on CUDA device" if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad or return_lse else None + lse = ( + torch.empty(lse_shape, dtype=torch.float32, device=device) + if requires_grad or return_lse + else None + ) elif lse is not None: - assert lse.shape == lse_shape, f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" - assert lse.dtype == torch.float32, f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" - assert lse.device == device, f"lse tensor device {lse.device} does not match input device {device}" + assert lse.shape == lse_shape, ( + f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" + ) + assert lse.dtype == torch.float32, ( + f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" + ) + assert lse.device == device, ( + f"lse tensor device {lse.device} does not match input device {device}" + ) assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] @@ -198,82 +236,156 @@ def _flash_attn_fwd( from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + lse_tensor = ( + from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if lse is not None + else None + ) + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + learnable_sink_tensor, + ) = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] - page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None - - full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None - full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None - mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None - mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None - - - if causal: - window_size_right = 0 - local = window_size_left is not None or window_size_right is not None - if window_size_left is not None or window_size_right is not None: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - else: - causal, local = False, True - compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability + page_table_tensor = ( + from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) + if page_table is not None + else None + ) + + full_block_cnt_tensor = ( + from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + if full_block_cnt is not None + else None + ) + full_block_idx_tensor = ( + from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) + if full_block_idx is not None + else None + ) + mask_block_cnt_tensor = ( + from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + if mask_block_cnt is not None + else None + ) + mask_block_idx_tensor = ( + from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) + if mask_block_idx is not None + else None + ) + use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + + if mask_mod is None: + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + else: + causal, local = False, True + else: + causal, local = False, False + + compute_capability = ( + torch.cuda.get_device_capability()[0] + if _compute_capability is None + else _compute_capability + ) assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - if compute_capability == 9: # TODO: tune block size according to hdim - if head_dim == head_dim_v == 128 and not causal and not local: + if compute_capability == 9: # TODO: tune block size according to hdim. + if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: n_block_size = 192 if compute_capability == 10: # TODO: fix the varlen case - if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): + if ( + pack_gqa + and (128 % qhead_per_kvhead != 0) + or (cu_seqlens_q is not None or seqused_q is not None) + ): pack_gqa = False - + # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None - + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False + + print(mask_mod_hash) + if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) - is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None - use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) if score_mod is not None: if is_varlen: - raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") - if pack_gqa: - raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") + raise NotImplementedError( + "score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if mask_mod is not None: if not use_block_sparsity: - raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.") + raise NotImplementedError( + "mask_mod requires the use of block sparsity. This will be fixed in a future PR." + ) if is_varlen: - raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if pack_gqa: - raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") - + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + if use_block_sparsity: if is_varlen: - raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.") + raise NotImplementedError( + "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." + ) if pack_gqa: - raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.") - - cute_buffers = None - if buffers is not None: - cute_buffers = [from_dlpack(buf) for buf in buffers] + raise NotImplementedError( + "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." + ) + + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [from_dlpack(buf) for buf in aux_tensors] compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, - score_mod_hash, mask_mod_hash, - buffers is not None, - lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + score_mod_hash, + mask_mod_hash, + use_block_sparsity, + aux_tensors is not None, + lse is None, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, page_table is not None, - window_size_left is not None, window_size_right is not None, + window_size_left is not None, + window_size_right is not None, learnable_sink is not None, - m_block_size, n_block_size, num_threads, pack_gqa, + m_block_size, + n_block_size, + num_threads, + pack_gqa, compute_capability, ) @@ -299,10 +411,12 @@ def _flash_attn_fwd( mma_pv_is_rs=True, mask_mod=mask_mod, score_mod=score_mod, - has_buffers=buffers is not None, + has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" + assert page_size in [None, 128], ( + "Only page_size=128 is supported for paged KV on SM 10.0" + ) fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -310,34 +424,69 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None, score_mod=score_mod, - has_buffers=buffers is not None, + has_aux_tensors=aux_tensors is not None, ) else: - raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") + raise ValueError( + f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" + ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - buffers=cute_buffers, + window_size_left, + window_size_right, + learnable_sink_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + mask_block_cnt_tensor, + mask_block_idx_tensor, + cute_aux_tensors, ) _flash_attn_fwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, page_table_tensor, - window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, - buffers=cute_buffers, + window_size_left, + window_size_right, + learnable_sink_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + mask_block_cnt_tensor, + mask_block_idx_tensor, + cute_aux_tensors, ) return out, lse _flash_attn_fwd.compile_cache = {} + def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, @@ -407,10 +556,14 @@ def _flash_attn_bwd( else: assert k.shape == (total_k, num_head_kv, head_dim) assert v.shape == (total_k, num_head_kv, head_dim_v) - assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) assert out.shape == (total_q, num_head, head_dim_v) assert dout.shape == (total_q, num_head, head_dim_v) @@ -418,15 +571,21 @@ def _flash_attn_bwd( else: assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) - assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + assert lse.shape == (batch_size, num_head, seqlen_q), ( + "lse must have shape (batch_size, num_head, seqlen_q)" + ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" - assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, "inputs must have the same dtype" + assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( + "inputs must have the same dtype" + ) for t in [cu_seqlens_q, cu_seqlens_k]: if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all(t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)), "inputs must be on CUDA device" + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() @@ -448,12 +607,26 @@ def _flash_attn_bwd( if cu_seqlens_q is None: seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) - lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + dq_accum = torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dpsum = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + lse_log2 = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) else: - total_q_rounded_padded = (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size - dq_accum = torch.empty(num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) + total_q_rounded_padded = ( + (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size + ) + dq_accum = torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) @@ -461,19 +634,45 @@ def _flash_attn_bwd( head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size - dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + dk_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) else: - total_k_rounded_padded = (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size - dk_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device) - dv_accum = torch.zeros(num_head_kv, total_k_rounded_padded * head_dim_v_rounded, dtype=torch.float32, device=device) + total_k_rounded_padded = ( + (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + ) + dk_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse.ndim - 1 + ) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dq_accum, dpsum, lse_log2) @@ -484,7 +683,9 @@ def _flash_attn_bwd( for t in (dk_accum, dv_accum) ] cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim-1) if t is not None else None + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1) + if t is not None + else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -493,23 +694,57 @@ def _flash_attn_bwd( compile_key_pre = (dtype, head_dim_v, m_block_size, num_threads) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: fa_bwd_pre = FlashAttentionBackwardPreprocess( - dtype, head_dim_v, m_block_size, num_threads=num_threads, + dtype, + head_dim_v, + m_block_size, + num_threads=num_threads, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( - fa_bwd_pre, o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, - dq_accum_tensor, cu_seqlens_q_tensor, seqused_q_tensor, current_stream + fa_bwd_pre, + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, - cu_seqlens_q_tensor, seqused_q_tensor, current_stream + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) # Backward kernel: compute dk, dv, dq_accum. compile_key = ( - dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, m_block_size, - n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, - AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + num_stages_Q, + num_stages_dO, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -557,7 +792,12 @@ def _flash_attn_bwd( _flash_attn_bwd.compile_cache[compile_key] = cute.compile( # fa_bwd_sm80, fa_bwd_sm90, - q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, @@ -569,7 +809,12 @@ def _flash_attn_bwd( seqused_k_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( - q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, dq_accum_tensor, dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, @@ -591,11 +836,21 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, - seqused_q_tensor, current_stream + fa_bwd_post, + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum_tensor, dq_tensor, softmax_scale, cu_seqlens_q_tensor, seqused_q_tensor, current_stream + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, ) if qhead_per_kvhead > 1: @@ -607,22 +862,51 @@ def _flash_attn_bwd( ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream + fa_bwd_post, + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum_tensor, dk_tensor, softmax_scale, cu_seqlens_k_tensor, seqused_k_tensor, current_stream + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, + ) + compile_key_post = ( + dtype, + head_dim_v, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, ) - compile_key_post = (dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB ) # TODO: check @can_implement _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream + fa_bwd_post, + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), cu_seqlens_k_tensor, seqused_k_tensor, current_stream + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, ) return dq, dk, dv @@ -634,7 +918,6 @@ def _flash_attn_bwd( class FlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -695,7 +978,6 @@ def backward(ctx, dout, *args): class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -864,7 +1146,9 @@ def _flash_attn_fwd_combine( # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], "out_partial must be fp16, bf16, or fp32" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + "out_partial must be fp16, bf16, or fp32" + ) assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" @@ -881,7 +1165,11 @@ def _flash_attn_fwd_combine( assert lse.dtype == torch.float32, "lse must be fp32" # Validate optional tensors - for t, name in [(cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr")]: + for t, name in [ + (cu_seqlens, "cu_seqlens"), + (seqused, "seqused"), + (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), + ]: if t is not None: assert t.dtype == torch.int32, f"{name} must be int32" assert t.is_cuda, f"{name} must be on CUDA device" @@ -903,16 +1191,28 @@ def _flash_attn_fwd_combine( log_max_splits = max(log_max_splits, 5) # Convert to cute tensors (using kernel-formatted tensors) - out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=4) - lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 2) + out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=4 + ) + lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=lse_partial.ndim - 2 + ) out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3) - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) if lse is not None else None + lse_tensor = ( + from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) + if lse is not None + else None + ) optional_tensors = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if t is not None + else None for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) ] - cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = optional_tensors + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( + optional_tensors + ) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -921,9 +1221,15 @@ def _flash_attn_fwd_combine( dtype_partial = torch2cute_dtype_map[out_partial.dtype] compile_key = ( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, log_max_splits, - cu_seqlens is not None, seqused is not None, lse is not None, + cu_seqlens is not None, + seqused is not None, + lse is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: @@ -938,9 +1244,17 @@ def _flash_attn_fwd_combine( # Check if implementation is supported if not fa_combine.can_implement( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, log_max_splits, num_threads=256 + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads=256, ): - raise RuntimeError(f"FlashAttention combine kernel cannot be implemented with given parameters") + raise RuntimeError( + f"FlashAttention combine kernel cannot be implemented with given parameters" + ) _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( fa_combine, @@ -952,7 +1266,7 @@ def _flash_attn_fwd_combine( seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor, - current_stream + current_stream, ) _flash_attn_fwd_combine.compile_cache[compile_key]( @@ -964,7 +1278,7 @@ def _flash_attn_fwd_combine( seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor, - current_stream + current_stream, ) @@ -1019,13 +1333,17 @@ def flash_attn_combine( if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, total_q, num_heads), "lse_partial shape mismatch for varlen" + assert lse_partial.shape == (num_splits, total_q, num_heads), ( + "lse_partial shape mismatch for varlen" + ) batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), "lse_partial shape mismatch" + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( + "lse_partial shape mismatch" + ) # Determine output dtype if out_dtype is None: @@ -1037,14 +1355,20 @@ def flash_attn_combine( if is_varlen: out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) else: - out = torch.empty(batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device) + out = torch.empty( + batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device + ) # Create lse output only if requested if return_lse: if is_varlen: - lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(0, 1) + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( + 0, 1 + ) else: - lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device).transpose(1, 2) + lse = torch.empty( + batch_size, num_heads, seqlen, dtype=torch.float32, device=device + ).transpose(1, 2) else: lse = None diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 0d78eb9e948..7b830f42c4e 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -9,6 +9,7 @@ import flash_attn.cute.utils as utils + @cute.jit def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: # Bit manipulation, compiles down to the R2P instruction @@ -38,6 +39,7 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal for r in cutlass.range_constexpr(cute.size(X.shape[0])): X[r, c] = X[r, c] if in_bound else -Float32.inf + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -62,7 +64,7 @@ def apply_mask( mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, - buffers: Optional[list[cute.Tensor]] = None, + aux_tensors: Optional[list] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) @@ -90,20 +92,22 @@ def apply_mask( acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) - - elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # FlexAttention mask mod + + elif const_expr( + not mask_causal and not mask_local and mask_mod is not None + ): # FlexAttention mask mod nrow = const_expr(cute.size(tScS_mn.shape[0])) ncol = const_expr(cute.size(tScS_mn.shape[1])) thr_col_offset = tScS_mn[0, 0][1] - + for r in cutlass.range_constexpr(nrow): global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m - + for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n - + cond = cutlass.Boolean( mask_mod( batch_idx, @@ -112,7 +116,7 @@ def apply_mask( thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, self.seqlen_q, self.seqlen_k, - buffers, + aux_tensors, ) ) if const_expr(mask_seqlen): @@ -126,7 +130,6 @@ def apply_mask( else: acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf - else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row @@ -321,12 +324,11 @@ def apply_mask_sm100( else acc_S[i] ) - @cute.jit def apply_mask_sm100_transposed( self, acc_S: cute.Tensor, - tScS_t2r : cute.Tensor, + tScS_t2r: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, wg_idx: cutlass.Int32, @@ -335,9 +337,9 @@ def apply_mask_sm100_transposed( mask_causal: cutlass.Constexpr, mask_local: cutlass.Constexpr, ) -> None: - ''' + """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. - ''' + """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" tidx = cute.arch.thread_idx()[0] % 128 @@ -352,7 +354,7 @@ def apply_mask_sm100_transposed( else: # Causal or local causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m row_idx = tScS_t2r[0][0] + n_block * self.tile_n - + if const_expr(mask_causal): col_limit_left = row_idx + causal_row_offset ncol = const_expr(cute.size(tScS_t2r.shape)) @@ -365,4 +367,4 @@ def apply_mask_sm100_transposed( acc_S[i] = ( -cutlass.Float32.inf if tScS_t2r[i][1] <= col_limit_left else acc_S[i] ) - # TODO: local \ No newline at end of file + # TODO: local diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 6b206fd6026..23c4f026b1c 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -1,7 +1,7 @@ from typing import Callable, Optional import random -import math +import math import cutlass import cutlass.cute as cute @@ -10,7 +10,14 @@ MaskModCallable = Optional[ Callable[ - ["cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32", "cutlass.Int32"], + [ + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + "cutlass.Int32", + ], "cutlass.Boolean", ] ] @@ -49,12 +56,14 @@ def flex_block_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): def create_flex_sliding_window_mask(window_size=1024): """Factory function to create a sliding window mask with configurable window size""" + def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): # Sliding window: q_idx - window_size <= kv_idx <= q_idx if seqlen_q is not None and seqlen_k is not None: offset = seqlen_k - seqlen_q return (kv_idx <= q_idx + offset) & (kv_idx >= q_idx + offset - window_size) return (kv_idx <= q_idx) & (kv_idx >= q_idx - window_size) + return flex_sliding_window_mask @@ -83,32 +92,49 @@ def flex_half_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): return torch.ones_like(kv_idx, dtype=torch.bool) return True + def flex_document_mask(b, h, q_idx, kv_idx, doc_id: torch.Tensor): return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + # CuTe versions for kernel compilation @cute.jit def cute_identity_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: return cutlass.Boolean(True) @cute.jit def cute_identity_partial_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: return cutlass.Boolean(True) @cute.jit def cute_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: # Right-aligned causal masking offset = seqlen_k - seqlen_q @@ -117,8 +143,13 @@ def cute_causal_mask( @cute.jit def cute_block_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: None, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: None, ) -> cutlass.Boolean: # Right-aligned causal masking offset = seqlen_k - seqlen_q @@ -127,22 +158,36 @@ def cute_block_causal_mask( def create_cute_sliding_window_mask(window_size=1024): """Factory function to create a CuTe sliding window mask with configurable window size""" + @cute.jit def cute_sliding_window_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: offset = seqlen_k - seqlen_q - return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + return cutlass.Boolean( + (n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size) + ) + return cute_sliding_window_mask # Default sliding window mask with window_size=1024 for backward compatibility @cute.jit def cute_sliding_window_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: window_size = 1024 # offset = seqlen_k - seqlen_q @@ -152,24 +197,40 @@ def cute_sliding_window_mask( @cute.jit def cute_document_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers: list, + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors: list, ): - doc_id = buffers[0] + doc_id = aux_tensors[0] return cutlass.Boolean(doc_id[batch, head, m_idx] == doc_id[batch, head, n_idx]) - + @cute.jit def cute_block_diagonal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: return cutlass.Boolean((m_idx // 64) == (n_idx // 64)) @cute.jit def cute_mini_causal_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32, buffers + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, + aux_tensors, ) -> cutlass.Boolean: """Each tile is locally causal-masked""" m_mod = m_idx % 128 @@ -179,8 +240,12 @@ def cute_mini_causal_mask( @cute.jit def cute_half_identity_mask( - batch: cutlass.Int32, head: cutlass.Int32, m_idx: cutlass.Int32, n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, seqlen_k: cutlass.Int32 + batch: cutlass.Int32, + head: cutlass.Int32, + m_idx: cutlass.Int32, + n_idx: cutlass.Int32, + seqlen_q: cutlass.Int32, + seqlen_k: cutlass.Int32, ) -> cutlass.Boolean: return cutlass.Boolean(True) @@ -191,17 +256,17 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): for h in range(nheads): N = seqlen_q n = random.randint(1, math.ceil(math.sqrt(N // 4))) - cuts = sorted(random.sample(range(1, N), n-1)) + cuts = sorted(random.sample(range(1, N), n - 1)) lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] doc_ids = [] for i, length in enumerate(lengths): doc_ids += [i for _ in range(length)] - + doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) print(f"{doc_ids_tensor.shape = }") return doc_ids_tensor - + MASK_FUNCTIONS = { "identity": (cute_identity_mask, flex_identity_mask), @@ -217,4 +282,4 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): if __name__ == "__main__": doc_ids = random_doc_id_tensor(1, 2, 128) - print(f"{doc_ids = }") \ No newline at end of file + print(f"{doc_ids = }") diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 72de115732a..0ca08f3f2e3 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -337,7 +337,7 @@ def apply_score_mod_inner( softmax_scale, vec_size: cutlass.Constexpr, qk_acc_dtype: cutlass.Constexpr, - buffers, + aux_tensors, fastdiv_mods, constant_q_idx: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, @@ -353,7 +353,7 @@ def apply_score_mod_inner( softmax_scale: Scale to apply vec_size: Vector size for processing elements qk_acc_dtype: Data type for accumulator - buffers: Optional buffers for FlexAttention + aux_tensors: Optional aux_tensors for FlexAttention fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping constant_q_idx: If provided, use this constant for all q_idx values If None, compute q_idx per-element @@ -388,7 +388,7 @@ def apply_score_mod_inner( head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset # If we will do loads we mod, in order to not read OOB - if cutlass.const_expr(buffers is not None and fastdiv_mods is not None): + if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) @@ -421,9 +421,9 @@ def apply_score_mod_inner( else: head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) - buffer_args = [] - if cutlass.const_expr(buffers is not None): - buffer_args = buffers + aux_args = [] + if cutlass.const_expr(aux_tensors is not None): + aux_args = aux_tensors post_mod_scores = score_mod( score_ssa, @@ -431,7 +431,7 @@ def apply_score_mod_inner( head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, - buffers=buffer_args, + aux_tensors=aux_args, ) # Write back modified scores diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 644936d8d2d..6c3a679a613 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -7,6 +7,7 @@ import torch from einops import rearrange, repeat + try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: @@ -19,7 +20,11 @@ pad_input, unpad_input, ) -from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -77,7 +82,17 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype + seqlen_q, + seqlen_k, + d, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): if (causal or local) and seqlen_k < seqlen_q: pytest.skip("Causal attention requires seqlen_k >= seqlen_q") @@ -99,26 +114,54 @@ def test_flash_attn_output( # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) + q_ref = q_ref * softcap / 4 q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -131,11 +174,13 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, - softcap=softcap + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -145,7 +190,9 @@ def test_flash_attn_output( None, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, @@ -197,7 +244,9 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -225,7 +274,9 @@ def test_flash_attn_output( # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -240,12 +291,24 @@ def test_flash_attn_output( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -300,9 +363,22 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, ): - if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q + if ( + causal or local + ): # Right now we only support causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed @@ -320,25 +396,53 @@ def test_flash_attn_varlen_output( # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) if has_qv: - qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] @@ -349,7 +453,11 @@ def test_flash_attn_varlen_output( # TODO: test zero_lengths key_padding_mask = generate_random_padding_mask( # seqlen_k, batch_size, device, mode="random", zero_lengths=True - seqlen_k, batch_size, device, mode="random", zero_lengths=False + seqlen_k, + batch_size, + device, + mode="random", + zero_lengths=False, ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @@ -394,9 +502,20 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] out_ref, attn_ref = attention_ref( q_ref, k_ref, @@ -405,11 +524,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, - softcap=softcap + softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, @@ -419,7 +540,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): key_padding_mask, causal=causal, qv=qv_ref, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, @@ -473,8 +596,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -510,7 +634,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # deterministic, # 0, # sm_margin # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) @@ -534,9 +660,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -551,12 +678,24 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -664,45 +803,107 @@ def test_flash_attn_kvcache( for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): # has_qv = d == 64 and dv >= 256 has_qv = False - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) if has_qv: - qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) else: qv = None if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) - qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) else: query_padding_mask = None q_unpad = q qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) cu_seqlens_k_new = None key_new_padding_mask = None if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) v_unpad, *rest = unpad_input(v, key_new_padding_mask) else: k_unpad, v_unpad = k, v else: k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) page_table = None else: ( @@ -713,13 +914,25 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) if new_kv else (seqlen_k + 1) ), @@ -728,15 +941,26 @@ def test_flash_attn_kvcache( device=device, ) if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) + cache_leftpad = torch.cat( + [ + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size) + ] + ) else: cache_leftpad = None if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") @@ -744,11 +968,14 @@ def test_flash_attn_kvcache( if not new_kv: key_padding_mask = arange < cache_seqlens_expanded else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens if has_leftpad: key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 @@ -766,7 +993,11 @@ def test_flash_attn_kvcache( sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, ) else: q_ro = rearrange( @@ -782,17 +1013,26 @@ def test_flash_attn_kvcache( ) # q_ro = q k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() if new_kv: update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, ) k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") v_to_update = rearrange(v, "b s ... -> (b s) ...") @@ -801,8 +1041,12 @@ def test_flash_attn_kvcache( v_to_update = v_to_update[indices_k] k_cache_ref[update_mask] = k_to_update v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) out_ref, _ = attention_ref( q_ro, k_cache_rep, @@ -830,7 +1074,7 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -852,7 +1096,9 @@ def test_flash_attn_kvcache( num_splits_vals = [1] # precompute_metadata_vals = [False, True] precompute_metadata_vals = [False] - for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -922,19 +1168,35 @@ def test_flash_attn_kvcache( if new_kv: if page_size is None: k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] ) v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] ) else: k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) @@ -943,7 +1205,9 @@ def test_flash_attn_kvcache( if dtype is not torch.float8_e4m3fn: assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) # breakpoint() # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: if rotary_dim == 0: @@ -952,23 +1216,37 @@ def test_flash_attn_kvcache( # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref - ).to(dtype).to(dtype_ref) + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -994,7 +1272,9 @@ def attention_combine_ref(out_partial, lse_partial): """ lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) - scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse @@ -1019,13 +1299,25 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # batch_size = 1 # nheads = 1 # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) - out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor - lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor # To test short-circuiting based on num_splits - lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") # Test with LSE returned (default behavior) - out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) out_pt = out_ref.to(dtype) @@ -1039,9 +1331,16 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) multiple = 2 - assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + assert ( + (out - out_ref).abs().max().item() + <= multiple * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) # Test with LSE not returned - out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" \ No newline at end of file + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 3e6707b5fb9..ce3a28b82c6 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -1,23 +1,22 @@ # mask mod test script +# REFACTORED to use _flash_attn_fwd as the kernel entrypoint import math +from typing import Optional, Callable -import cuda.bindings.driver as cuda -import cutlass -import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack import pytest import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F +from flash_attn.cute.interface import _flash_attn_fwd from flash_attn.cute.block_sparsity import compute_block_sparsity -from flash_attn.cute.flash_fwd import ( - FlashAttentionForwardSm80, - FlashAttentionForwardSm90, +from flash_attn.cute.mask_definitions import ( + MASK_FUNCTIONS, + flex_causal_mask, + create_flex_sliding_window_mask, + create_cute_sliding_window_mask, ) -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 -from flash_attn.cute.mask_definitions import MASK_FUNCTIONS, flex_causal_mask, create_flex_sliding_window_mask, create_cute_sliding_window_mask from flash_attn.cute.testing import attention_ref @@ -46,169 +45,12 @@ def create_tensors( } -def compile_and_run_kernel( - tensors, - mask_mod_cute, - causal, - is_local, - window_left, - window_right, - tile_m, - tile_n, - full_block_cnt=None, - full_block_idx=None, - mask_block_cnt=None, - mask_block_idx=None, -): - dtype_map = { - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, - torch.float32: cutlass.Float32, - } - cute_dtype = dtype_map[tensors["q"].dtype] - - batch_size, seqlen_q, nheads, headdim = tensors["q"].shape - _, seqlen_k, nheads_kv, _ = tensors["k"].shape - headdim_v = tensors["v"].shape[-1] - - compute_capability = torch.cuda.get_device_capability() - if compute_capability >= (10, 0): - kernel_class = FlashAttentionForwardSm100 - elif compute_capability >= (9, 0): - kernel_class = FlashAttentionForwardSm90 - else: - kernel_class = FlashAttentionForwardSm80 - - qhead_per_kvhead = nheads // nheads_kv - kernel = kernel_class( - cute_dtype, - headdim, - headdim_v, - qhead_per_kvhead, - is_causal=causal, - is_local=is_local, - pack_gqa=False, - tile_m=tile_m, - tile_n=tile_n, - num_stages=2, - num_threads=384, - intra_wg_overlap=True, - mma_pv_is_rs=True, - mask_mod=mask_mod_cute, - has_buffers=False, - Q_in_regs=False, - ) - - softmax_scale = 1.0 / math.sqrt(headdim) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["q"].ndim - 1 - ) - k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["k"].ndim - 1 - ) - v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=tensors["v"].ndim - 1 - ) - out_cute = from_dlpack( - tensors["out"].detach(), assumed_align=16 - ).mark_layout_dynamic(leading_dim=tensors["out"].ndim - 1) - lse_cute = from_dlpack( - tensors["lse"].detach(), assumed_align=4 - ).mark_layout_dynamic(leading_dim=tensors["lse"].ndim - 1) - - full_block_cnt_cute = ( - from_dlpack(full_block_cnt.detach(), assumed_align=4) - if full_block_cnt is not None - else None - ) - full_block_idx_cute = ( - from_dlpack(full_block_idx.detach(), assumed_align=4) - if full_block_idx is not None - else None - ) - mask_block_cnt_cute = ( - from_dlpack(mask_block_cnt.detach(), assumed_align=4) - if mask_block_cnt is not None - else None - ) - mask_block_idx_cute = ( - from_dlpack(mask_block_idx.detach(), assumed_align=4) - if mask_block_idx is not None - else None - ) - - # Window parameters for is_local - window_left_cute = ( - cutlass.Int32(window_left) if window_left is not None else None - ) - window_right_cute = ( - cutlass.Int32(window_right) if window_right is not None else None - ) - - compiled = cute.compile( - kernel, - q_cute, - k_cute, - v_cute, - out_cute, - lse_cute, - softmax_scale, - current_stream, - None, # cu_seqlens_q - None, # cu_seqlens_k - None, # seqused_q - None, # seqused_k - None, # page_table - window_left_cute, - window_right_cute, - None, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, - None, # buffers - ) - - compiled( - q_cute, - k_cute, - v_cute, - out_cute, - lse_cute, - softmax_scale, - current_stream, - None, # cu_seqlens_q - None, # cu_seqlens_k - None, # seqused_q - None, # seqused_k - None, # page_table - window_left_cute, - window_right_cute, - None, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, - None, # buffers - ) - - torch.cuda.synchronize() - return tensors["out"] - - -def compute_reference_flash_attn( - tensors, causal, window_size, dtype_ref, upcast=True -): +def compute_reference_flash_attn(tensors, causal, window_size, dtype_ref, upcast=True): """Compute reference using FlashAttention's attention_ref function""" - batch_size, seqlen_q, nheads, headdim = tensors["q"].shape - _, seqlen_k, nheads_kv, _ = tensors["k"].shape - q = tensors["q"].to(dtype_ref) k = tensors["k"].to(dtype_ref) v = tensors["v"].to(dtype_ref) - + out_ref, attn_ref = attention_ref( q, k, @@ -220,13 +62,11 @@ def compute_reference_flash_attn( upcast=upcast, reorder_ops=False, ) - + return out_ref -def compute_reference_flex_attn( - tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n -): +def compute_reference_flex_attn(tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape _, seqlen_k, nheads_kv, _ = tensors["k"].shape @@ -266,9 +106,7 @@ def mask_fn(b, h, q_idx, kv_idx): k_end = min((k_block + 1) * tile_n, seqlen_k) mask[q_start:q_end, k_start:k_end] = True - attn_mask = ( - mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) - ) + attn_mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) out_ref = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, scale=scale ) @@ -319,11 +157,11 @@ def mask_fn(b, h, q_idx, kv_idx): @pytest.mark.parametrize( "use_mask_mod,is_local,mask_name,window_size,window_left,window_right", [ - (False, False, "identity", None, None, None), - (False, False, "causal", None, None, None), + # (False, False, "identity", None, None, None), + # (False, False, "causal", None, None, None), (True, False, "identity", None, None, None), (True, False, "causal", None, None, None), - # (True, False, "block_causal", None, None, None), + (True, False, "block_causal", None, None, None), # Mask mod sliding window (True, False, "sliding_window", 128, None, None), (True, False, "sliding_window", 256, None, None), @@ -334,39 +172,46 @@ def mask_fn(b, h, q_idx, kv_idx): # (False, True, None, None, 512, 0), ], ) -@pytest.mark.parametrize("tile_m,tile_n", [(128, 128),]) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) def test_mask_mod_output( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, - use_mask_mod, is_local, mask_name, window_size, window_left, window_right, - tile_m, tile_n + seqlen_q, + seqlen_k, + nheads, + kv_mode, + headdim, + dtype, + use_mask_mod, + is_local, + mask_name, + window_size, + window_left, + window_right, + tile_m, + tile_n, ): torch.manual_seed(42) # Validate configuration if is_local: assert not use_mask_mod, "Cannot use both is_local and use_mask_mod" - assert window_left is not None or window_right is not None, \ + assert window_left is not None or window_right is not None, ( "Must specify window_left or window_right for is_local" - + ) + if use_mask_mod and mask_name == "sliding_window": - assert window_size is not None, "window_size must be specified for sliding_window" - # Skip if seqlen_k is too small for the window - # if seqlen_k < window_size // 2: - # pytest.skip(f"seqlen_k={seqlen_k} too small for window_size={window_size}") - # Skip if seqlen_q > seqlen_k (problematic for sliding window) + assert window_size is not None, ( + "window_size must be specified for sliding_window" + ) if seqlen_q > seqlen_k: - pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window") - + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" + ) + if is_local: - window_left_val = window_left if window_left is not None else 0 - window_right_val = window_right if window_right is not None else 0 - total_window = window_left_val + window_right_val + 1 - # Skip if seqlen_k is too small for the window - if seqlen_k < total_window // 2: - pytest.skip(f"seqlen_k={seqlen_k} too small for window={total_window}") - # Skip if seqlen_q > seqlen_k (problematic for local window) if seqlen_q > seqlen_k: - pytest.skip(f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local") + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local" + ) # Determine nheads_kv based on mode if kv_mode == "mha": @@ -378,7 +223,7 @@ def test_mask_mod_output( else: raise ValueError(f"Unknown kv_mode: {kv_mode}") - batch_size = 2 + batch_size = 1 headdim_v = headdim # Determine mask_mod functions and causal flag @@ -389,7 +234,7 @@ def test_mask_mod_output( mask_mod_flex = create_flex_sliding_window_mask(window_size) else: mask_mod_cute, mask_mod_flex = MASK_FUNCTIONS[mask_name] - causal = (mask_name == "causal") + causal = False elif is_local: # Base local attention - no mask_mod mask_mod_cute = None @@ -399,7 +244,7 @@ def test_mask_mod_output( mask_mod_cute = None mask_mod_flex = None causal = (mask_name == "causal") if mask_name else False - + if causal and seqlen_k < seqlen_q: pytest.skip("causal masking requires seqlen_k >= seqlen_q") @@ -443,26 +288,61 @@ class Config: config=config, mask_mod_flex=mask_mod_flex, device="cuda" ) - # Run kernel - out_cute = compile_and_run_kernel( - tensors, - mask_mod_cute, + softmax_scale = 1.0 / math.sqrt(headdim) + + # if full_cnt is not None: + # print(f"Block sparsity info for {mask_name}:") + # print(f" full_cnt shape: {full_cnt.shape}") + # print(f" full_idx shape: {full_idx.shape}") + # print(f" mask_cnt shape: {mask_cnt.shape}") + # print(f" mask_idx shape: {mask_idx.shape}") + # print(f" full_cnt: {full_cnt}") + # print(f" full_idx: {full_idx}") + # print(f" mask_cnt: {mask_cnt}") + # print(f" mask_idx: {mask_idx}") + # if full_cnt[0,0,0] > 0: + # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") + # if mask_cnt[0,0,0] > 0: + # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") + + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, causal=causal, - is_local=is_local, - window_left=window_left, - window_right=window_right, - tile_m=tile_m, - tile_n=tile_n, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + m_block_size=tile_m, + n_block_size=tile_n, + num_threads=384, + pack_gqa=False, + _compute_capability=None, + score_mod=None, + mask_mod=mask_mod_cute, full_block_cnt=full_cnt, full_block_idx=full_idx, mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, + return_lse=True, + aux_tensors=None, ) + out_cute = out_tuple[0] + # Determine which reference implementation to use dtype_ref = torch.bfloat16 use_flash_attn_ref = False - + # Use FlashAttention reference for causal and local window cases if mask_name == "causal" and not use_mask_mod: use_flash_attn_ref = True @@ -472,8 +352,6 @@ class Config: window_size_ref = (None, None) # No window for identity elif is_local: use_flash_attn_ref = True - # For is_local, we need to pass the window parameters - # When window_right=0, this is inherently causal window_size_ref = (window_left, window_right) if window_right == 0: causal = True # Override causal flag for reference computation @@ -484,19 +362,31 @@ class Config: # Sliding window with window_right=0 is inherently causal window_size_ref = (window_size, 0) causal = True # Override causal flag for reference computation - + if use_flash_attn_ref: # Compute reference using FlashAttention's attention_ref out_ref_fp32 = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=torch.float32, upcast=True + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=torch.float32, + upcast=True, ) out_ref = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype_ref, upcast=False + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=dtype_ref, + upcast=False, ) - + # Also compute PyTorch reference for comparison (with reorder_ops for better accuracy) out_pt = compute_reference_flash_attn( - tensors, causal=causal, window_size=window_size_ref, dtype_ref=dtype, upcast=False + tensors, + causal=causal, + window_size=window_size_ref, + dtype_ref=dtype, + upcast=False, ) else: # Use flex_attention for custom mask_mods @@ -504,7 +394,7 @@ class Config: k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() } - + out_ref_fp32 = compute_reference_flex_attn( tensors_fp32, mask_mod_flex, mask_name, tile_m, tile_n ) @@ -537,18 +427,20 @@ class Config: mask_desc += f"(w={window_size})" else: mask_desc = mask_name if mask_name else "identity" - + print( f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " f"D={headdim}, M={tile_m}, N={tile_n}" ) - print(f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}") + print( + f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}" + ) print(f" Reference vs FP32: {ref_error:.2e}") print(f" PyTorch vs FP32: {pt_error:.2e}") print(f" Kernel vs FP32: {cute_error:.2e}") print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") - + # Debug: show some sample values if error is large if cute_error > 1e-2: print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") @@ -567,4 +459,4 @@ class Config: if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 0d8b2234467..147e5519394 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -9,14 +9,14 @@ @cute.jit -def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tSrS_ssa = tmp0 return tSrS_ssa @cute.jit -def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = operator.ge(tmp0, tmp1) @@ -27,7 +27,7 @@ def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = q_idx tmp2 = kv_idx @@ -40,7 +40,7 @@ def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = q_idx tmp2 = kv_idx @@ -54,7 +54,7 @@ def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = tmp0 * cute.full_like(tmp0, 2) tSrS_ssa = tmp1 @@ -62,7 +62,7 @@ def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = tSrS_ssa tmp1 = tmp0.to(cutlass.Float32) tmp2 = h_idx @@ -84,7 +84,7 @@ def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tmp0 - tmp1 @@ -97,7 +97,7 @@ def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tSrS_ssa @@ -109,7 +109,7 @@ def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): +def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = tmp0 - tmp1 @@ -121,8 +121,8 @@ def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): - batch_bias = buffers[0] +def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + batch_bias = aux_tensors[0] # Detect dtype from buffer element type dtype = batch_bias.element_type @@ -137,9 +137,9 @@ def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): @cute.jit -def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers): - head_bias = buffers[0] - pos_bias = buffers[1] +def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] # Detect dtype from buffer element type dtype = head_bias.element_type @@ -232,8 +232,8 @@ def dual_buffer_mod(score, b, h, q_idx, kv_idx): (score_mod_9, causal_mask_v2_eager), ] -# Test pairs with buffers: (cute_jit_function, eager_reference_function_factory) -TEST_PAIRS_WITH_BUFFERS = [ +# Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory) +TEST_PAIRS_WITH_AUX_TENSORS = [ (score_mod_10, batch_bias), (score_mod_11, dual_buffer_bias), ] @@ -248,7 +248,9 @@ def create_tensors( return q, k, v -def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> torch.Tensor: +def run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False +) -> torch.Tensor: q_transposed, k_transposed, v_transposed = map( lambda x: x.transpose(1, 2), (q, k, v) ) @@ -261,7 +263,7 @@ def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> tor score_mod=cute_score_mod, out=out, lse=None, - buffers=buffers, + aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return out.transpose(1, 2) @@ -270,7 +272,9 @@ def run_cute_flash(q, k, v, cute_score_mod, buffers=None, pack_gqa=False) -> tor def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) + return flex_attention( + q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] + ) @pytest.mark.parametrize( @@ -301,7 +305,9 @@ def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) -def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair): +def test_cute_vs_flex_attention( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): torch.random.manual_seed(42) cute_score_mod, eager_score_mod = score_mod_pair @@ -375,8 +381,8 @@ def test_cute_vs_flex_attention(seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_he ) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_BUFFERS) -def test_cute_vs_flex_attention_with_buffers( +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +def test_cute_vs_flex_attention_with_aux_tensors( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair ): torch.random.manual_seed(42) @@ -398,13 +404,13 @@ def test_cute_vs_flex_attention_with_buffers( if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 - buffers = [buffer] + aux_tensors = [buffer] eager_score_mod = eager_score_mod_factory(buffer) assert buffer.shape == (batch_size,) elif cute_score_mod == score_mod_11: head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 - buffers = [head_bias, pos_scale] + aux_tensors = [head_bias, pos_scale] eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) assert head_bias.shape == (num_q_heads,) assert pos_scale.shape == (seqlen_q,) @@ -412,7 +418,9 @@ def test_cute_vs_flex_attention_with_buffers( out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash(q, k, v, cute_score_mod, buffers=buffers, pack_gqa=pack_gqa) + out_cute = run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -443,7 +451,9 @@ def test_cute_vs_flex_attention_with_buffers( ) -@pytest.mark.xfail(raises=NotImplementedError, reason="Varlen with score_mod not yet supported") +@pytest.mark.xfail( + raises=NotImplementedError, reason="Varlen with score_mod not yet supported" +) def test_varlen_with_score_mod(): """Test that varlen (variable length sequences) works with score_mod. @@ -458,7 +468,11 @@ def test_varlen_with_score_mod(): num_heads = 4 dtype = torch.bfloat16 - cu_seqlens = torch.tensor([0] + list(torch.tensor(seqlens).cumsum(0).tolist()), device="cuda", dtype=torch.int32) + cu_seqlens = torch.tensor( + [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype) From 3effce828cd3c69cdeff96b418a6370d5d5a2430 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 24 Oct 2025 01:17:39 -0400 Subject: [PATCH 185/258] Fix FA3 segfault with custom CUDA streams in ABI stable build (#1957) The ABI stable implementation incorrectly used getCurrentStream().id() which returns a StreamId (int64_t) instead of the actual cudaStream_t pointer. Casting an integer ID to a stream pointer caused segmentation faults when using custom CUDA streams. Fixed by using the proper AOTI C API function aoti_torch_get_current_cuda_stream() which returns the actual CUDA stream pointer. --- hopper/flash_api_stable.cpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 4d2700bf271..6de5c5ac380 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -16,6 +16,10 @@ #include #include #include +#include + +// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h +extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); #include #include @@ -717,7 +721,9 @@ mha_fwd_get_scheduler_metadata( int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); - auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -1227,7 +1233,9 @@ mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_ if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); - auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd(params, stream); if (params.num_splits > 1) { if (out_type == torch::headeronly::ScalarType::BFloat16) { @@ -1619,7 +1627,9 @@ std::tuple mha_b if (total_q > 0 && total_k > 0 && num_heads_k > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); - auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); run_mha_bwd(params, stream); } else if (total_k > 0 && num_heads_k > 0) { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. @@ -1726,7 +1736,9 @@ mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x nu if (seqlen > 0 && batch_size > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); - auto stream = (cudaStream_t)torch::stable::accelerator::getCurrentStream(device_idx).id(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } From 9450df6612a9eaeefbe6154b8c8731b6625dab9a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 24 Oct 2025 14:50:00 -0400 Subject: [PATCH 186/258] [Cute,Fwd,Sm100] Fix interface w score mod to get it to run --- flash_attn/cute/flash_fwd_sm100.py | 12 +++++++----- flash_attn/cute/interface.py | 2 -- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 83755896d51..0758d3f405b 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -223,9 +223,11 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - aux_tensors: Optional[ - list - ] = None, # Not typing for now since conversion behaves a lil funny + full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) + mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + aux_tensors: Optional[list] = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -1966,7 +1968,7 @@ def correction_rescale( tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): - tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), ) @@ -2041,7 +2043,7 @@ def correction_epilogue( tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): - tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index e3d2eb0891b..b77a70d9211 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -317,8 +317,6 @@ def _flash_attn_fwd( score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False - print(mask_mod_hash) - if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) From 7ef1a6f3a79958cb08b04c9da1d94ace6dd24812 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 24 Oct 2025 16:10:31 -0400 Subject: [PATCH 187/258] [Cute,Sm100] In gemm ptx, add to base smem_address instead --- flash_attn/cute/blackwell_helpers.py | 23 ++++++++++++++--------- flash_attn/cute/flash_bwd_sm100.py | 28 +++++++++++++++++++--------- flash_attn/cute/flash_fwd_sm100.py | 15 +++++++-------- 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 83ba1cd518d..f3335b3923e 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -439,24 +439,27 @@ def gemm_ptx_partial( ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" - "mov.b32 smem_desc_a_lo, $0;\n\t" - "mov.b32 smem_desc_b_lo, $1;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" - f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" - f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( - f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" - f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" @@ -504,6 +507,7 @@ def gemm_ptx_partial( ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_b;\n\t" @@ -511,15 +515,16 @@ def gemm_ptx_partial( f"mov.b32 idesc, {hex(idesc)};\n\t" f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_a, $0;\n\t" - f"mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" - f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" - f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 7eaf7b95849..e02a05512e1 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -146,7 +146,7 @@ def __init__( self.num_regs_reduce = 160 self.num_regs_compute = 128 - self.num_regs_other = 96 + self.num_regs_other = 80 self.num_regs_empty = 24 assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 @@ -1195,16 +1195,24 @@ def mma( tdVrdO = tiled_mma_dV.make_fragment_B(sdO) tdVrP = tiled_mma_dV.make_fragment_A(tP)[None, None, None, 0] - mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) - # mma_qk_fn = partial( - # gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, A_idx=0, zero_init=True - # ) - mma_dov_fn = partial( - gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True + # mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) + mma_qk_fn = partial( + gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, A_idx=0, zero_init=True ) # mma_dov_fn = partial( - # gemm_ptx_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, sA=sV, sB=sdOt, A_idx=0, zero_init=True + # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True # ) + mma_dov_fn = partial( + gemm_ptx_w_idx, + tiled_mma_SdP, + tdPtdP, + tdPrV, + tdPrdOt, + sA=sV, + sB=sdOt, + A_idx=0, + zero_init=True, + ) mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, A_idx=None) # mma_pdo_fn = partial( # gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO, A_idx=None @@ -1832,6 +1840,8 @@ def dQacc_reduce( barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) self.reduce_sync_barrier.arrive_and_wait() + gdQaccum_cur = gdQaccum[None, None, m_block] + # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops delay_tma_store = False @@ -1846,8 +1856,8 @@ def tma_store_fn(src_idx, dst_idx): with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, src_idx].iterator, - gdQaccum[None, dst_idx, m_block].iterator, self.tma_copy_bytes["dQ"] // 1, + gdQaccum_cur[None, dst_idx].iterator, ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0758d3f405b..9d5a814104d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -160,7 +160,8 @@ def __init__( self.num_regs_correction = 64 self.num_regs_other = 48 else: - self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + self.num_regs_softmax = 200 # self.num_regs_softmax = 176 # self.num_regs_correction = 96 # self.num_regs_correction = 80 @@ -169,9 +170,9 @@ def __init__( # self.num_regs_other = 32 # self.num_regs_other = 64 # self.num_regs_other = 80 - # self.num_regs_other = 48 + self.num_regs_other = 48 # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 - self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 self.num_regs_empty = 24 self.buffer_align_bytes = 1024 @@ -1173,11 +1174,9 @@ def mma( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM - thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM - tSrQ = thr_mma_qk.make_fragment_A(sQ) - tSrK = thr_mma_qk.make_fragment_B(sK) - tOrV = thr_mma_pv.make_fragment_B(sV) + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tOrV = tiled_mma_pv.make_fragment_B(sV) if const_expr(self.q_stage == 2): tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) else: From b3f437fbcbeb0dd38e838cae418cfec3fb3e8fa9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 24 Oct 2025 21:45:13 -0400 Subject: [PATCH 188/258] [Cute,Bwd,Sm100] Make postprocessing work, add interface --- flash_attn/cute/flash_bwd_postprocess.py | 132 +++++++++++++++++------ flash_attn/cute/flash_bwd_sm100.py | 17 ++- flash_attn/cute/flash_bwd_sm90.py | 2 + flash_attn/cute/interface.py | 127 ++++++++++++++-------- 4 files changed, 197 insertions(+), 81 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 9aa7979adf6..45a0d102eba 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -33,7 +33,7 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, - arch: Literal[80, 90], + arch: Literal[80, 90, 100], tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, @@ -47,7 +47,9 @@ def __init__( """ self.dtype = dtype self.tile_m = tile_m - assert arch in [80, 90], "Only Ampere (80) and Hopper (90) are supported" + assert arch in [80, 90, 100], ( + "Only Ampere (80), Hopper (90), and Blackwell (100) are supported" + ) self.arch = arch # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 @@ -92,7 +94,7 @@ def _get_tiled_mma(self): atom_layout_dQ, permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), ) - else: + elif const_expr(self.arch == 90): num_mma_warp_groups = self.num_threads // 128 atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) @@ -106,7 +108,18 @@ def _get_tiled_mma(self): + (1,), tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], ) - assert self.num_threads == tiled_mma.size + else: + cta_group = tcgen05.CtaGroup.ONE + tiled_mma = sm100_utils_basic.make_trivial_tiled_mma( + self.dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode + Float32, + cta_group, + (self.tile_m, self.tile_hdim), + ) + if const_expr(self.arch in [80, 90]): + assert self.num_threads == tiled_mma.size return tiled_mma def _setup_attributes(self): @@ -133,7 +146,8 @@ def _setup_attributes(self): self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_s2r_copy_elems ) - else: + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + elif const_expr(self.arch == 90): num_threads_per_warp_group = 128 num_mma_warp_groups = self.num_threads // 128 self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( @@ -141,20 +155,26 @@ def _setup_attributes(self): cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout cute.make_layout(128 // Float32.width), # val_layout ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + ) + else: + self.dQ_reduce_ncol = 32 + dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + assert self.num_threads == 128 # TODO: currently hard-coded + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage) + ) self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( self.dtype, self.tile_hdim, self.num_threads ) # /////////////////////////////////////////////////////////////////////////////// - # Shared memory layout: dQaccum / dQ + # Shared memory layout: dQ # /////////////////////////////////////////////////////////////////////////////// - if const_expr(self.arch == 80): - self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) - else: - num_mma_warp_groups = self.num_threads // 128 - self.sdQaccum_layout = cute.make_layout( - (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) - ) # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. @@ -164,10 +184,15 @@ def _setup_attributes(self): self.sdQ_layout = cute.tile_to_shape( sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) ) - else: + elif const_expr(self.arch == 90): self.sdQ_layout = sm90_utils.make_smem_layout( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) ) + else: + # TODO: this is hard-coded for hdim 128 + self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1 + ) @cute.jit def __call__( @@ -247,7 +272,7 @@ def __call__( TileScheduler, ).launch( grid=grid_dim, - block=[self.tiled_mma.size, 1, 1], + block=[self.num_threads, 1, 1], smem=smem_size, stream=stream, ) @@ -276,7 +301,14 @@ def kernel( smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) - sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + if const_expr(self.arch in [80, 90]): + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + else: + # extra stage dimension + sdQ = cute.make_tensor( + cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), + sdQ_layout.outer, + )[None, None, 0] sdQt = utils.transpose_view(sdQ) # Thread index, block index @@ -344,11 +376,28 @@ def kernel( s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) tile_shape = (self.tile_m, self.tile_hdim) - acc_shape = tiled_mma.partition_shape_C( - tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] - ) - acc = cute.make_fragment(acc_shape, cutlass.Float32) - assert cute.size(acc) == cute.size(tdQsdQaccum) + acc = None + tiled_copy_t2r = None + if const_expr(self.arch in [80, 90]): + acc_shape = tiled_mma.partition_shape_C( + tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + else: + thr_mma = tiled_mma.get_slice(0) # 1-CTA + dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) + tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) + tdQcdQ = thr_mma.partition_C( + cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + ) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) # Convert tdQrdQaccum from fp32 to fp16/bf16 @@ -357,27 +406,46 @@ def kernel( # Step 3: Copy dQ from register to smem cute.arch.barrier() # make sure all threads have finished loading dQaccum - smem_copy_atom_dQ = utils.get_smem_store_atom( - self.arch, self.dtype, transpose=self.dQ_swapAB - ) - smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) - taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) - taccdQsdQ = smem_thr_copy_dQ.partition_D( - sdQ if const_expr(not self.dQ_swapAB) else sdQt - ) - cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) + if const_expr(self.arch in [80, 90]): + copy_atom_r2s_dQ = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.dQ_swapAB + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) + else: + # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( + # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, + # ) + # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) + thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads + val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) + copy_atom_r2s_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( + copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ + ) + thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + if const_expr(self.arch in [80, 90]): + taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) + else: + taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape + taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) + taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt) + cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + cute.arch.barrier() # make sure all smem stores are done gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) - cute.arch.barrier() # make sure all smem stores are done # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled cute.autovec_copy(tdQsdQ, tdQrdQ) # Step 5: Copy dQ from register to gmem - cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e02a05512e1..0945376ebf9 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -45,6 +45,7 @@ def __init__( is_persistent: bool = False, deterministic: bool = False, ): + assert qhead_per_kvhead == 1, "GQA is not supported yet in FlashAttentionBackwardSm100" # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) @@ -308,10 +309,20 @@ def __call__( mdV: cute.Tensor, softmax_scale: Float32, stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, ): + assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( + "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" + ) self.q_dtype = mQ.element_type self.k_dtype = mK.element_type self.v_dtype = mV.element_type @@ -409,13 +420,13 @@ def __call__( val_layout_r2s_dKV = cute.make_ordered_layout( (1, 128 // self.dk_dtype.width), order=(1, 0) ) # 4 or 8 vals for 16 byte store - r2s_copy_atom_r2s_dKV = cute.make_copy_atom( + copy_atom_r2s_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=128, ) tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( - r2s_copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV + copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) @@ -1856,8 +1867,8 @@ def tma_store_fn(src_idx, dst_idx): with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, src_idx].iterator, - self.tma_copy_bytes["dQ"] // 1, gdQaccum_cur[None, dst_idx].iterator, + self.tma_copy_bytes["dQ"] // 1, ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index bfb67824be0..59d4c2c4680 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -937,6 +937,8 @@ def mma( mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( mask.apply_mask, + batch_idx=None, + head_idx=None, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b77a70d9211..c3fb3fa3c3b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -38,6 +38,7 @@ from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine @@ -513,17 +514,26 @@ def _flash_attn_bwd( seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - m_block_size = 80 if not causal else 64 - n_block_size = 128 - num_stages_Q = 2 - num_stages_dO = 2 - num_stages_PdS = 2 - SdP_swapAB = True - dKV_swapAB = False - dQ_swapAB = not causal - AtomLayoutMSdP = 1 - AtomLayoutNdKV = 2 - AtomLayoutMdQ = 1 + compute_capability = torch.cuda.get_device_capability()[0] + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + + if compute_capability == 9: + m_block_size = 80 if not causal else 64 + n_block_size = 128 + num_stages_Q = 2 + num_stages_dO = 2 + num_stages_PdS = 2 + SdP_swapAB = True + dKV_swapAB = False + dQ_swapAB = not causal + AtomLayoutMSdP = 1 + AtomLayoutNdKV = 2 + AtomLayoutMdQ = 1 + else: + m_block_size = 128 + n_block_size = 128 + dQ_swapAB = False + AtomLayoutMdQ = 1 q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -723,73 +733,98 @@ def _flash_attn_bwd( ) # Backward kernel: compute dk, dv, dq_accum. - compile_key = ( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - causal, - softcap != 0.0, - m_block_size, - n_block_size, - num_threads, - pack_gqa, - num_stages_Q, - num_stages_dO, - SdP_swapAB, - dKV_swapAB, - dQ_swapAB, - AtomLayoutMSdP, - AtomLayoutNdKV, - AtomLayoutMdQ, - V_in_regs, - ) - num_threads = 384 - if compile_key not in _flash_attn_bwd.compile_cache: - fa_bwd_sm80 = FlashAttentionBackwardSm80( + if compute_capability == 9: + compile_key = ( + compute_capability, dtype, head_dim, head_dim_v, qhead_per_kvhead, + causal, + softcap != 0.0, m_block_size, n_block_size, - num_stages_Q, - num_stages_dO, num_threads, pack_gqa, - causal, + num_stages_Q, + num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, - V_in_regs=V_in_regs, + V_in_regs, ) - fa_bwd_sm90 = FlashAttentionBackwardSm90( + else: + compile_key = ( + compute_capability, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + ) + num_threads = 384 + if compile_key not in _flash_attn_bwd.compile_cache: + fa_bwd_sm80 = FlashAttentionBackwardSm80( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, m_block_size, n_block_size, num_stages_Q, num_stages_dO, - num_stages_PdS, + num_threads, + pack_gqa, + causal, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, - num_threads, V_in_regs=V_in_regs, ) + if compute_capability == 9: + fa_bwd_obj = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_stages_PdS, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + num_threads, + V_in_regs=V_in_regs, + ) + else: + fa_bwd_obj = FlashAttentionBackwardSm100( + head_dim, + head_dim_v, + is_causal=causal, + qhead_per_kvhead=qhead_per_kvhead, + # tile_m=m_block_size, + # tile_n=n_block_size, + ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( - # fa_bwd_sm80, - fa_bwd_sm90, + fa_bwd_obj, q_tensor, k_tensor, v_tensor, @@ -824,11 +859,11 @@ def _flash_attn_bwd( seqused_k_tensor, ) - num_threads -= 128 + num_threads = 256 if compute_capability == 9 else 128 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: - arch = 90 + arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB ) From 6eb7c8037b4eadd2134f4c2b10adf7a320242a8a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 15:37:56 -0400 Subject: [PATCH 189/258] [Cute,Bwd,Sm100] Simplify layouts in compute_loop --- flash_attn/cute/copy_utils.py | 17 ++++ flash_attn/cute/flash_bwd_sm100.py | 122 ++++++++++++++--------------- flash_attn/cute/mask.py | 10 +-- 3 files changed, 79 insertions(+), 70 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index a97344768de..dd314bffa60 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -7,6 +7,7 @@ import cutlass.cute as cute from cutlass import Float32, Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cutlass_dsl import dsl_user_op from cutlass._mlir.dialects import llvm import cutlass.pipeline @@ -47,6 +48,22 @@ def get_copy_atom( return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), + stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + + @dsl_user_op def copy( src: cute.Tensor, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 0945376ebf9..6f2f75c2b89 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1514,19 +1514,18 @@ def compute_loop( # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 - tidx = cute.arch.thread_idx()[0] + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) dp_idx = tidx % 128 - wg_idx = (tidx % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))) // 128 - wg_idx = cute.arch.make_warp_uniform(wg_idx) num_wg = len(self.compute_warp_ids) // 4 # 2 # wg_idx: # 0: [256...384] # 1: [128...256] tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) - tStP = cute.composition(tStS, cute.make_layout((self.tile_m, tileP_f32_like))) + # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) + tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) - tScP = cute.composition(tScS, cute.make_layout((self.tile_m, tileP_f32_like))) + tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 @@ -1535,23 +1534,33 @@ def compute_loop( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS).get_slice(dp_idx) - tStS_t2r_p = thr_tmem_load.partition_S(tStS) - tStS_t2r = self.split_wg(tStS_t2r_p, wg_idx, num_wg) - tdPtdP_t2r_p = thr_tmem_load.partition_S(tdPtdP) - tdPtdP_t2r = self.split_wg(tdPtdP_t2r_p, wg_idx, num_wg) - tScS_t2r_p = thr_tmem_load.partition_D(tScS) - tScS_t2r = self.split_wg(tScS_t2r_p, wg_idx, num_wg) - tSsLSE_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sLSE_2D)) - tSsLSE = self.split_wg(tSsLSE_p, wg_idx, num_wg) # ((32, 1), 2, 1, 1, STAGE) - tSsdPsum_p = thr_tmem_load.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) - tSsdPsum = self.split_wg(tSsdPsum_p, wg_idx, num_wg) - - thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(dp_idx) - tScP_r2t_p = thr_tmem_store.partition_S(tScP) - tScP_r2t = self.split_wg(tScP_r2t_p, wg_idx, num_wg) - tStP_r2t_p = thr_tmem_store.partition_D(tStP) - tStP_r2t = self.split_wg(tStP_r2t_p, wg_idx, num_wg) + # tmem -> rmem + thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx) + tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1) + tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP) + tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) + # ((32, 1), 2, 1, 1, STAGE) + tSsLSE = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sLSE_2D)) + tSsdPsum = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) + # rmem -> tmem + thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) + tScP_r2t = thr_copy_r2t.partition_S(tScP) + tStP_r2t = thr_copy_r2t.partition_D(tStP) + # rmem -> smem + # This part is a bit iffy, we might be making a lot of assumptions here + copy_atom_r2s = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r + ) + thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx) + # We assume the swizzle (i.e. layout.inner) stays the same + sdS_layout = sm100_utils_basic.make_smem_layout_epi( + self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1 + ).outer # ((8,16), (64,2), (1, 1)) + sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2)) + # Need to group into 1 mode to be compatible w thr_copy_r2s + sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,)) + sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout) + tRS_sdS = thr_copy_r2s.partition_D(sdS_epi) consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 cutlass.pipeline.PipelineUserType.Consumer, 1 @@ -1571,9 +1580,7 @@ def compute_loop( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) # TODO: condition mask_seqlen mask_fn = partial( @@ -1589,28 +1596,21 @@ def compute_loop( pipeline_S_P.consumer_wait(consumer_state_S_P_dP) # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) #### TMEM->RMEM (Load S from TMEM) - tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) # 64 - cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) + cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) consumer_phase_LSE ^= 1 #### APPLY MASK - if const_expr(self.is_causal or self.is_local): - mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block) + mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block) # --------------------------------------------- #### P = exp(S - LSE) # --------------------------------------------- - lane_idx = cute.arch.lane_idx() - - tSrP_r2t_f32 = cute.make_fragment(tScP_r2t[None, None, 0].shape, Float32) # 16 - tSrP_r2t = cute.make_tensor( - cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), - tSrS_t2r[None, 0, None, None].layout, - ) - - for stage in cutlass.range_constexpr(cute.size(tStP_r2t, mode=[2]), unroll=1): + tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64 + tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) + for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsLSE_cur = tSsLSE[None, stage, 0, 0, 0] # TODO: have stages if const_expr(not self.shuffle_LSE): @@ -1618,7 +1618,7 @@ def compute_loop( cute.autovec_copy(tSsLSE_cur, tSrLSE) else: tSrLSE = tSsLSE_cur[lane_idx] - for v in cutlass.range_constexpr(cute.size(tSrP_r2t) // 2, unroll_full=True): + for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2): if const_expr(not self.shuffle_LSE): lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1]) else: @@ -1633,13 +1633,17 @@ def compute_loop( ) tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) - utils.cvt_f16(tSrS_cur, tSrP_r2t[None, 0, 0]) + utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0]) if const_expr(stage == 0): cute.arch.fence_view_async_tmem_load() # Without this barrier, we could have 1 warp writing to P in tmem while # another warp is still reading S from tmem. self.compute_sync_barrier.arrive_and_wait() - cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t[None, None, stage]) + cute.copy( + thr_copy_r2t, + tSrP_r2t_f32[None, stage, None, None], + tStP_r2t[None, stage, None, None], + ) cute.arch.fence_view_async_tmem_store() @@ -1660,21 +1664,10 @@ def compute_loop( # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) pipeline_dS.producer_acquire(producer_state_dS) - #### TMEM->RMEM (Load dP from TMEM) - # ((32,1),1,1) - tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) - ##### dS.T = P.T * (dP.T - Psum) - sdSt_mn = cute.composition(sdS, cute.make_layout((self.tile_m, self.tile_n))) - tdKsdS = cute.composition( - sdSt_mn[(None, wg_idx), dp_idx], cute.make_layout(tSrS_t2r.shape) - ) - tSrS_t2r_bf16 = cute.make_tensor( - cute.recast_ptr(tSrS_t2r.iterator, dtype=self.ds_dtype), tSrS_t2r.shape - ) - for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): - cute.copy(thr_tmem_load, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) + tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) + cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] @@ -1684,7 +1677,7 @@ def compute_loop( cute.autovec_copy(tSsdPsum_cur, tSrdPsum) else: tSrdPsum = tSsdPsum_cur[lane_idx] - for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r) // 2, unroll=1): + for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2): if const_expr(not self.shuffle_dPsum): dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1]) else: @@ -1699,8 +1692,9 @@ def compute_loop( (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), ) - utils.cvt_f16(tdPrdP_cur, tSrS_t2r_bf16[None, stage, 0, 0]) - cute.autovec_copy(tSrS_t2r_bf16[None, stage, 0, 0], tdKsdS[None, stage, 0, 0]) + tdPrdP_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) + utils.cvt_f16(tdPrdP_cur, tdPrdP_cvt) + cute.autovec_copy(tdPrdP_cvt, tRS_sdS[None, stage]) cute.arch.sync_warp() with cute.arch.elect_one(): @@ -1798,10 +1792,10 @@ def dQacc_reduce( tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 ) - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) - tdQtdQ_t2r = thr_tmem_load.partition_S(tdQtdQ) + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) + tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ) tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) - tdQrdQ_t2r_shape = thr_tmem_load.partition_D(tdQcdQ).shape + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, ( "dQaccum reduce stage mismatch" ) @@ -1839,7 +1833,7 @@ def dQacc_reduce( pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) - cute.copy(thr_tmem_load, tdQtdQ_t2r, tdQrdQ_t2r) + cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r) cute.arch.fence_view_async_tmem_load() cute.arch.sync_warp() with cute.arch.elect_one(): @@ -2123,15 +2117,15 @@ def epilogue_dK_or_dV_tma( for s in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) - tdKVtdKV_t2r_p = thr_tmem_load.partition_S(tdKVtdKV) + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) + tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) - tdKVcdKV_t2r_p = thr_tmem_load.partition_D(tdKVcdKV) + tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVcdKV_t2r = tdKVcdKV_t2r[None, s] @@ -2143,7 +2137,7 @@ def epilogue_dK_or_dV_tma( ) # TMEM -> RMEM -- copy and fence - cute.copy(thr_tmem_load, tdKVtdKV_t2r, tdKVrdKV_t2r) + cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r) cute.arch.fence_view_async_tmem_load() # RMEM -- scale and convert diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 7b830f42c4e..fabc251bb8f 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -260,16 +260,16 @@ def apply_mask_sm100( mask_local: cutlass.Constexpr[bool] = False, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): - ncol = const_expr(cute.size(tScS_t2r.shape)) if const_expr(not r2p): - for i in cutlass.range(ncol, unroll_full=True): + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -Float32.inf # For some reason the 2 lines above generate really bad SASS @@ -331,8 +331,6 @@ def apply_mask_sm100_transposed( tScS_t2r: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, - wg_idx: cutlass.Int32, - num_wg: cutlass.Constexpr[cutlass.Int32], mask_seqlen: cutlass.Constexpr, mask_causal: cutlass.Constexpr, mask_local: cutlass.Constexpr, @@ -358,7 +356,7 @@ def apply_mask_sm100_transposed( if const_expr(mask_causal): col_limit_left = row_idx + causal_row_offset ncol = const_expr(cute.size(tScS_t2r.shape)) - # if tidx == 32 and wg_idx == 1: + # if tidx == 32: # cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1]) if const_expr(mask_seqlen): if tScS_t2r[0][0] >= seqlenk_row_limit: From 93a0afeb816f194c862a1b3a5c586ed52b15d675 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 17:47:01 -0400 Subject: [PATCH 190/258] [Cute,Bwd,Sm100] Causal mask --- benchmarks/benchmark_attn.py | 1 + flash_attn/cute/flash_bwd_sm100.py | 10 ++-- flash_attn/cute/mask.py | 83 ++++++++++++++++++++++-------- 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 7830477a68a..511019265d1 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -183,6 +183,7 @@ def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None # use_causal_mask_bottom_right=causal or window_size_left is not None, use_causal_mask=causal or window_size_left is not None, sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + use_deterministic_algorithm=False, ) dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 6f2f75c2b89..6b9378f4cd0 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -147,7 +147,7 @@ def __init__( self.num_regs_reduce = 160 self.num_regs_compute = 128 - self.num_regs_other = 80 + self.num_regs_other = 96 self.num_regs_empty = 24 assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 @@ -846,6 +846,7 @@ def kernel( AttentionMask, self.tile_m, self.tile_n, + swap_AB=True, ) cute.arch.sync_threads() @@ -960,7 +961,6 @@ def kernel( tdKtdK, mdV, mdK, - sdSt, sdS, tdPtdP, LSE_full_mbar_ptr, @@ -1466,7 +1466,6 @@ def compute_loop( tdKtdK: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, - sdSt: cute.Tensor, sdS: cute.Tensor, tdPtdP: cute.Tensor, LSE_full_mbar_ptr: cute.Pointer, @@ -1539,6 +1538,7 @@ def compute_loop( tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1) tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP) tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) + t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1) # ((32, 1), 2, 1, 1, STAGE) tSsLSE = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sLSE_2D)) tSsdPsum = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) @@ -1585,6 +1585,8 @@ def compute_loop( # TODO: condition mask_seqlen mask_fn = partial( mask.apply_mask_sm100_transposed, + tScS_t2r=tScS_t2r, + t0ScS_t2r=t0ScS_t2r, n_block=n_block, mask_seqlen=True, mask_causal=self.is_causal, @@ -1602,7 +1604,7 @@ def compute_loop( consumer_phase_LSE ^= 1 #### APPLY MASK - mask_fn(tSrS_t2r, tScS_t2r, m_block=m_block) + mask_fn(tSrS_t2r, m_block=m_block) # --------------------------------------------- #### P = exp(S - LSE) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index fabc251bb8f..2d65856d223 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -40,6 +40,33 @@ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = Fal X[r, c] = X[r, c] if in_bound else -Float32.inf +@cute.jit +def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127 + # or 0, 1, ..., 15, 32, ..., 47, 64, ... + # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # Here we hardcode for the case of 2 warp groups. + num_wg = 2 + row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min( + row_limit_top % (num_rep * num_wg), num_rep + ) + ncol = cute.size(X.shape) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + row_limit_top_s = max(row_limit_top_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << row_limit_top_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + out_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + X[c] = -Float32.inf if out_bound else X[c] + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx == 128: + # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) + + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -219,7 +246,9 @@ def apply_mask( # If col0 is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. row_limit_top = ( - self.tile_m if col0 >= seqlenk_col_limit else col0 - causal_row_offset + self.tile_m + if col0 >= seqlenk_col_limit and mask_seqlen + else col0 - causal_row_offset ) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = ( @@ -329,6 +358,7 @@ def apply_mask_sm100_transposed( self, acc_S: cute.Tensor, tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, mask_seqlen: cutlass.Constexpr, @@ -339,30 +369,39 @@ def apply_mask_sm100_transposed( Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" - - tidx = cute.arch.thread_idx()[0] % 128 - - seqlenk_row_limit = self.seqlen_k - n_block * self.tile_n + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + thr_col_offset = tScS_t2r[0][COL] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): - ncol = const_expr(cute.size(tScS_t2r.shape)) - if tScS_t2r[0][0] >= seqlenk_row_limit: - for i in cutlass.range(ncol, unroll_full=True): + if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = -cutlass.Float32.inf else: # Causal or local - causal_row_offset = (self.seqlen_q - self.seqlen_k - 1) - m_block * self.tile_m - row_idx = tScS_t2r[0][0] + n_block * self.tile_n - + thr_row_offset = tScS_t2r[0][ROW] + causal_row_offset = ( + seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset + ) if const_expr(mask_causal): - col_limit_left = row_idx + causal_row_offset - ncol = const_expr(cute.size(tScS_t2r.shape)) - # if tidx == 32: - # cute.printf("row idx = {}, causal_row_offset = {}, col_limit_left = {}, first column = {}, last column = {} ", row_idx, causal_row_offset, col_limit_left, tScS_t2r[0][1], tScS_t2r[ncol - 1][1]) + col0 = t0ScS_t2r[0][COL] + row_limit_top = col0 - causal_row_offset + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx < 32: + # cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0) if const_expr(mask_seqlen): - if tScS_t2r[0][0] >= seqlenk_row_limit: - col_limit_left = self.tile_m - for i in cutlass.range(ncol, unroll_full=True): - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][1] <= col_limit_left else acc_S[i] - ) - # TODO: local + # If col is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + row_limit_top = self.tile_m + r2p = True + if const_expr(not r2p): + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i] + ) + else: + num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 + mask_r2p_transposed(acc_S, row_limit_top, num_rep) + else: + assert False, "Local masking isn't supported yet" From 662cf9c5b5df78d02c780608da7603901732954f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 17:50:14 -0400 Subject: [PATCH 191/258] [Cute,Bwd,Sm100] Enable bwd tests --- tests/cute/test_flash_attn.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 6c3a679a613..7dc132e4f7e 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -29,18 +29,18 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_learnable_sink", [False, True]) -# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -51,8 +51,8 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) -@pytest.mark.parametrize("d", [128, 192]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -60,6 +60,7 @@ (3, 3), (64, 32), (64, 128), + (128, 128), (128, 192), (256, 256), (239, 1), @@ -76,6 +77,7 @@ (1024, 1024), (1023, 1024), (1024, 1023), + (2048, 2048), (4096, 4096), (4224, 4224), ], @@ -219,7 +221,8 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] - pack_gqa_vals = [False, True, None] + # pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False] num_splits_vals = [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( @@ -257,7 +260,7 @@ def test_flash_attn_output( and not local and dv == d and learnable_sink is None - and False + # and False ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -272,6 +275,7 @@ def test_flash_attn_output( # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad( From 79b9030c14ee30091342c1d7abe260b2f594a788 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 17:51:31 -0400 Subject: [PATCH 192/258] [Cute,Bwd] Enable bwd benchmarks --- benchmarks/benchmark_attn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 511019265d1..5b3de776ec0 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -227,7 +227,7 @@ def run(*args, **kwargs): device = 'cuda' verbose = True varlen = False -has_backward = False +has_backward = True page_size = None # page_size = 128 softcap = 0.0 @@ -244,10 +244,10 @@ def run(*args, **kwargs): headdim = 256 # for headdim in [64, 128, 256]: # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] -bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] +# bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] # bs_seqlen_vals = [(32, 512), (16, 1024)] # bs_seqlen_vals = [(2, 64 * 132)] -# bs_seqlen_vals = [(4, 8192)] +bs_seqlen_vals = [(4, 8192)] # bs_seqlen_vals = [(1, 16 * 1024)] time_f = {} time_b = {} @@ -267,8 +267,8 @@ def run(*args, **kwargs): # seqlen = 512 # nheads = 8 # headdim = 128 - # nheads_kv = nheads - nheads_kv = nheads // 8 + nheads_kv = nheads + # nheads_kv = nheads // 8 # nheads_kv = 1 # headdim_v = headdim headdim_v = 128 if headdim == 192 else headdim @@ -383,7 +383,7 @@ def run(*args, **kwargs): _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean - # time.sleep(1) + time.sleep(1) # if not varlen: # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) # else: From 510fe92da31e1f702ad8fc2036368041f0730d5f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 20:44:14 -0400 Subject: [PATCH 193/258] [Cute] Add store_shared_remote_fp32x4 util function --- flash_attn/cute/copy_utils.py | 70 +++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index dd314bffa60..45ec493aaa3 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -8,7 +8,7 @@ from cutlass import Float32, Int32, Boolean, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cutlass_dsl import dsl_user_op +from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import llvm import cutlass.pipeline @@ -57,13 +57,11 @@ def make_tmem_copy( assert num_bits == 32 tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) layout_tv = cute.make_layout( - ((32, 4, num_wg), (num_rep, 32)), - stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) ) return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) - @dsl_user_op def copy( src: cute.Tensor, @@ -145,6 +143,70 @@ def atomic_add_fp32x4( asm_dialect=llvm.AsmDialect.AD_ATT, ) + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + @dsl_user_op def cpasync_bulk_g2s( gmem_ptr: cute.Pointer, From b634499757f12f206c9ea9ca0d4349855bf5efe8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 22:21:08 -0400 Subject: [PATCH 194/258] [Cute,Bwd,Sm100] Tune registers --- flash_attn/cute/flash_bwd_sm100.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 6b9378f4cd0..357c2a469d9 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -145,9 +145,13 @@ def __init__( self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m - self.num_regs_reduce = 160 - self.num_regs_compute = 128 - self.num_regs_other = 96 + if not is_causal and not is_local: + self.num_regs_reduce = 152 + self.num_regs_compute = 136 + else: + self.num_regs_reduce = 136 + self.num_regs_compute = 144 + self.num_regs_other = 96 - 8 self.num_regs_empty = 24 assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 @@ -849,8 +853,6 @@ def kernel( swap_AB=True, ) - cute.arch.sync_threads() - # EMPTY # (15) if warp_idx == self.empty_warp_id: @@ -949,7 +951,7 @@ def kernel( # Compute # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: - cute.arch.warpgroup_reg_dealloc(self.num_regs_compute) # 8 warps + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps self.compute_loop( thr_mma_SdP, thr_mma_dV, @@ -1664,7 +1666,6 @@ def compute_loop( pipeline_dP.consumer_wait(consumer_state_S_P_dP) # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) - pipeline_dS.producer_acquire(producer_state_dS) ##### dS.T = P.T * (dP.T - Psum) for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): @@ -1696,6 +1697,8 @@ def compute_loop( ) tdPrdP_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) utils.cvt_f16(tdPrdP_cur, tdPrdP_cvt) + if const_expr(stage == 0): + pipeline_dS.producer_acquire(producer_state_dS) cute.autovec_copy(tdPrdP_cvt, tRS_sdS[None, stage]) cute.arch.sync_warp() From e873ad00fb10bab2e300c9a342bd6639612cac10 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 25 Oct 2025 22:55:45 -0400 Subject: [PATCH 195/258] [Cute,Sm100] acc_tmem_addr is Int32 instead of constexpr --- flash_attn/cute/blackwell_helpers.py | 17 +++++++++++------ flash_attn/cute/flash_bwd_sm100.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index f3335b3923e..1cac21f8f38 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -358,7 +358,7 @@ def gemm_ptx_loop( @cute.jit def gemm_ptx_partial( op: cute.nvgpu.tcgen05.mma.MmaOp, - acc_tmem_addr: cutlass.Constexpr[int], + acc_tmem_addr: Int32, tCrA: cute.Tensor, tCrB: cute.Tensor, sA: Optional[cute.Tensor], @@ -433,6 +433,7 @@ def gemm_ptx_partial( Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" @@ -445,7 +446,8 @@ def gemm_ptx_partial( ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" - f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" "mov.b32 smem_desc_a_lo_start, $0;\n\t" "mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" @@ -467,7 +469,8 @@ def gemm_ptx_partial( for k in range(1, cute.size(tCrA.shape[2])) ) + "}\n", - "r,r,r", + # "r,r,r", + "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, @@ -477,6 +480,7 @@ def gemm_ptx_partial( Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ] if const_expr(mbar_ptr is not None): assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" @@ -485,7 +489,7 @@ def gemm_ptx_partial( mbar_wait_str = ( ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [$3], $4, 10000000; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" "@P1 bra DONE; \n\t" "bra LAB_WAIT; \n\t" "DONE: \n\t" @@ -513,7 +517,8 @@ def gemm_ptx_partial( ".reg .b64 smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" - f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" f"mov.b32 tmem_a, $0;\n\t" f"mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" @@ -550,7 +555,7 @@ def gemm_ptx_partial( else "" ) + "}\n", - "r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 357c2a469d9..9f49a98aa20 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -797,33 +797,36 @@ def kernel( sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) # S thr_mma_SdP = tiled_mma_SdP.get_slice(0) Sacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) tStS = thr_mma_SdP.make_fragment_C(Sacc_shape) # (MMA, MMA_M, MMA_N) - tStS = cute.make_tensor(tStS.iterator + self.tmem_S_offset, tStS.layout) + tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) # dP dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) - tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dP_offset, tdPtdP.layout) + tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) # dV thr_mma_dV = tiled_mma_dV.get_slice(0) dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) - tdVtdV = cute.make_tensor(tdVtdV.iterator + self.tmem_dV_offset, tdVtdV.layout) + tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) tP_ptr = cute.make_ptr(self.do_dtype, self.tmem_P_offset, cute.AddressSpace.tmem) tP = cute.make_tensor(tP_ptr, tP_layout.outer) # dK thr_mma_dK = tiled_mma_dK.get_slice(0) dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) - tdKtdK = cute.make_tensor(tdKtdK.iterator + self.tmem_dK_offset, tdKtdK.layout) + tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) # dQ thr_mma_dQ = tiled_mma_dQ.get_slice(0) dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) - tdQtdQ = cute.make_tensor(tdQtdQ.iterator + self.tmem_dQ_offset, tdQtdQ.layout) + tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout) block_info = BlockInfo( self.tile_m, From 2c7177d0b0d1f1c6d195e42c5c7afc9df210e0ae Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 Oct 2025 00:05:44 -0400 Subject: [PATCH 196/258] [Cute,Bwd,Sm100] Reduce sync --- flash_attn/cute/flash_bwd_sm100.py | 72 +++++++++--------------------- 1 file changed, 21 insertions(+), 51 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 9f49a98aa20..b7961feda06 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1247,21 +1247,10 @@ def mma( consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) - # producer_state_S_P = cutlass.pipeline.make_pipeline_state( - # cutlass.pipeline.PipelineUserType.Producer, 1 - # ) - producer_phase_S_P = Int32(1) - # producer_state_dP = cutlass.pipeline.make_pipeline_state( - # cutlass.pipeline.PipelineUserType.Producer, 1 - # ) - producer_phase_dP = Int32(1) + producer_phase_acc = Int32(1) # For S & P, dP, dQ consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) - # producer_state_dQ = cutlass.pipeline.make_pipeline_state( - # cutlass.pipeline.PipelineUserType.Producer, 1 - # ) - producer_phase_dQ = Int32(1) # producer_state_dKV = cutlass.pipeline.make_pipeline_state( # cutlass.pipeline.PipelineUserType.Producer, 2 # ) @@ -1285,32 +1274,24 @@ def mma( # 1) S = Q0 @ K.T handle_Q = pipeline_Q_consumer.wait_and_advance() - # pipeline_S_P.producer_acquire(producer_state_S_P) - pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_qk_fn(B_idx=handle_Q.index) # Don't release Q yet - # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # producer_state_S_P.advance() - producer_phase_S_P ^= 1 # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) - # pipeline_dP.producer_acquire(producer_state_dP) - pipeline_dP.sync_object_empty.wait(0, producer_phase_dP) - # pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP - pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) # Don't release dO yet - # pipeline_dP.producer_commit(producer_state_dP) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) - # producer_state_dP.advance() - producer_phase_dP ^= 1 + producer_phase_acc ^= 1 # 3) dV = P.T @ dO # wait for P to be ready, which uses the same tmem as S - # pipeline_S_P.producer_acquire(producer_state_S_P) - pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() @@ -1328,20 +1309,15 @@ def mma( handle_Q_next = pipeline_Q_consumer.wait_and_advance() # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready mma_qk_fn(B_idx=handle_Q_next.index) - # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # producer_state_S_P.advance() - producer_phase_S_P ^= 1 # 2) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) - # pipeline_dP.producer_acquire(producer_state_dP) # dP uses the same tmem as dQ - pipeline_dP.sync_object_empty.wait(0, producer_phase_dP) + # dP uses the same tmem as dQ + # However, if dS is ready, then dP must have been ready, so we don't need to wait + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) mma_dsk_fn() - # pipeline_dQ.producer_commit(producer_state_dQ) pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - # producer_state_dQ.advance() - producer_phase_dQ ^= 1 # 3) dK = dS.T @ Q mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) @@ -1352,28 +1328,22 @@ def mma( # 4) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) - # pipeline_dQ.producer_acquire(producer_state_dQ) # dQ uses the same tmem as dP - pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) - # pipeline_dP.producer_commit(producer_state_dP) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) - # producer_state_dP.advance() - producer_phase_dP ^= 1 + producer_phase_acc ^= 1 # 5) dV += P @ dO # wait for P to be ready, which uses the same tmem as S - # pipeline_S_P.producer_acquire(producer_state_S_P) - pipeline_S_P.sync_object_empty.wait(0, producer_phase_S_P) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() handle_Q = handle_Q_next - # pipeline_S_P.producer_commit(producer_state_S_P) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # producer_state_S_P.advance() - producer_phase_S_P ^= 1 # signal to the epilogue that dV is ready # pipeline_dKV.producer_acquire(producer_state_dKV) @@ -1397,16 +1367,16 @@ def mma( producer_phase_dKV ^= 1 # 2) dQ = dS @ K + # dS is done, so dP must have been ready, we don't need to wait mma_dsk_fn() - # pipeline_dQ.producer_commit(producer_state_dQ) pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - # producer_state_dQ.advance() - producer_phase_dQ ^= 1 # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier handle_Q.release() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() + producer_phase_acc ^= 1 + tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1669,6 +1639,8 @@ def compute_loop( pipeline_dP.consumer_wait(consumer_state_S_P_dP) # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) + consumer_state_S_P_dP.advance() + # consumer_phase_S_P_dP ^= 1 ##### dS.T = P.T * (dP.T - Psum) for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): @@ -1706,11 +1678,9 @@ def compute_loop( cute.arch.sync_warp() with cute.arch.elect_one(): - # pipeline_dP.consumer_release(consumer_state_dP) - pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) + # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive + # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) - consumer_state_S_P_dP.advance() - # consumer_phase_S_P_dP ^= 1 cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta From 6c56a0ceb4ed884a2158c0b5007d17108cbc28c3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 Oct 2025 00:21:45 -0400 Subject: [PATCH 197/258] [Cute] Change utils.view_transpose back --- flash_attn/cute/flash_bwd_sm100.py | 2 +- flash_attn/cute/utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index b7961feda06..2eccadd9790 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1216,7 +1216,7 @@ def mma( gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, A_idx=0, zero_init=True ) # mma_dov_fn = partial( - # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True + # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True # ) mma_dov_fn = partial( gemm_ptx_w_idx, diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index f26f2cb8d80..6bd5123f100 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -228,10 +228,10 @@ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: def transpose_view(a: cute.Tensor) -> cute.Tensor: """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) - # order = (1, 0, *range(2, cute.rank(a))) - # return cute.composition(a, cute.make_ordered_layout(shape, order=order)) - stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:]) - return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + # stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:]) + # return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: From 285bf126bf5702f9c3731d29eb07e1214158598c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 Oct 2025 16:06:14 -0400 Subject: [PATCH 198/258] [Cute,Bwd,Sm100] Remove delay_tma_store option --- flash_attn/cute/flash_bwd_sm100.py | 43 ++++++++++-------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 2eccadd9790..967e8fb84ea 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -436,7 +436,7 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) - # S = K @ Q.T + # S.T = K @ Q.T tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mK, @@ -453,7 +453,7 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) - # dP = V @ dO.T + # dP.T = V @ dO.T tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mV, @@ -998,7 +998,6 @@ def kernel( # (0, 1, 2, 3) - dQ if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: cute.arch.warpgroup_reg_alloc(self.num_regs_reduce) - self.dQacc_reduce( mdQaccum, sdQaccum, @@ -1787,7 +1786,7 @@ def dQacc_reduce( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() - dQ_consumer_state = cutlass.pipeline.make_pipeline_state( + dQ_consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) dQ_tma_store_producer_state = pipeline.make_pipeline_state( @@ -1820,15 +1819,18 @@ def dQacc_reduce( # semaphore acquire if const_expr(self.deterministic): - barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, n_block) + barrier.wait_eq(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, n_block) self.reduce_sync_barrier.arrive_and_wait() gdQaccum_cur = gdQaccum[None, None, m_block] - # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops - delay_tma_store = False - - def tma_store_fn(src_idx, dst_idx): + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + smem_idx = dQ_tma_store_producer_state.index + tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] + tdQrdQ_r2s = cute.make_tensor( + tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape + ) + cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta @@ -1838,28 +1840,13 @@ def tma_store_fn(src_idx, dst_idx): if is_tma_warp: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( - sdQaccum[None, src_idx].iterator, - gdQaccum_cur[None, dst_idx].iterator, + sdQaccum[None, smem_idx].iterator, + gdQaccum_cur[None, stage].iterator, self.tma_copy_bytes["dQ"] // 1, ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() - - smem_idx_prev, stage_prev = None, -1 - for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 - smem_idx = dQ_tma_store_producer_state.index - tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] - tdQrdQ_r2s = cute.make_tensor( - tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape - ) - if const_expr(delay_tma_store): - if const_expr(stage > 0): - tma_store_fn(src_idx=smem_idx_prev, dst_idx=stage_prev) - smem_idx_prev, stage_prev = smem_idx, stage - cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) - if const_expr(not delay_tma_store): - tma_store_fn(smem_idx, stage) dQ_tma_store_producer_state.advance() # Directly add to gmem, much slower # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) @@ -1872,8 +1859,6 @@ def tma_store_fn(src_idx, dst_idx): # tdQrdQ_r2s[4 * i + 3], # utils.elem_pointer(tdQgdQ, 4 * i), # ) - if const_expr(delay_tma_store): - tma_store_fn(src_idx=smem_idx_prev, dst_idx=stage_prev) # semaphore release # NOTE: arrive_inc calls red_release which issues membar @@ -1881,7 +1866,7 @@ def tma_store_fn(src_idx, dst_idx): if tidx == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() - barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) + barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) if warp_idx == 0: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) From c59ecd8936e13a8dda475e4cbe350491662509bc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 Oct 2025 16:46:45 -0400 Subject: [PATCH 199/258] [Cute,Bwd,Sm100] Implement cluster Co-authored-by: Ted Zadouri --- flash_attn/cute/flash_bwd.py | 6 ++ flash_attn/cute/flash_bwd_preprocess.py | 6 +- flash_attn/cute/flash_bwd_sm100.py | 97 ++++++++++++++++++++----- flash_attn/cute/interface.py | 3 +- flash_attn/cute/pipeline.py | 7 +- flash_attn/cute/tile_scheduler.py | 8 +- 6 files changed, 103 insertions(+), 24 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 4d3bbe7d185..12f900b3970 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -11,6 +11,7 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp +from cutlass import Float32, Int32 import cutlass.utils as utils_basic from flash_attn.cute import ampere_helpers as sm80_utils @@ -373,7 +374,12 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, ): + assert mdQ_semaphore is None, "semaphore not supported yet" # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 1a900f83a67..dd5455b98c4 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -344,10 +344,10 @@ def kernel( blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) - tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - zero = cute.make_fragment_like(tQgQaccum) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + zero = cute.make_fragment_like(tdQgdQaccum) zero.fill(0.0) - cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) + cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) if cutlass.const_expr(mLSE is not None): if cutlass.const_expr(not seqlen.has_cu_seqlens_q): diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 967e8fb84ea..649e85cd747 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -44,6 +44,7 @@ def __init__( tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, + cluster_size: int = 1, ): assert qhead_per_kvhead == 1, "GQA is not supported yet in FlashAttentionBackwardSm100" # padding head_dim to a multiple of 16 as k_block_size @@ -79,7 +80,8 @@ def __init__( self.dsk_acc_dtype ) = Float32 - self.cluster_shape_mn = (1, 1) + assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported" + self.cluster_shape_mn = (cluster_size, 1) self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = False @@ -342,6 +344,18 @@ def __call__( assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + (mdQaccum,) = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mdQaccum,) + ] + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) mQ, mK, mV, mdO, mdK, mdV = [ utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO, mdK, mdV) @@ -354,7 +368,6 @@ def __call__( mdO = utils.select(mdO, mode=dO_transpose) semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) - mdQ_semaphore = None if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = utils.select(mdQ_semaphore.layout, mode=semaphore_transpose) @@ -383,6 +396,8 @@ def __call__( cute.make_layout(self.cluster_shape_mnk), (self.tiled_mma_SdP.thr_id.shape,), ) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_q_do_mcast = self.num_mcast_ctas_b > 1 self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) @@ -445,8 +460,12 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) + Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_SdP.thr_id + ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op if const_expr(self.cluster_shape_mnk[1] == 1) else tma_load_op_multicast, + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + Q_tma_op, mQ, cute.select(self.sQ_layout, mode=[0, 1, 2]), self.mma_tiler_kq, @@ -462,12 +481,16 @@ def __call__( self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) + dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_SdP.thr_id + ) tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op if const_expr(self.cluster_shape_mnk[1] == 1) else tma_load_op_multicast, + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + dO_tma_op, mdO, cute.select(self.sdO_layout, mode=[0, 1, 2]), - self.mma_tiler_pdo, - self.tiled_mma_dV, + self.mma_tiler_vdo, + self.tiled_mma_SdP, self.cluster_layout_vmnk.shape, ) @@ -495,6 +518,7 @@ def __call__( mV.shape[1], total_q=cute.size(mQ.shape[0]), tile_shape_mn=self.cta_tiler[:2], + cluster_shape_mn=self.cluster_shape_mnk[:2], mCuSeqlensQ=None, mSeqUsedQ=None, qhead_per_kvhead_packgqa=1, @@ -674,6 +698,11 @@ def kernel( if const_expr(tma_atom_dK is not None): cpasync.prefetch_descriptor(tma_atom_dK) + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_SdP.thr_id.shape,), + ) + # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) @@ -698,8 +727,9 @@ def kernel( pipeline_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) ) + # The arrive count is the number of mcast size pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b ) pipeline_Q = pipeline.PipelineTmaUmma.create( barrier_storage=storage.Q_mbar_ptr.data_ptr(), @@ -707,6 +737,7 @@ def kernel( producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cluster_layout_vmnk, init_wait=False, ) pipeline_dO = pipeline.PipelineTmaUmma.create( @@ -715,6 +746,7 @@ def kernel( producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"], + cta_layout_vmnk=cluster_layout_vmnk, init_wait=False, ) @@ -830,7 +862,8 @@ def kernel( block_info = BlockInfo( self.tile_m, - self.tile_n, + # self.tile_n, + self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested self.is_causal, self.is_local, None, @@ -873,7 +906,6 @@ def kernel( cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.load( thr_mma_SdP, - thr_mma_dV, mQ, mK, mV, @@ -896,6 +928,7 @@ def kernel( dPsum_empty_mbar_ptr, pipeline_Q, pipeline_dO, + cluster_layout_vmnk, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1016,7 +1049,6 @@ def kernel( def load( self, thr_mma_SdP: cute.core.ThrMma, - thr_mma_dV: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, @@ -1039,6 +1071,7 @@ def load( dPsum_empty_mbar_ptr: cute.Pointer, pipeline_Q: PipelineAsync, pipeline_dO: PipelineAsync, + cluster_layout_vmnk: cute.Layout, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1050,12 +1083,23 @@ def load( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) + # Compute multicast mask for Q & dO buffer full + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + q_do_mcast_mask = None + if const_expr(self.is_q_do_mcast): + q_do_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) head_idx_kv = head_idx // self.qhead_per_kvhead mQ_cur = mQ[None, None, head_idx, batch_idx] mK_cur = mK[None, None, head_idx_kv, batch_idx] @@ -1073,7 +1117,7 @@ def load( gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) - tdVgdO = thr_mma_dV.partition_B(gdO) + tdPgdO = thr_mma_SdP.partition_B(gdO) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), tSgK, sK[None, None, None, 0], single_stage=True @@ -1086,10 +1130,23 @@ def load( sV[None, None, None, 0], single_stage=True, ) - load_Q, _, _ = copy_utils.tma_get_copy_fn(tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tSgQ, + dst_tensor=sQ, + mcast_mask=q_do_mcast_mask, + ) load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) load_dO, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_dO, 0, cute.make_layout(1), tdVgdO, sdO + tma_atom_dO, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tdPgdO, + dst_tensor=sdO, + mcast_mask=q_do_mcast_mask, ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) @@ -1261,7 +1318,9 @@ def mma( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) accumulate_dK = False # ----------------------------------------------------------- @@ -1554,7 +1613,9 @@ def compute_loop( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) # TODO: condition mask_seqlen mask_fn = partial( @@ -1795,7 +1856,9 @@ def dQacc_reduce( while work_tile.is_valid_tile: n_block, head_idx, batch_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) # (M * K / STAGE, STAGE, _) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c3fb3fa3c3b..55d415c93cc 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -699,7 +699,7 @@ def _flash_attn_bwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. - compile_key_pre = (dtype, head_dim_v, m_block_size, num_threads) + compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads) if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: fa_bwd_pre = FlashAttentionBackwardPreprocess( dtype, @@ -821,6 +821,7 @@ def _flash_attn_bwd( qhead_per_kvhead=qhead_per_kvhead, # tile_m=m_block_size, # tile_n=n_block_size, + cluster_size=2 if not causal else 2, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 6228037d203..3fca9c21c9b 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -264,6 +264,7 @@ def create( tx_count: int, barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), init_wait: cutlass.Constexpr[bool] = True, ): """ @@ -280,6 +281,8 @@ def create( :type tx_count: int :param cta_layout_vmnk: Layout of the cluster shape :type cta_layout_vmnk: cute.Layout | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] """ if not isinstance(barrier_storage, cute.Pointer): raise ValueError( @@ -305,7 +308,9 @@ def create( # All threadblocks are leaders if not using clusters is_leader_cta = True else: - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask( + cta_layout_vmnk, mcast_mode_mn + ) is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) cta_group = ( diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index bea4496ecc2..f9359556662 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -45,6 +45,7 @@ class TileSchedulerArguments(ParamsBase): headdim_v: Int32 total_q: Int32 tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @@ -59,12 +60,13 @@ class Params(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) @staticmethod def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileScheduler.Params": - return SingleTileScheduler.Params(args.num_block, args.num_head, args.num_batch) + return SingleTileScheduler.Params(args.num_block, args.num_head, args.num_batch, args.cluster_shape_mn) def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): self._blk_coord = blk_coord @@ -89,7 +91,9 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: - return params.num_block, params.num_head, params.num_batch + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + return cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head, params.num_batch def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) From 25e6d94496fa5d4eb39a0ee28884fc8f142af1e5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 27 Oct 2025 17:10:10 -0400 Subject: [PATCH 200/258] [Cute] Copy benchmark util functions to cute directory Easier to benchmark without having to install FA2 --- benchmarks/benchmark_attn.py | 4 +- flash_attn/cute/benchmark.py | 268 +++++++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 2 deletions(-) create mode 100644 flash_attn/cute/benchmark.py diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 5b3de776ec0..1a868e0a286 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -21,7 +21,7 @@ from einops import rearrange, repeat # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python @@ -409,4 +409,4 @@ def run(*args, **kwargs): if flash_attn_func_python is not None: print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: - print(f'FAv2 Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') + print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') diff --git a/flash_attn/cute/benchmark.py b/flash_attn/cute/benchmark.py new file mode 100644 index 00000000000..9a7820e7b0c --- /dev/null +++ b/flash_attn/cute/benchmark.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, Tri Dao. +"""Useful functions for writing test code.""" + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(*inputs, y=y, grad=grad)", + globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward + Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(grad, *inputs, **kwinputs)", + globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_fwd_bwd( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def benchmark_all( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_combined( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def pytorch_profiler( + fn, + *inputs, + trace_filename=None, + backward=False, + amp=False, + amp_dtype=torch.float16, + cpu=False, + verbose=True, + **kwinputs, +): + """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" + if backward: + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) + for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + # Backward should be done outside autocast + if backward: + out.backward(g, retain_graph=True) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ + torch.profiler.ProfilerActivity.CUDA + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + if backward: + out.backward(g, retain_graph=True) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) + if verbose: + print(f"{desc} max memory: {mem}GB") + torch.cuda.empty_cache() + return mem From 53d3a99d2ab33e331330dae4775173de0117f45c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 27 Oct 2025 20:10:27 -0400 Subject: [PATCH 201/258] [Cute,Bwd,Sm100] Use pipeline class for LSE and dPsum --- flash_attn/cute/flash_bwd_sm100.py | 282 +++++++++++++++++------------ flash_attn/cute/pipeline.py | 7 +- 2 files changed, 170 insertions(+), 119 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 649e85cd747..ef36a77746e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -162,14 +162,14 @@ def __init__( def _setup_attributes(self): self.Q_stage = 2 self.dO_stage = 1 - self.LSE_stage = 1 - self.dPsum_stage = 1 + # LSE_stage = Q_stage and dPsum_stage = dO_stage self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma self.dQ_reduce_ncol = 32 self.sdQaccum_stage = 64 // self.dQ_reduce_ncol assert self.tile_hdim % self.dQ_reduce_ncol == 0 self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE @@ -282,11 +282,11 @@ def _setup_smem_layout(self): (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) ) self.sLSE_layout = cute.make_layout( - shape=(self.tile_m, self.LSE_stage), + shape=(self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) self.sdPsum_layout = cute.make_layout( - shape=(self.tile_m, self.dPsum_stage), + shape=(self.tile_m, self.dO_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) self.sdKV_epi_tile = ( @@ -536,15 +536,19 @@ def __call__( class SharedStorage: Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] - LSE_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] - LSE_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.LSE_stage] - dPsum_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] - dPsum_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.dPsum_stage] + LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] tmem_holding_buf: Int32 tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] @@ -708,47 +712,18 @@ def kernel( storage = smem.allocate(self.shared_storage) tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() - LSE_full_mbar_ptr = storage.LSE_full_mbar_ptr.data_ptr() - LSE_empty_mbar_ptr = storage.LSE_empty_mbar_ptr.data_ptr() - dPsum_full_mbar_ptr = storage.dPsum_full_mbar_ptr.data_ptr() - dPsum_empty_mbar_ptr = storage.dPsum_empty_mbar_ptr.data_ptr() + dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr() + dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr() if warp_idx == 1: cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) ) - if warp_idx == 2: - cute.arch.mbarrier_init(LSE_full_mbar_ptr, 1) - cute.arch.mbarrier_init(LSE_empty_mbar_ptr, len(self.compute_warp_ids)) - if warp_idx == 3: - cute.arch.mbarrier_init(dPsum_full_mbar_ptr, 1) - cute.arch.mbarrier_init(dPsum_empty_mbar_ptr, len(self.compute_warp_ids)) - - pipeline_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) - ) - # The arrive count is the number of mcast size - pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b - ) - pipeline_Q = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.Q_mbar_ptr.data_ptr(), - num_stages=self.Q_stage, - producer_group=pipeline_producer_group, - consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_bytes["Q"], - cta_layout_vmnk=cluster_layout_vmnk, - init_wait=False, - ) - pipeline_dO = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.dO_mbar_ptr.data_ptr(), - num_stages=self.dO_stage, - producer_group=pipeline_producer_group, - consumer_group=pipeline_consumer_group, - tx_count=self.tma_copy_bytes["dO"], - cta_layout_vmnk=cluster_layout_vmnk, - init_wait=False, - ) + if const_expr(self.cluster_reduce_dQ): + if warp_idx == 4: + for i in range(self.dQaccum_reduce_stage // 2): + cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1) + cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1) # UMMA producers and AsyncThread consumers pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( @@ -795,7 +770,6 @@ def kernel( pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) # MMA - pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=1, producer_group=pipeline_PdS_producer_group, @@ -803,6 +777,56 @@ def kernel( barrier_storage=storage.dS_mbar_ptr.data_ptr(), ) + # TMA producer and UMMA consumers + pipeline_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + # The arrive count is the number of mcast size + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b + ) + pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup( + # cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b + cutlass.pipeline.Agent.Thread, + len(self.compute_warp_ids) * 1, + ) + pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.LSE_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["LSE"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.dPsum_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["dPsum"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_Q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.Q_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=False, + ) + pipeline_dO = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.dO_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["dO"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=True, + ) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sQt = cute.make_tensor(cute.recast_ptr(sQ.iterator, sQt_layout.inner), sQt_layout.outer) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) @@ -922,12 +946,10 @@ def kernel( tma_atom_K, tma_atom_V, tma_atom_dO, - LSE_full_mbar_ptr, - LSE_empty_mbar_ptr, - dPsum_full_mbar_ptr, - dPsum_empty_mbar_ptr, pipeline_Q, pipeline_dO, + pipeline_LSE, + pipeline_dPsum, cluster_layout_vmnk, block_info, SeqlenInfoCls, @@ -1001,10 +1023,8 @@ def kernel( mdK, sdS, tdPtdP, - LSE_full_mbar_ptr, - LSE_empty_mbar_ptr, - dPsum_full_mbar_ptr, - dPsum_empty_mbar_ptr, + pipeline_LSE, + pipeline_dPsum, pipeline_S_P, pipeline_dS, pipeline_dKV, @@ -1065,21 +1085,19 @@ def load( tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, - LSE_full_mbar_ptr: cute.Pointer, - LSE_empty_mbar_ptr: cute.Pointer, - dPsum_full_mbar_ptr: cute.Pointer, - dPsum_empty_mbar_ptr: cute.Pointer, pipeline_Q: PipelineAsync, pipeline_dO: PipelineAsync, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, cluster_layout_vmnk: cute.Layout, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): - producer_state_Q = cutlass.pipeline.make_pipeline_state( + producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) - producer_state_dO = cutlass.pipeline.make_pipeline_state( + producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) @@ -1151,65 +1169,79 @@ def load( load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) copy_stats = partial(cute.copy, copy_atom_stats) + # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32) + # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) # First iteration: load K together w Q & LSE, then V together w dO & dPsum # K & Q - pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) - load_Q(m_block_min, producer_state=producer_state_Q) - pipeline_Q.producer_commit(producer_state_Q) - producer_state_Q.advance() + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_min, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] + copy_stats( + gLSE[None, m_block_min], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) - copy_stats(gLSE[None, m_block_min], sLSE[None, 0], mbar_ptr=LSE_full_mbar_ptr) + producer_state_Q_LSE.advance() # V & dO - pipeline_dO.producer_acquire(producer_state_dO, extra_tx_count=self.tma_copy_bytes["V"]) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO)) - load_dO(m_block_min, producer_state=producer_state_dO) - pipeline_dO.producer_commit(producer_state_dO) - producer_state_dO.advance() + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_min, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + copy_stats( + gdPsum[None, m_block_min], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) - copy_stats(gdPsum[None, m_block_min], sdPsum[None, 0], mbar_ptr=dPsum_full_mbar_ptr) + producer_state_dO_dPsum.advance() - lse_empty_consumer_phase = cute.Int32(0) - dpsum_empty_consumer_phase = cute.Int32(0) for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # Q - pipeline_Q.producer_acquire(producer_state_Q) - load_Q(m_block, producer_state=producer_state_Q) - pipeline_Q.producer_commit(producer_state_Q) - producer_state_Q.advance() + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE - cute.arch.mbarrier_wait(LSE_empty_mbar_ptr, lse_empty_consumer_phase) - lse_empty_consumer_phase ^= 1 + pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - LSE_full_mbar_ptr, self.tma_copy_bytes["LSE"] + copy_stats( + gLSE[None, m_block_min], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) - copy_stats(gLSE[None, m_block], sLSE[None, 0], mbar_ptr=LSE_full_mbar_ptr) + producer_state_Q_LSE.advance() # dO - pipeline_dO.producer_acquire(producer_state_dO) - load_dO(m_block, producer_state=producer_state_dO) - pipeline_dO.producer_commit(producer_state_dO) - producer_state_dO.advance() + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum - cute.arch.mbarrier_wait(dPsum_empty_mbar_ptr, dpsum_empty_consumer_phase) - dpsum_empty_consumer_phase ^= 1 + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - dPsum_full_mbar_ptr, self.tma_copy_bytes["dPsum"] + copy_stats( + gdPsum[None, m_block_min], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) - copy_stats(gdPsum[None, m_block], sdPsum[None, 0], mbar_ptr=dPsum_full_mbar_ptr) + producer_state_dO_dPsum.advance() - pipeline_Q.producer_tail(producer_state_Q) - pipeline_dO.producer_tail(producer_state_dO) + pipeline_Q.producer_tail(producer_state_Q_LSE.clone()) # will hand if we don't clone + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_LSE.producer_tail(producer_state_Q_LSE) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1501,10 +1533,8 @@ def compute_loop( mdK: cute.Tensor, sdS: cute.Tensor, tdPtdP: cute.Tensor, - LSE_full_mbar_ptr: cute.Pointer, - LSE_empty_mbar_ptr: cute.Pointer, - dPsum_full_mbar_ptr: cute.Pointer, - dPsum_empty_mbar_ptr: cute.Pointer, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, @@ -1528,14 +1558,14 @@ def compute_loop( sLSE_2D = cute.make_tensor( sLSE.iterator, cute.make_layout( - (self.tile_m, self.tile_n, self.LSE_stage), + (self.tile_m, self.tile_n, self.Q_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) sdPsum_2D = cute.make_tensor( sdPsum.iterator, cute.make_layout( - (self.tile_m, self.tile_n, self.dPsum_stage), + (self.tile_m, self.tile_n, self.dO_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) @@ -1605,8 +1635,13 @@ def compute_loop( consumer_state_dKV = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 2 ) - - consumer_phase_LSE = consumer_phase_dPsum = cute.Int32(0) + consumer_state_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + # consumer_state_dPsum = cutlass.pipeline.make_pipeline_state( + consumer_state_dPsum = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -1628,19 +1663,28 @@ def compute_loop( mask_local=self.is_local, ) + # prefetch_LSE = not self.is_causal + prefetch_LSE = False + # Mainloop for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + # Prefetch 1 stage of LSE + pipeline_LSE.consumer_wait(consumer_state_LSE) + tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) + if const_expr(prefetch_LSE and not self.shuffle_LSE): + cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r) + pipeline_S_P.consumer_wait(consumer_state_S_P_dP) # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) - cute.arch.mbarrier_wait(LSE_full_mbar_ptr, consumer_phase_LSE) - consumer_phase_LSE ^= 1 #### APPLY MASK mask_fn(tSrS_t2r, m_block=m_block) + num_stages = cute.size(tScS_t2r, mode=[1]) + # --------------------------------------------- #### P = exp(S - LSE) # --------------------------------------------- @@ -1649,10 +1693,11 @@ def compute_loop( tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): tSrS_cur = tSrS_t2r[None, stage, 0, 0] - tSsLSE_cur = tSsLSE[None, stage, 0, 0, 0] # TODO: have stages + tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index] if const_expr(not self.shuffle_LSE): - tSrLSE = cute.make_fragment_like(tSsLSE_cur, Float32) - cute.autovec_copy(tSsLSE_cur, tSrLSE) + if const_expr(stage > 0 or not prefetch_LSE): + cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r) + tSrLSE = tSrLSE_s2r else: tSrLSE = tSsLSE_cur[lane_idx] for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2): @@ -1688,14 +1733,14 @@ def compute_loop( with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) - cute.arch.mbarrier_arrive(LSE_empty_mbar_ptr) + pipeline_LSE.consumer_release(consumer_state_LSE) # consumer_state_S_P_dP.advance() + consumer_state_LSE.advance() # --------------------------------------------- # dS.T = P.T * (dP.T - D) # --------------------------------------------- - cute.arch.mbarrier_wait(dPsum_full_mbar_ptr, consumer_phase_dPsum) - consumer_phase_dPsum ^= 1 + pipeline_dPsum.consumer_wait(consumer_state_dPsum) pipeline_dP.consumer_wait(consumer_state_S_P_dP) # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) @@ -1709,7 +1754,7 @@ def compute_loop( cute.arch.fence_view_async_tmem_load() tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] - tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, 0] # TODO: have stages + tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] if const_expr(not self.shuffle_dPsum): tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32) cute.autovec_copy(tSsdPsum_cur, tSrdPsum) @@ -1737,10 +1782,11 @@ def compute_loop( cute.autovec_copy(tdPrdP_cvt, tRS_sdS[None, stage]) cute.arch.sync_warp() - with cute.arch.elect_one(): - # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive - # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) - cute.arch.mbarrier_arrive(dPsum_empty_mbar_ptr) + # with cute.arch.elect_one(): + # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive + # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) + pipeline_dPsum.consumer_release(consumer_state_dPsum) + consumer_state_dPsum.advance() cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 3fca9c21c9b..7ed7ab06d29 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -154,6 +154,7 @@ def create( barrier_storage: cute.Pointer = None, cta_layout_vmnk: Optional[cute.Layout] = None, tidx: Optional[Int32] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), init_wait: cutlass.Constexpr[bool] = True, ): """ @@ -172,6 +173,8 @@ def create( :type cta_layout_vmnk: cute.Layout | None :param tidx: thread index to consumer async threads :type tidx: Int32 | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] """ if not isinstance(barrier_storage, cute.Pointer): raise ValueError( @@ -201,7 +204,9 @@ def create( ( dst_rank, is_signalling_thread, - ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal( + cta_layout_vmnk, tidx, mcast_mode_mn + ) producer_mask = None From a5d545df1ddab7477d6df494d655caedbf789237 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 27 Oct 2025 20:28:03 -0400 Subject: [PATCH 202/258] [Cute,Bwd,Sm100] Remove stage from sK, sV, tP, sdS --- flash_attn/cute/flash_bwd_sm100.py | 49 +++++++++++++++--------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index ef36a77746e..46ac485e34e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -214,12 +214,13 @@ def _get_tiled_mma(self): def _setup_smem_layout(self): # S = K @ Q.T - self.sK_layout = sm100_utils_basic.make_smem_layout_a( + sK_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_SdP, self.mma_tiler_kq, self.k_dtype, 1, ) + self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) self.sQ_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_SdP, self.mma_tiler_kq, @@ -227,12 +228,13 @@ def _setup_smem_layout(self): self.Q_stage, ) # dP = V @ dO.T - self.sV_layout = sm100_utils_basic.make_smem_layout_a( + sV_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_SdP, self.mma_tiler_vdo, self.v_dtype, 1, ) + self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_SdP, self.mma_tiler_vdo, @@ -240,12 +242,13 @@ def _setup_smem_layout(self): self.dO_stage, ) # dV += P @ dO - self.tP_layout = sm100_utils_basic.make_smem_layout_a( + tP_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dV, self.mma_tiler_pdo, self.do_dtype, 1, ) + self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0)) self.sdO_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dV, self.mma_tiler_pdo, @@ -253,12 +256,13 @@ def _setup_smem_layout(self): self.dO_stage, ) # dK += dS.T @ Q - self.sdSt_layout = sm100_utils_basic.make_smem_layout_a( + sdSt_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, 1, ) + self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, self.mma_tiler_dsq, @@ -266,18 +270,20 @@ def _setup_smem_layout(self): self.Q_stage, ) # dQ = dS @ K - self.sdS_layout = sm100_utils_basic.make_smem_layout_a( + sdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dQ, self.mma_tiler_dsk, self.ds_dtype, 1, ) - self.sKt_layout = sm100_utils_basic.make_smem_layout_b( + self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0)) + sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, self.mma_tiler_dsk, self.k_dtype, 1, ) + self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0)) self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) ) @@ -1138,14 +1144,14 @@ def load( tdPgdO = thr_mma_SdP.partition_B(gdO) load_K, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_K, 0, cute.make_layout(1), tSgK, sK[None, None, None, 0], single_stage=True + tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True ) load_V, _, _ = copy_utils.tma_get_copy_fn( tma_atom_V, 0, cute.make_layout(1), tdPgV, - sV[None, None, None, 0], + sV, single_stage=True, ) b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) @@ -1297,14 +1303,14 @@ def mma( tdQrK = tiled_mma_dQ.make_fragment_B(sKt) # dV = P @ dO.T tdVrdO = tiled_mma_dV.make_fragment_B(sdO) - tdVrP = tiled_mma_dV.make_fragment_A(tP)[None, None, None, 0] + tdVrP = tiled_mma_dV.make_fragment_A(tP) - # mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, A_idx=0, zero_init=True) + # mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, zero_init=True) mma_qk_fn = partial( - gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, A_idx=0, zero_init=True + gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True ) # mma_dov_fn = partial( - # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, A_idx=0, zero_init=True + # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, zero_init=True # ) mma_dov_fn = partial( gemm_ptx_w_idx, @@ -1314,23 +1320,16 @@ def mma( tdPrdOt, sA=sV, sB=sdOt, - A_idx=0, zero_init=True, ) - mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, A_idx=None) - # mma_pdo_fn = partial( - # gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO, A_idx=None - # ) - mma_dsk_fn = partial( - gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, A_idx=0, B_idx=0, zero_init=True - ) + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) + # mma_pdo_fn = partial(gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO) + mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True) # mma_dsk_fn = partial( - # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, A_idx=0, B_idx=0, zero_init=True - # ) - mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, A_idx=0) - # mma_dsq_fn = partial( - # gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=sdSt, sB=sQt, A_idx=0 + # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + # mma_dsq_fn = partial(gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=sdSt, sB=sQt) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage From b3f1b6a5bdcce820e74cc0bb6f615165387195cc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 27 Oct 2025 23:06:25 -0400 Subject: [PATCH 203/258] [Cute,Bwd,Sm100] Fix wrong LSE and dPsum indexing in load --- flash_attn/cute/flash_bwd_sm100.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 46ac485e34e..8eebd457ad9 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1225,7 +1225,7 @@ def load( pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block_min], + gLSE[None, m_block], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) @@ -1238,7 +1238,7 @@ def load( pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block_min], + gdPsum[None, m_block], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) From 67e88650129371e439342122208ab7bfc01557bf Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 28 Oct 2025 12:35:27 -0700 Subject: [PATCH 204/258] [Cute] Blocks tweaks (#1964) --- flash_attn/cute/benchmark_mask_mod.py | 58 ++++++------------- flash_attn/cute/block_sparsity.py | 81 ++++++++++++++++++++++++++- flash_attn/cute/flash_fwd.py | 44 +++++---------- flash_attn/cute/flash_fwd_sm100.py | 7 +-- flash_attn/cute/interface.py | 53 +++++------------- tests/cute/test_mask_mod.py | 15 +++-- 6 files changed, 135 insertions(+), 123 deletions(-) diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py index b1aadd89395..9b7950ba076 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -21,7 +21,11 @@ create_cute_sliding_window_mask, create_flex_sliding_window_mask, ) -from block_sparsity import compute_block_sparsity +from flash_attn.cute.block_sparsity import ( + compute_block_sparsity, + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, +) @dataclass @@ -265,10 +269,12 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: ) if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): - tensors["full_block_cnt"] = full_cnt.contiguous() - tensors["full_block_idx"] = full_idx.contiguous() - tensors["mask_block_cnt"] = mask_cnt.contiguous() - tensors["mask_block_idx"] = mask_idx.contiguous() + tensors["block_sparse_tensors"] = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt.contiguous(), + mask_block_idx=mask_idx.contiguous(), + full_block_cnt=full_cnt.contiguous(), + full_block_idx=full_idx.contiguous(), + ) if config.verbose: total_full = full_cnt.sum().item() @@ -373,33 +379,9 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] else None ) - # Block sparsity tensors - full_block_cnt_cute = ( - from_dlpack(tensors["full_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - if "full_block_cnt" in tensors - else None - ) - full_block_idx_cute = ( - from_dlpack(tensors["full_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - if "full_block_idx" in tensors - else None - ) - mask_block_cnt_cute = ( - from_dlpack(tensors["mask_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - if "mask_block_cnt" in tensors - else None - ) - mask_block_idx_cute = ( - from_dlpack(tensors["mask_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - if "mask_block_idx" in tensors + blocksparse_tensors_cute = ( + to_cute_block_sparse_tensors(tensors["block_sparse_tensors"]) + if "block_sparse_tensors" in tensors else None ) @@ -436,11 +418,8 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] None, # page_table window_left_cute, window_right_cute, - learnable_sink_cute, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, + learnable_sink_cute, + blocksparse_tensors_cute, aux_tensors_cute, # None, ) @@ -461,10 +440,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] window_left_cute, window_right_cute, learnable_sink_cute, - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, + blocksparse_tensors_cute, aux_tensors_cute, # None, ) diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index be685dea5d4..c28df4c20d3 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -8,13 +8,92 @@ by a more robust preprocessing kernel in the future. """ -from typing import Tuple, Optional, Callable, List +from typing import Tuple, Optional, Callable, List, NamedTuple import torch +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack # placeholder Config = type("Config", (), {}) +class BlockSparseTensors(NamedTuple): + mask_block_cnt: cute.Tensor + mask_block_idx: cute.Tensor + full_block_cnt: Optional[cute.Tensor] + full_block_idx: Optional[cute.Tensor] + + def __new_from_mlir_values__(self, values): + return BlockSparseTensors(*values) + + +class BlockSparseTensorsTorch(NamedTuple): + mask_block_cnt: torch.Tensor + mask_block_idx: torch.Tensor + full_block_cnt: Optional[torch.Tensor] = None + full_block_idx: Optional[torch.Tensor] = None + + +def validate_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> None: + for name, cnt, idx in ( + ("mask", tensors.mask_block_cnt, tensors.mask_block_idx), + ("full", tensors.full_block_cnt, tensors.full_block_idx), + ): + if (cnt is None) != (idx is None): + raise ValueError( + f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" + ) + if cnt is None: + continue + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: + raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") + if not cnt.is_cuda or not idx.is_cuda: + raise ValueError(f"{name}_block tensors must live on CUDA") + + if tensors.full_block_cnt is not None and tensors.mask_block_cnt is not None: + if tensors.full_block_cnt.device != tensors.mask_block_cnt.device: + raise ValueError("All block sparse tensors must be on the same device") + + +def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: + return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) + + +def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]: + if not is_block_sparsity_enabled(tensors): + return None + + mask_block_cnt_tensor = from_dlpack( + tensors.mask_block_cnt.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=2) + mask_block_idx_tensor = from_dlpack( + tensors.mask_block_idx.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=3) + full_block_cnt_tensor = ( + from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if tensors.full_block_cnt is not None + else None + ) + full_block_idx_tensor = ( + from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if tensors.full_block_idx is not None + else None + ) + + return BlockSparseTensors( + mask_block_cnt_tensor, + mask_block_idx_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + ) + + def compute_block_sparsity( config: Config, mask_mod_flex: Optional[Callable], diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index b49a693dfcd..16d57991f97 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -29,6 +29,7 @@ from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd @@ -1271,10 +1272,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, ): """Configures and launches the flash attention kernel. @@ -1290,6 +1288,7 @@ def __call__( ) ) + # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: ( *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), @@ -1325,9 +1324,8 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_block_sparsity = const_expr( - mask_block_cnt is not None and full_block_cnt is not None - ) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + self.use_scheduler_barrier = ( (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) @@ -1521,10 +1519,7 @@ def __call__( window_size_left, window_size_right, learnable_sink, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1571,10 +1566,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1698,10 +1690,7 @@ def kernel( pipeline_k, pipeline_v, mbar_ptr_Q, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1740,10 +1729,7 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, aux_tensors, fastdiv_mods, ) @@ -1763,10 +1749,7 @@ def load( pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1852,6 +1835,7 @@ def load( # ========================================== # Flex Attention blocksparsity # ========================================== + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] @@ -2033,10 +2017,7 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], aux_tensors: Optional[list], fastdiv_mods=None, ): @@ -2263,6 +2244,7 @@ def mma( # ========================================== # Block sparsity # ========================================== + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 9d5a814104d..1ec7dce3a1a 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -33,6 +33,7 @@ from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils @@ -224,10 +225,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -243,7 +241,6 @@ def __call__( 5. Grid and work scheduling computation 6. Kernel launch with appropriate parameters """ - # setup static attributes before smem/grid/tma computation self.q_dtype = mQ.element_type self.k_dtype = mK.element_type diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 55d415c93cc..51fb5baae63 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -42,6 +42,8 @@ from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch, to_cute_block_sparse_tensors + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -79,10 +81,7 @@ def _flash_attn_fwd( _compute_capability: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, @@ -156,10 +155,7 @@ def _flash_attn_fwd( if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: - if t is not None: - assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" - # assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + assert all( t is None or t.is_cuda for t in ( @@ -172,10 +168,6 @@ def _flash_attn_fwd( seqused_k, page_table, learnable_sink, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, ) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" @@ -259,28 +251,13 @@ def _flash_attn_fwd( if page_table is not None else None ) - - full_block_cnt_tensor = ( - from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) - if full_block_cnt is not None + sparse_tensors = ( + to_cute_block_sparse_tensors(block_sparse_tensors) + if block_sparse_tensors is not None else None ) - full_block_idx_tensor = ( - from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) - if full_block_idx is not None - else None - ) - mask_block_cnt_tensor = ( - from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) - if mask_block_cnt is not None - else None - ) - mask_block_idx_tensor = ( - from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) - if mask_block_idx is not None - else None - ) - use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + + use_block_sparsity = sparse_tensors is not None if mask_mod is None: if causal: @@ -416,6 +393,8 @@ def _flash_attn_fwd( assert page_size in [None, 128], ( "Only page_size=128 is supported for paged KV on SM 10.0" ) + if sparse_tensors is not None: + raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -452,10 +431,7 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, - full_block_idx_tensor, - mask_block_cnt_tensor, - mask_block_idx_tensor, + sparse_tensors, cute_aux_tensors, ) _flash_attn_fwd.compile_cache[compile_key]( @@ -474,10 +450,7 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, - full_block_idx_tensor, - mask_block_cnt_tensor, - mask_block_idx_tensor, + sparse_tensors, cute_aux_tensors, ) return out, lse diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index ce3a28b82c6..033d08f296f 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_fwd -from flash_attn.cute.block_sparsity import compute_block_sparsity +from flash_attn.cute.block_sparsity import compute_block_sparsity, BlockSparseTensorsTorch from flash_attn.cute.mask_definitions import ( MASK_FUNCTIONS, flex_causal_mask, @@ -304,6 +304,14 @@ class Config: # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") # if mask_cnt[0,0,0] > 0: # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") + block_sparse_mask = None + if use_mask_mod: + block_sparse_mask = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) out_tuple = _flash_attn_fwd( q=tensors["q"], @@ -329,10 +337,7 @@ class Config: _compute_capability=None, score_mod=None, mask_mod=mask_mod_cute, - full_block_cnt=full_cnt, - full_block_idx=full_idx, - mask_block_cnt=mask_cnt, - mask_block_idx=mask_idx, + block_sparse_tensors=block_sparse_mask, return_lse=True, aux_tensors=None, ) From 7f7a497b628d6f4b006c6ec6feb90d0192eddfc3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 28 Oct 2025 17:49:55 -0400 Subject: [PATCH 205/258] [Cute,Bwd,Sm100] Use TS MMA for dK --- flash_attn/cute/blackwell_helpers.py | 16 +++++- flash_attn/cute/flash_bwd_sm100.py | 86 ++++++++++++++++++++++------ 2 files changed, 81 insertions(+), 21 deletions(-) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 1cac21f8f38..e2ff2ccc9ae 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -46,6 +46,7 @@ def gemm_ptx_w_idx( A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, zero_init: bool | Boolean = False, + **kwargs, ) -> None: rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] @@ -55,7 +56,9 @@ def gemm_ptx_w_idx( sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] mma_atom = cute.make_mma_atom(tiled_mma.op) acc_tmem_addr = acc.iterator.toint() - gemm_ptx_partial(mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init) + gemm_ptx_partial( + mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs + ) @cute.jit @@ -366,7 +369,11 @@ def gemm_ptx_partial( mbar_ptr: Optional[cutlass.Pointer] = None, mbar_phase: Optional[Int32] = None, zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, ) -> None: + # acc_tmem_addr += acc_offset is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" @@ -418,6 +425,7 @@ def gemm_ptx_partial( smem_desc_start_a_lo = Int32( smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) ) + # ) + sA_offset else: smem_desc_start_a_lo = None smem_desc_start_b_lo = Int32( @@ -476,8 +484,12 @@ def gemm_ptx_partial( asm_dialect=llvm.AsmDialect.AD_ATT, ) else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr input_args = [ - Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 8eebd457ad9..e32cc64df4b 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -146,6 +146,7 @@ def __init__( self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m + self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP if not is_causal and not is_local: self.num_regs_reduce = 152 @@ -200,6 +201,7 @@ def _get_tiled_mma(self): self.pdo_acc_dtype, cta_group, self.mma_tiler_dsq[:2], + a_source=tcgen05.OperandSource.TMEM, ) # dQ = dS @ K tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( @@ -263,6 +265,13 @@ def _setup_smem_layout(self): 1, ) self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) + tdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.ds_dtype, + 1, + ) + self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0)) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, self.mma_tiler_dsq, @@ -631,6 +640,7 @@ class SharedStorage: self.sdQaccum_layout, self.sdKV_layout, self.tP_layout, + self.tdS_layout, self.tiled_mma_SdP, self.tiled_mma_dV, self.tiled_mma_dK, @@ -685,6 +695,7 @@ def kernel( sdQaccum_layout: cute.Layout, sdKV_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, + tdS_layout: cute.ComposedLayout, tiled_mma_SdP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, @@ -877,13 +888,17 @@ def kernel( dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) - tP_ptr = cute.make_ptr(self.do_dtype, self.tmem_P_offset, cute.AddressSpace.tmem) - tP = cute.make_tensor(tP_ptr, tP_layout.outer) + tP = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer + ) # dK thr_mma_dK = tiled_mma_dK.get_slice(0) dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) + tdS = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer + ) # dQ thr_mma_dQ = tiled_mma_dQ.get_slice(0) dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) @@ -987,6 +1002,7 @@ def kernel( sdS, sKt, tP, + tdS, tStS, tdPtdP, tdVtdV, @@ -1270,6 +1286,7 @@ def mma( sdS: cute.Tensor, sKt: cute.Tensor, tP: cute.Tensor, + tdS: cute.Tensor, tStS: cute.Tensor, tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, @@ -1296,7 +1313,8 @@ def mma( tdPrV = tiled_mma_SdP.make_fragment_A(sV) tdPrdOt = tiled_mma_SdP.make_fragment_B(sdOt) # dK = dS.T @ Q - tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + # tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) @@ -1309,9 +1327,7 @@ def mma( mma_qk_fn = partial( gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True ) - # mma_dov_fn = partial( - # gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, zero_init=True - # ) + # mma_dov_fn = partial(gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) mma_dov_fn = partial( gemm_ptx_w_idx, tiled_mma_SdP, @@ -1322,14 +1338,33 @@ def mma( sB=sdOt, zero_init=True, ) - mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) - # mma_pdo_fn = partial(gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO) + # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) + mma_pdo_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dV, + tdVtdV, + tdVrP, + tdVrdO, + sA=None, + sB=sdO, + tA_addr=self.tmem_P_offset, + ) mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True) # mma_dsk_fn = partial( # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) - mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) - # mma_dsq_fn = partial(gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=sdSt, sB=sQt) + # mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + # Need to explicitly pass in tA_addr for correctness + mma_dsq_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dK, + tdKtdK, + tdKrdS, + tdKrQ, + sA=None, + sB=sQt, + tA_addr=self.tmem_dS_offset, + ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage @@ -1400,18 +1435,18 @@ def mma( mma_qk_fn(B_idx=handle_Q_next.index) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # 2) dQ = dS @ K + # 2) dK = dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + + # 3) dQ = dS @ K # dP uses the same tmem as dQ # However, if dS is ready, then dP must have been ready, so we don't need to wait # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - - # 3) dK = dS.T @ Q - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1576,6 +1611,7 @@ def compute_loop( # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0]) dp_idx = tidx % 128 num_wg = len(self.compute_warp_ids) // 4 # 2 # wg_idx: @@ -1584,9 +1620,15 @@ def compute_loop( tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) + # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + # tdS overlap with tdP + tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tdPcdP = tScS + tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 @@ -1608,6 +1650,8 @@ def compute_loop( thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) tScP_r2t = thr_copy_r2t.partition_S(tScP) tStP_r2t = thr_copy_r2t.partition_D(tStP) + tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS) + tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS) # rmem -> smem # This part is a bit iffy, we might be making a lot of assumptions here copy_atom_r2s = sm100_utils_basic.get_smem_store_op( @@ -1774,11 +1818,15 @@ def compute_loop( (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), ) - tdPrdP_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) - utils.cvt_f16(tdPrdP_cur, tdPrdP_cvt) + tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) + utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt) if const_expr(stage == 0): pipeline_dS.producer_acquire(producer_state_dS) - cute.autovec_copy(tdPrdP_cvt, tRS_sdS[None, stage]) + cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) + tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) + cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + + cute.arch.fence_view_async_tmem_store() cute.arch.sync_warp() # with cute.arch.elect_one(): From b613d9e2c8475945baff3fd68f2030af1b890acf Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 28 Oct 2025 18:02:04 -0400 Subject: [PATCH 206/258] [Cute,Blocksparse] Group block sparse input torch tensors --- flash_attn/cute/interface.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 51fb5baae63..ea81ab88f34 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -955,10 +955,12 @@ def forward( softcap=softcap, pack_gqa=pack_gqa, mask_mod=mask_mod, - full_block_cnt=full_block_cnt, - full_block_idx=full_block_idx, - mask_block_cnt=mask_block_cnt, - mask_block_idx=mask_block_idx, + block_sparse_tensors=BlockSparseTensorsTorch( + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, + ) ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale From 11336b7ca822a16f15bf67fe888fff01552462a9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Oct 2025 17:57:43 -0400 Subject: [PATCH 207/258] [Cute,Bwd,Sm100] Separate mma_S and mma_dP --- flash_attn/cute/flash_bwd_sm100.py | 132 +++++++++++++++++------------ flash_attn/cute/interface.py | 3 +- 2 files changed, 78 insertions(+), 57 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e32cc64df4b..fe7568be125 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -76,9 +76,7 @@ def __init__( # dQ = dS @ K self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) - self.kq_acc_dtype = self.vdo_acc_dtype = self.pdo_acc_dtype = self.dsq_acc_dtype = ( - self.dsk_acc_dtype - ) = Float32 + self.acc_dtype = Float32 assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported" self.cluster_shape_mn = (cluster_size, 1) @@ -174,21 +172,30 @@ def _setup_attributes(self): def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE - # S = K @ Q.T, dP = V @ dO.T - tiled_mma_SdP = sm100_utils_basic.make_trivial_tiled_mma( + # S = K @ Q.T + tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma( self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, - self.kq_acc_dtype, + self.acc_dtype, cta_group, self.mma_tiler_kq[:2], ) + # dP = V @ dO.T + tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.mma_tiler_vdo[:2], + ) # dV += P @ dO --> (K, MN) major tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # P_major_mode tcgen05.OperandMajorMode.MN, # dO_major_mode - self.pdo_acc_dtype, + self.acc_dtype, cta_group, self.mma_tiler_pdo[:2], a_source=tcgen05.OperandSource.TMEM, @@ -198,7 +205,7 @@ def _get_tiled_mma(self): self.do_dtype, tcgen05.OperandMajorMode.K, # dS_major_mode tcgen05.OperandMajorMode.MN, # Q_major_mode - self.pdo_acc_dtype, + self.acc_dtype, cta_group, self.mma_tiler_dsq[:2], a_source=tcgen05.OperandSource.TMEM, @@ -208,37 +215,37 @@ def _get_tiled_mma(self): self.k_dtype, tcgen05.OperandMajorMode.MN, # dS_major_mode tcgen05.OperandMajorMode.MN, # Kt_major_mode - self.dsk_acc_dtype, + self.acc_dtype, cta_group, self.mma_tiler_dsk[:2], ) - return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _setup_smem_layout(self): # S = K @ Q.T sK_layout = sm100_utils_basic.make_smem_layout_a( - self.tiled_mma_SdP, + self.tiled_mma_S, self.mma_tiler_kq, self.k_dtype, 1, ) self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) self.sQ_layout = sm100_utils_basic.make_smem_layout_b( - self.tiled_mma_SdP, + self.tiled_mma_S, self.mma_tiler_kq, self.q_dtype, self.Q_stage, ) # dP = V @ dO.T sV_layout = sm100_utils_basic.make_smem_layout_a( - self.tiled_mma_SdP, + self.tiled_mma_dP, self.mma_tiler_vdo, self.v_dtype, 1, ) self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( - self.tiled_mma_SdP, + self.tiled_mma_dP, self.mma_tiler_vdo, self.do_dtype, self.dO_stage, @@ -399,9 +406,13 @@ def __call__( mdV_semaphore = None self._setup_attributes() - self.tiled_mma_SdP, self.tiled_mma_dK, self.tiled_mma_dV, self.tiled_mma_dQ = ( - self._get_tiled_mma() - ) + ( + self.tiled_mma_S, + self.tiled_mma_dP, + self.tiled_mma_dK, + self.tiled_mma_dV, + self.tiled_mma_dQ, + ) = self._get_tiled_mma() self._setup_smem_layout() cta_group = tcgen05.CtaGroup.ONE @@ -409,7 +420,7 @@ def __call__( self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), - (self.tiled_mma_SdP.thr_id.shape,), + (self.tiled_mma_S.thr_id.shape,), ) self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.is_q_do_mcast = self.num_mcast_ctas_b > 1 @@ -472,11 +483,11 @@ def __call__( mK, cute.select(self.sK_layout, mode=[0, 1, 2]), self.mma_tiler_kq, - self.tiled_mma_SdP, + self.tiled_mma_S, self.cluster_layout_vmnk.shape, ) Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( - self.cluster_shape_mnk, self.tiled_mma_SdP.thr_id + self.cluster_shape_mnk, self.tiled_mma_S.thr_id ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, @@ -484,7 +495,7 @@ def __call__( mQ, cute.select(self.sQ_layout, mode=[0, 1, 2]), self.mma_tiler_kq, - self.tiled_mma_SdP, + self.tiled_mma_S, self.cluster_layout_vmnk.shape, ) # dP.T = V @ dO.T @@ -493,11 +504,11 @@ def __call__( mV, cute.select(self.sV_layout, mode=[0, 1, 2]), self.mma_tiler_vdo, - self.tiled_mma_SdP, + self.tiled_mma_dP, self.cluster_layout_vmnk.shape, ) dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( - self.cluster_shape_mnk, self.tiled_mma_SdP.thr_id + self.cluster_shape_mnk, self.tiled_mma_dP.thr_id ) tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, @@ -505,7 +516,7 @@ def __call__( mdO, cute.select(self.sdO_layout, mode=[0, 1, 2]), self.mma_tiler_vdo, - self.tiled_mma_SdP, + self.tiled_mma_dP, self.cluster_layout_vmnk.shape, ) @@ -641,7 +652,8 @@ class SharedStorage: self.sdKV_layout, self.tP_layout, self.tdS_layout, - self.tiled_mma_SdP, + self.tiled_mma_S, + self.tiled_mma_dP, self.tiled_mma_dV, self.tiled_mma_dK, self.tiled_mma_dQ, @@ -696,7 +708,8 @@ def kernel( sdKV_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, tdS_layout: cute.ComposedLayout, - tiled_mma_SdP: cute.TiledMma, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, @@ -721,7 +734,7 @@ def kernel( cluster_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), - (tiled_mma_SdP.thr_id.shape,), + (tiled_mma_S.thr_id.shape,), ) # Alloc @@ -874,14 +887,15 @@ def kernel( # request 512 columns of tmem, so we know that it starts at 0. tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) # S - thr_mma_SdP = tiled_mma_SdP.get_slice(0) - Sacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) - tStS = thr_mma_SdP.make_fragment_C(Sacc_shape) + thr_mma_S = tiled_mma_S.get_slice(0) + Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) + tStS = thr_mma_S.make_fragment_C(Sacc_shape) # (MMA, MMA_M, MMA_N) tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) # dP - dPacc_shape = thr_mma_SdP.partition_shape_C(self.mma_tiler_vdo[:2]) - tdPtdP = thr_mma_SdP.make_fragment_C(dPacc_shape) + thr_mma_dP = tiled_mma_dP.get_slice(0) + dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape) tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) # dV thr_mma_dV = tiled_mma_dV.get_slice(0) @@ -950,7 +964,8 @@ def kernel( if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.load( - thr_mma_SdP, + thr_mma_S, + thr_mma_dP, mQ, mK, mV, @@ -988,7 +1003,8 @@ def kernel( cute.arch.sync_warp() self.mma( - tiled_mma_SdP, + tiled_mma_S, + tiled_mma_dP, tiled_mma_dV, tiled_mma_dK, tiled_mma_dQ, @@ -1033,7 +1049,8 @@ def kernel( if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps self.compute_loop( - thr_mma_SdP, + thr_mma_S, + thr_mma_dP, thr_mma_dV, thr_mma_dK, tStS, @@ -1090,7 +1107,8 @@ def kernel( @cute.jit def load( self, - thr_mma_SdP: cute.core.ThrMma, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, @@ -1149,15 +1167,15 @@ def load( mPsum_cur = mdPsum[None, head_idx, batch_idx] gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) - tSgK = thr_mma_SdP.partition_A(gK) + tSgK = thr_mma_S.partition_A(gK) gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) - tdPgV = thr_mma_SdP.partition_A(gV) + tdPgV = thr_mma_dP.partition_A(gV) gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) - tSgQ = thr_mma_SdP.partition_B(gQ) + tSgQ = thr_mma_S.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) - tdPgdO = thr_mma_SdP.partition_B(gdO) + tdPgdO = thr_mma_dP.partition_B(gdO) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True @@ -1272,7 +1290,8 @@ def load( @cute.jit def mma( self, - tiled_mma_SdP: cute.TiledMma, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, @@ -1307,11 +1326,11 @@ def mma( # kernel (before warp specialization) is a lot slower tha putting them here. # Partition smem / tmem tensors # S = K @ Q.T - tSrK = tiled_mma_SdP.make_fragment_A(sK) - tSrQ = tiled_mma_SdP.make_fragment_B(sQ) + tSrK = tiled_mma_S.make_fragment_A(sK) + tSrQ = tiled_mma_S.make_fragment_B(sQ) # dP = V @ dO.T - tdPrV = tiled_mma_SdP.make_fragment_A(sV) - tdPrdOt = tiled_mma_SdP.make_fragment_B(sdOt) + tdPrV = tiled_mma_dP.make_fragment_A(sV) + tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) # dK = dS.T @ Q # tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) tdKrdS = tiled_mma_dK.make_fragment_A(tdS) @@ -1323,14 +1342,14 @@ def mma( tdVrdO = tiled_mma_dV.make_fragment_B(sdO) tdVrP = tiled_mma_dV.make_fragment_A(tP) - # mma_qk_fn = partial(gemm_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, zero_init=True) + # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True) mma_qk_fn = partial( - gemm_ptx_w_idx, tiled_mma_SdP, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True + gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True ) - # mma_dov_fn = partial(gemm_w_idx, tiled_mma_SdP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) + # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) mma_dov_fn = partial( gemm_ptx_w_idx, - tiled_mma_SdP, + tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, @@ -1555,7 +1574,8 @@ def split_wg( @cute.jit def compute_loop( self, - thr_mma_SdP: cute.core.ThrMma, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tStS: cute.Tensor, @@ -1623,11 +1643,11 @@ def compute_loop( # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong - tScS = thr_mma_SdP.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) + tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) # tdS overlap with tdP tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) - tdPcdP = tScS + tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2])) tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tmem_load_atom = cute.make_copy_atom( @@ -1644,8 +1664,8 @@ def compute_loop( tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1) # ((32, 1), 2, 1, 1, STAGE) - tSsLSE = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sLSE_2D)) - tSsdPsum = thr_copy_t2r.partition_D(thr_mma_SdP.partition_C(sdPsum_2D)) + tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D)) + tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D)) # rmem -> tmem thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) tScP_r2t = thr_copy_r2t.partition_S(tScP) @@ -1734,7 +1754,7 @@ def compute_loop( lane_idx = cute.arch.lane_idx() tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64 tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) - for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): + for stage in cutlass.range_constexpr(num_stages): tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index] if const_expr(not self.shuffle_LSE): @@ -1791,7 +1811,7 @@ def compute_loop( # consumer_phase_S_P_dP ^= 1 ##### dS.T = P.T * (dP.T - Psum) - for stage in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[1]), unroll=1): + for stage in cutlass.range_constexpr(num_stages): tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ea81ab88f34..76d016fde73 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -794,7 +794,8 @@ def _flash_attn_bwd( qhead_per_kvhead=qhead_per_kvhead, # tile_m=m_block_size, # tile_n=n_block_size, - cluster_size=2 if not causal else 2, + cluster_size=2, + # cluster_size=1, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( From 419bdb7e3ace3811e0710cf0705b5fdd579e3576 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Oct 2025 17:58:58 -0400 Subject: [PATCH 208/258] [Cute,Bwd,Sm100] Try LPTBwdScheduler --- flash_attn/cute/flash_bwd_sm100.py | 2 + flash_attn/cute/tile_scheduler.py | 110 ++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index fe7568be125..376fc043033 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -23,6 +23,7 @@ from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, + SingleTileLPTBwdScheduler, # noqa ParamsBase, ) @@ -533,6 +534,7 @@ def __call__( self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 + # TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler TileScheduler = SingleTileScheduler # TODO -- optimizer scheduler for causal tile_sched_args = TileSchedulerArguments( diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index f9359556662..517dd8a91a5 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -316,7 +316,115 @@ def __new_from_mlir_values__(self, values): for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTBwdScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_head_divmod: FastDivmod + l2_minor_divmod: FastDivmod + l2_major_divmod: FastDivmod + l2_minor_residual_divmod: FastDivmod + num_hb_quotient: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTBwdScheduler.Params": + swizzle = 8 + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0]) + return SingleTileLPTBwdScheduler.Params( + total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch, + num_head_divmod=FastDivmod.create(args.num_head), + l2_minor_divmod=FastDivmod.create(swizzle), + l2_major_divmod=FastDivmod.create(swizzle * num_block), + l2_minor_residual_divmod=FastDivmod.create( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + cluster_shape_mn=args.cluster_shape_mn, + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = params.l2_major_divmod.divmod(cluster_idx) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + else: + block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) + bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual + batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + is_valid = self._tile_idx < params.total_blocks + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) class SingleTileVarlenScheduler: From de1584b5328321189a4d7832fe29bbd6813bf6ed Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Oct 2025 17:59:54 -0400 Subject: [PATCH 209/258] [Cute,Bwd,Sm100] Try separating warps loading Q and dO --- flash_attn/cute/flash_bwd_sm100.py | 102 ++++++++++++++++------------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 376fc043033..1044a39b453 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -992,6 +992,8 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + should_load_Q=True, + should_load_dO=True, ) # MMA @@ -1135,6 +1137,8 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + should_load_Q: bool = True, + should_load_dO: bool = True, ): producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage @@ -1219,71 +1223,79 @@ def load( # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) # First iteration: load K together w Q & LSE, then V together w dO & dPsum - # K & Q - pipeline_Q.producer_acquire( - producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) - load_Q(m_block_min, producer_state=producer_state_Q_LSE) - pipeline_Q.producer_commit(producer_state_Q_LSE) - # LSE - pipeline_LSE.producer_acquire(producer_state_Q_LSE) - with cute.arch.elect_one(): - copy_stats( - gLSE[None, m_block_min], - sLSE[None, producer_state_Q_LSE.index], - mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), - ) - producer_state_Q_LSE.advance() - # V & dO - pipeline_dO.producer_acquire( - producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) - load_dO(m_block_min, producer_state=producer_state_dO_dPsum) - pipeline_dO.producer_commit(producer_state_dO_dPsum) - # dPsum - pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) - with cute.arch.elect_one(): - copy_stats( - gdPsum[None, m_block_min], - sdPsum[None, producer_state_dO_dPsum.index], - mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + if const_expr(should_load_Q): + # K & Q + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] ) - producer_state_dO_dPsum.advance() - - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - # Q - pipeline_Q.producer_acquire(producer_state_Q_LSE) - load_Q(m_block, producer_state=producer_state_Q_LSE) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_min, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block], + gLSE[None, m_block_min], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() - # dO - pipeline_dO.producer_acquire(producer_state_dO_dPsum) - load_dO(m_block, producer_state=producer_state_dO_dPsum) + if const_expr(should_load_dO): + # V & dO + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_min, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block], + gdPsum[None, m_block_min], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() - pipeline_Q.producer_tail(producer_state_Q_LSE.clone()) # will hand if we don't clone - pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) - pipeline_LSE.producer_tail(producer_state_Q_LSE) - pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + # Q + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + # dO + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + + if const_expr(should_load_Q): + pipeline_Q.producer_tail( + producer_state_Q_LSE.clone() + ) # will hang if we don't clone + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() From 0256114fe2381ab293503219bdd9078de3cd26b3 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 31 Oct 2025 08:23:16 -0700 Subject: [PATCH 210/258] BlockSparse Tweaks (#1970) * Tweaks * better errors * Switch to new API --- flash_attn/cute/benchmark_mask_mod.py | 16 +- flash_attn/cute/block_sparsity.py | 99 ++++-- flash_attn/cute/interface.py | 29 +- flash_attn/cute/mask.py | 26 +- flash_attn/cute/mask_definitions.py | 325 ++++++++----------- flash_attn/cute/utils.py | 5 + tests/cute/test_mask_mod.py | 432 +++++++++++--------------- 7 files changed, 445 insertions(+), 487 deletions(-) diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py index 9b7950ba076..88db8418abc 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -16,10 +16,8 @@ from flash_fwd import FlashAttentionForwardSm90 from mask_definitions import ( - MASK_FUNCTIONS, + get_mask_pair, random_doc_id_tensor, - create_cute_sliding_window_mask, - create_flex_sliding_window_mask, ) from flash_attn.cute.block_sparsity import ( compute_block_sparsity, @@ -99,12 +97,12 @@ def __init__(self, config: BenchmarkConfig): config.use_mask_mod = False if config.use_mask_mod: - if config.mask_mod_name == "sliding_window": - # Use factory function for custom window size - self.mask_mod_cute = create_cute_sliding_window_mask(config.window_size) - self.mask_mod_flex = create_flex_sliding_window_mask(config.window_size) - else: - self.mask_mod_cute, self.mask_mod_flex = MASK_FUNCTIONS[config.mask_mod_name] + self.mask_mod_cute, self.mask_mod_flex = get_mask_pair( + config.mask_mod_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + window_size=config.window_size, + ) else: self.mask_mod_cute = None self.mask_mod_flex = None diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index c28df4c20d3..1a243e74127 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -24,6 +24,8 @@ class BlockSparseTensors(NamedTuple): full_block_idx: Optional[cute.Tensor] def __new_from_mlir_values__(self, values): + if len(values) == 2: + values = (*values, None, None) return BlockSparseTensors(*values) @@ -34,27 +36,82 @@ class BlockSparseTensorsTorch(NamedTuple): full_block_idx: Optional[torch.Tensor] = None -def validate_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> None: - for name, cnt, idx in ( - ("mask", tensors.mask_block_cnt, tensors.mask_block_idx), - ("full", tensors.full_block_cnt, tensors.full_block_idx), - ): - if (cnt is None) != (idx is None): - raise ValueError( - f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" - ) - if cnt is None: - continue - if cnt.dtype != torch.int32 or idx.dtype != torch.int32: - raise ValueError(f"{name}_block tensors must have dtype torch.int32") - if cnt.device != idx.device: - raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") - if not cnt.is_cuda or not idx.is_cuda: - raise ValueError(f"{name}_block tensors must live on CUDA") - - if tensors.full_block_cnt is not None and tensors.mask_block_cnt is not None: - if tensors.full_block_cnt.device != tensors.mask_block_cnt.device: - raise ValueError("All block sparse tensors must be on the same device") +def _expand_sparsity_tensor( + tensor: torch.Tensor, + expected_shape: Tuple[int, ...], + tensor_name: str, +) -> torch.Tensor: + """Check if we need to expand the tensor to expected shape, and do so if possible.""" + needs_expand = tensor.shape != expected_shape + if not needs_expand: + return tensor + can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape)) + if not can_expand: + raise ValueError( + f"{tensor_name} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." + ) + return tensor.expand(*expected_shape).contiguous() + + +def _check_and_expand_block( + name: str, + cnt: Optional[torch.Tensor], + idx: Optional[torch.Tensor], + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if (cnt is None) != (idx is None): + raise ValueError( + f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" + ) + if cnt is None or idx is None: + return None, None + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: + raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") + if not cnt.is_cuda or not idx.is_cuda: + raise ValueError(f"{name}_block tensors must live on CUDA") + expanded_cnt = _expand_sparsity_tensor(cnt, expected_count_shape, f"{name}_block_cnt") + expanded_idx = _expand_sparsity_tensor(idx, expected_index_shape, f"{name}_block_idx") + return expanded_cnt, expanded_idx + + +def normalize_block_sparse_tensors( + tensors: BlockSparseTensorsTorch, + *, + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], +) -> BlockSparseTensorsTorch: + if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + mask_cnt, mask_idx = _check_and_expand_block( + "mask", + tensors.mask_block_cnt, + tensors.mask_block_idx, + expected_count_shape, + expected_index_shape, + ) + if mask_cnt is None or mask_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + full_cnt, full_idx = _check_and_expand_block( + "full", + tensors.full_block_cnt, + tensors.full_block_idx, + expected_count_shape, + expected_index_shape, + ) + if full_cnt is not None and mask_cnt.device != full_cnt.device: + raise ValueError("All block sparse tensors must be on the same device") + + return BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 76d016fde73..c9685d461c5 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -42,8 +42,11 @@ from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch, to_cute_block_sparse_tensors - +from flash_attn.cute.block_sparsity import ( + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, + normalize_block_sparse_tensors, +) def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -132,6 +135,7 @@ def _flash_attn_fwd( assert cu_seqlens_k.shape == (batch_size + 1,), ( "cu_seqlens_k must have shape (batch_size + 1,)" ) + if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), ( "cu_seqlens_q must have shape (batch_size + 1,)" @@ -251,11 +255,18 @@ def _flash_attn_fwd( if page_table is not None else None ) - sparse_tensors = ( - to_cute_block_sparse_tensors(block_sparse_tensors) - if block_sparse_tensors is not None - else None - ) + sparse_tensors = None + if block_sparse_tensors is not None: + if seqlen_q is None: + raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") + expected_m_blocks = (seqlen_q + m_block_size - 1) // m_block_size + expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size + block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=(batch_size, num_head, expected_m_blocks), + expected_index_shape=(batch_size, num_head, expected_m_blocks, expected_n_blocks), + ) + sparse_tensors = to_cute_block_sparse_tensors(block_sparse_tensors) use_block_sparsity = sparse_tensors is not None @@ -337,7 +348,7 @@ def _flash_attn_fwd( cute_aux_tensors = None if aux_tensors is not None: - cute_aux_tensors = [from_dlpack(buf) for buf in aux_tensors] + cute_aux_tensors = [from_dlpack(buf).mark_layout_dynamic() for buf in aux_tensors] compile_key = ( dtype, @@ -348,7 +359,7 @@ def _flash_attn_fwd( score_mod_hash, mask_mod_hash, use_block_sparsity, - aux_tensors is not None, + len(aux_tensors) if aux_tensors is not None else 0, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 2d65856d223..6f92d0835ac 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -135,17 +135,23 @@ def apply_mask( # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n - cond = cutlass.Boolean( - mask_mod( - batch_idx, - head_idx, - tScS_mn[r, 0][0] + m_block * self.tile_m, - thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, - self.seqlen_q, - self.seqlen_k, - aux_tensors, - ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + q_idx_ssa = utils.scalar_to_ssa( + tScS_mn[r, 0][0] + m_block * self.tile_m, cutlass.Int32 + ) + kv_idx_ssa = utils.scalar_to_ssa( + thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, + cutlass.Int32, + ) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + q_idx_ssa, + kv_idx_ssa, + aux_tensors, ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) if const_expr(mask_seqlen): out_of_bounds = (global_row_idx >= self.seqlen_q) or ( global_col_idx >= self.seqlen_k diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 23c4f026b1c..0bb0d56751a 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -7,247 +7,150 @@ import cutlass.cute as cute import torch +from flash_attn.cute import utils + MaskModCallable = Optional[ Callable[ [ - "cutlass.Int32", - "cutlass.Int32", - "cutlass.Int32", - "cutlass.Int32", - "cutlass.Int32", - "cutlass.Int32", + "cute.TensorSSA", + "cute.TensorSSA", + "cute.TensorSSA", + "cute.TensorSSA", + "Optional[list]", ], - "cutlass.Boolean", + "cute.TensorSSA", ] ] # Flex Attention mask functions (PyTorch signatures for reference implementation) - - -def flex_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - if torch.is_tensor(q_idx): - return torch.ones_like(q_idx, dtype=torch.bool) - return True - - -def flex_identity_partial_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - if torch.is_tensor(q_idx): - return torch.ones_like(q_idx, dtype=torch.bool) - return True - - -def flex_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - # Right-aligned causal masking - if seqlen_q is not None and seqlen_k is not None: - offset = seqlen_k - seqlen_q +def get_flex_causal_mask(offset: int): + def _flex_causal_mask(b, h, q_idx, kv_idx): return kv_idx <= q_idx + offset - return kv_idx <= q_idx - -def flex_block_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - # Right-aligned causal masking - if seqlen_q is not None and seqlen_k is not None: - offset = seqlen_k - seqlen_q - return kv_idx <= q_idx + offset - return kv_idx <= q_idx + return _flex_causal_mask -def create_flex_sliding_window_mask(window_size=1024): - """Factory function to create a sliding window mask with configurable window size""" +def get_flex_block_causal_mask(offset: int): + def _flex_block_causal_mask(b, h, q_idx, kv_idx): + return kv_idx <= q_idx + offset - def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - # Sliding window: q_idx - window_size <= kv_idx <= q_idx - if seqlen_q is not None and seqlen_k is not None: - offset = seqlen_k - seqlen_q - return (kv_idx <= q_idx + offset) & (kv_idx >= q_idx + offset - window_size) - return (kv_idx <= q_idx) & (kv_idx >= q_idx - window_size) + return _flex_block_causal_mask - return flex_sliding_window_mask +def get_flex_sliding_window_mask(window_left: int, window_right: int, offset: int): + def _flex_sliding_window_mask(b, h, q_idx, kv_idx): + center = q_idx + offset + lower = center - window_left + upper = center + window_right + return (kv_idx >= lower) & (kv_idx <= upper) -# Default sliding window mask with window_size=1024 for backward compatibility -def flex_sliding_window_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - window_size = 1024 - if seqlen_q is not None and seqlen_k is not None: - offset = seqlen_k - seqlen_q - # Sliding window: q_pos - window_size < kv_pos <= q_pos - # Note: using strict inequality on the left to match typical sliding window behavior - return (kv_idx <= q_idx + offset) & (kv_idx > q_idx + offset - window_size) - return (kv_idx <= q_idx) & (kv_idx > q_idx - window_size) + return _flex_sliding_window_mask -def flex_block_diagonal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None, block_size=64): +def flex_block_diagonal_mask(b, h, q_idx, kv_idx): + block_size = 64 return (q_idx // block_size) == (kv_idx // block_size) -def flex_mini_causal_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): +def flex_mini_causal_mask(b, h, q_idx, kv_idx): return (q_idx % 128) >= (kv_idx % 128) -def flex_half_identity_mask(b, h, q_idx, kv_idx, seqlen_q=None, seqlen_k=None): - """Even k-blocks are full blocks, odd k-blocks are masked blocks (both return True)""" - if torch.is_tensor(kv_idx): - return torch.ones_like(kv_idx, dtype=torch.bool) - return True - - -def flex_document_mask(b, h, q_idx, kv_idx, doc_id: torch.Tensor): +def flex_document_mask(b, h, q_idx, kv_idx, doc_id): return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] # CuTe versions for kernel compilation +def get_cute_causal_mask(offset: int): + @cute.jit + def _cute_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors: None, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) + return _cute_causal_mask -@cute.jit -def cute_identity_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - aux_tensors: None, -) -> cutlass.Boolean: - return cutlass.Boolean(True) - - -@cute.jit -def cute_identity_partial_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - aux_tensors: None, -) -> cutlass.Boolean: - return cutlass.Boolean(True) - - -@cute.jit -def cute_causal_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - aux_tensors: None, -) -> cutlass.Boolean: - # Right-aligned causal masking - offset = seqlen_k - seqlen_q - return cutlass.Boolean(n_idx <= m_idx + offset) +def get_cute_block_causal_mask(offset: int): + @cute.jit + def _cute_block_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors: None, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) -@cute.jit -def cute_block_causal_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - aux_tensors: None, -) -> cutlass.Boolean: - # Right-aligned causal masking - offset = seqlen_k - seqlen_q - return cutlass.Boolean(n_idx <= m_idx + offset) - + return _cute_block_causal_mask -def create_cute_sliding_window_mask(window_size=1024): - """Factory function to create a CuTe sliding window mask with configurable window size""" +def get_cute_sliding_window_mask(window_left: int, window_right: int, offset: int): @cute.jit - def cute_sliding_window_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, + def _cute_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, aux_tensors, - ) -> cutlass.Boolean: - offset = seqlen_k - seqlen_q - - return cutlass.Boolean( - (n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size) - ) + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + window_left_ssa = utils.scalar_to_ssa(window_left, cutlass.Int32) + window_right_ssa = utils.scalar_to_ssa(window_right, cutlass.Int32) + center = m_idx + offset_ssa + lower = center - window_left_ssa + upper = center + window_right_ssa + return (n_idx >= lower) & (n_idx <= upper) - return cute_sliding_window_mask - - -# Default sliding window mask with window_size=1024 for backward compatibility -@cute.jit -def cute_sliding_window_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, - aux_tensors, -) -> cutlass.Boolean: - window_size = 1024 - # offset = seqlen_k - seqlen_q - offset = 0 - return cutlass.Boolean((n_idx <= m_idx + offset) and (n_idx >= m_idx + offset - window_size)) + return _cute_sliding_window_mask @cute.jit def cute_document_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, aux_tensors: list, -): +) -> cute.TensorSSA: doc_id = aux_tensors[0] - return cutlass.Boolean(doc_id[batch, head, m_idx] == doc_id[batch, head, n_idx]) + m_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], m_idx[0]], cutlass.Int32) + n_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], n_idx[0]], cutlass.Int32) + return m_doc == n_doc @cute.jit def cute_block_diagonal_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, aux_tensors, -) -> cutlass.Boolean: - return cutlass.Boolean((m_idx // 64) == (n_idx // 64)) +) -> cute.TensorSSA: + block_size_ssa = utils.scalar_to_ssa(64, cutlass.Int32) + return (m_idx // block_size_ssa) == (n_idx // block_size_ssa) @cute.jit def cute_mini_causal_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, aux_tensors, -) -> cutlass.Boolean: - """Each tile is locally causal-masked""" - m_mod = m_idx % 128 - n_mod = n_idx % 128 - return cutlass.Boolean(m_mod >= n_mod) - - -@cute.jit -def cute_half_identity_mask( - batch: cutlass.Int32, - head: cutlass.Int32, - m_idx: cutlass.Int32, - n_idx: cutlass.Int32, - seqlen_q: cutlass.Int32, - seqlen_k: cutlass.Int32, -) -> cutlass.Boolean: - return cutlass.Boolean(True) +) -> cute.TensorSSA: + tile_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) + m_mod = m_idx % tile_size_ssa + n_mod = n_idx % tile_size_ssa + return m_mod >= n_mod def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): @@ -255,7 +158,9 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): for b in range(batch): for h in range(nheads): N = seqlen_q - n = random.randint(1, math.ceil(math.sqrt(N // 4))) + max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) + n = random.randint(1, max_segments) + n = min(n, N) cuts = sorted(random.sample(range(1, N), n - 1)) lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] @@ -264,22 +169,52 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): doc_ids += [i for _ in range(length)] doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) - print(f"{doc_ids_tensor.shape = }") return doc_ids_tensor -MASK_FUNCTIONS = { - "identity": (cute_identity_mask, flex_identity_mask), - "identity_partial": (cute_identity_partial_mask, flex_identity_partial_mask), - "causal": (cute_causal_mask, flex_causal_mask), - "block_causal": (cute_block_causal_mask, flex_block_causal_mask), - "sliding_window": (cute_sliding_window_mask, flex_sliding_window_mask), +STATIC_MASKS = { "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), - "half_identity": (cute_half_identity_mask, flex_half_identity_mask), "document": (cute_document_mask, flex_document_mask), } +PARAMETERIZED_MASK_FACTORIES = { + "causal": (get_cute_causal_mask, get_flex_causal_mask), + "block_causal": (get_cute_block_causal_mask, get_flex_block_causal_mask), + "sliding_window": (get_cute_sliding_window_mask, get_flex_sliding_window_mask), +} + + +def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=None): + """Get (cute_mask, flex_mask) pair for the given mask name. + + For static masks, seqlen info is not needed. + For parameterized masks, seqlen_q and seqlen_k are required. + """ + if mask_name in STATIC_MASKS: + return STATIC_MASKS[mask_name] + + if mask_name not in PARAMETERIZED_MASK_FACTORIES: + raise ValueError(f"Unknown mask: {mask_name}") + + if seqlen_q is None or seqlen_k is None: + raise ValueError(f"Parameterized mask '{mask_name}' requires seqlen_q and seqlen_k") + + cute_factory, flex_factory = PARAMETERIZED_MASK_FACTORIES[mask_name] + offset = seqlen_k - seqlen_q + + if mask_name == "sliding_window": + if window_size is None: + raise ValueError("sliding_window mask requires window_size parameter") + cute_mask = cute_factory(window_size, window_size, offset) + flex_mask = flex_factory(window_size, window_size, offset) + else: + cute_mask = cute_factory(offset) + flex_mask = flex_factory(offset) + + return cute_mask, flex_mask + + if __name__ == "__main__": doc_ids = random_doc_id_tensor(1, 2, 128) print(f"{doc_ids = }") diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 6bd5123f100..51a017e71a1 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -781,3 +781,8 @@ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: vec = cute.make_fragment(1, dtype) vec[0] = a return vec.load() + + +def ssa_to_scalar(val): + """ Could inline but nice for reflecting the above api """ + return val[0] \ No newline at end of file diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 033d08f296f..07e63e2bc7f 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -1,8 +1,19 @@ # mask mod test script # REFACTORED to use _flash_attn_fwd as the kernel entrypoint +# +# Test Organization: +# - test_static_masks: Fast tests for masks that don't need per-seqlen compilation +# (identity, document, block_diagonal, etc.) with comprehensive seqlen coverage +# - test_parameterized_masks: Slower tests for masks that require recompilation per +# seqlen pair (causal, block_causal, sliding_window) with reduced seqlen coverage +# +# Usage: +# pytest test_mask_mod.py::test_static_masks # Run only fast tests +# pytest test_mask_mod.py::test_parameterized_masks # Run only slow tests +# pytest test_mask_mod.py # Run all tests import math -from typing import Optional, Callable +from typing import Optional import pytest import torch @@ -10,12 +21,11 @@ import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_fwd -from flash_attn.cute.block_sparsity import compute_block_sparsity, BlockSparseTensorsTorch +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch from flash_attn.cute.mask_definitions import ( - MASK_FUNCTIONS, - flex_causal_mask, - create_flex_sliding_window_mask, - create_cute_sliding_window_mask, + get_mask_pair, + STATIC_MASKS, + random_doc_id_tensor, ) from flash_attn.cute.testing import attention_ref @@ -66,7 +76,7 @@ def compute_reference_flash_attn(tensors, causal, window_size, dtype_ref, upcast return out_ref -def compute_reference_flex_attn(tensors, mask_mod_flex, mask_mod_name, tile_m, tile_n): +def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tuple[int, int]] = None): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape _, seqlen_k, nheads_kv, _ = tensors["k"].shape @@ -87,101 +97,61 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, mask_mod_name, tile_m, t out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) return out_ref.transpose(1, 2).contiguous() - # Wrap mask_mod_flex to pass seqlen_q and seqlen_k - def mask_fn(b, h, q_idx, kv_idx): - return mask_mod_flex(b, h, q_idx, kv_idx, seqlen_q, seqlen_k) - - if mask_mod_name == "block_causal": - n_blocks_q = (seqlen_q + tile_m - 1) // tile_m - n_blocks_k = (seqlen_k + tile_n - 1) // tile_n - - mask = torch.zeros(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device) - - for q_block in range(n_blocks_q): - q_start = q_block * tile_m - q_end = min((q_block + 1) * tile_m, seqlen_q) - for k_block in range(n_blocks_k): - if k_block <= q_block: - k_start = k_block * tile_n - k_end = min((k_block + 1) * tile_n, seqlen_k) - mask[q_start:q_end, k_start:k_end] = True - - attn_mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, nheads, -1, -1) - out_ref = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, scale=scale - ) - else: - block_mask = create_block_mask( - mask_fn, - B=batch_size, - H=nheads, - Q_LEN=seqlen_q, - KV_LEN=seqlen_k, - ).to(q.device) - out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) - + block_mask_kwargs = {} + if block_size is not None: + block_mask_kwargs["BLOCK_SIZE"] = block_size + + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device=q.device, + **block_mask_kwargs, + ) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) return out_ref.transpose(1, 2).contiguous() -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 1), - (64, 128), - (128, 192), - (256, 256), - (239, 1), - (799, 3), - (113, 203), - (113, 128), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (384, 256), - (640, 128), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (4096, 4096), - (4224, 4224), - ], -) -# @pytest.mark.parametrize("nheads", [4, 16, 32]) -@pytest.mark.parametrize("nheads", [16]) -@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) -# @pytest.mark.parametrize("headdim", [64, 128]) -@pytest.mark.parametrize("headdim", [128]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize( - "use_mask_mod,is_local,mask_name,window_size,window_left,window_right", - [ - # (False, False, "identity", None, None, None), - # (False, False, "causal", None, None, None), - (True, False, "identity", None, None, None), - (True, False, "causal", None, None, None), - (True, False, "block_causal", None, None, None), - # Mask mod sliding window - (True, False, "sliding_window", 128, None, None), - (True, False, "sliding_window", 256, None, None), - (True, False, "sliding_window", 512, None, None), - # Base local attention - # (False, True, None, None, 128, 0), - # (False, True, None, None, 256, 0), - # (False, True, None, None, 512, 0), - ], -) -@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) -def test_mask_mod_output( +SEQLEN_PAIRS_COMPREHENSIVE = [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), +] + +SEQLEN_PAIRS_SMOKE = [ + (128, 128), + (256, 256), + (113, 203), + (1024, 1024), +] + + +def _run_mask_test( seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, - use_mask_mod, - is_local, mask_name, window_size, window_left, @@ -191,14 +161,7 @@ def test_mask_mod_output( ): torch.manual_seed(42) - # Validate configuration - if is_local: - assert not use_mask_mod, "Cannot use both is_local and use_mask_mod" - assert window_left is not None or window_right is not None, ( - "Must specify window_left or window_right for is_local" - ) - - if use_mask_mod and mask_name == "sliding_window": + if mask_name == "sliding_window": assert window_size is not None, ( "window_size must be specified for sliding_window" ) @@ -207,12 +170,6 @@ def test_mask_mod_output( f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" ) - if is_local: - if seqlen_q > seqlen_k: - pytest.skip( - f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for is_local" - ) - # Determine nheads_kv based on mode if kv_mode == "mha": nheads_kv = nheads @@ -226,24 +183,22 @@ def test_mask_mod_output( batch_size = 1 headdim_v = headdim - # Determine mask_mod functions and causal flag - if use_mask_mod: - if mask_name == "sliding_window": - # Use factory function for custom window size - mask_mod_cute = create_cute_sliding_window_mask(window_size) - mask_mod_flex = create_flex_sliding_window_mask(window_size) - else: - mask_mod_cute, mask_mod_flex = MASK_FUNCTIONS[mask_name] - causal = False - elif is_local: - # Base local attention - no mask_mod - mask_mod_cute = None - mask_mod_flex = None - causal = False - else: - mask_mod_cute = None - mask_mod_flex = None - causal = (mask_name == "causal") if mask_name else False + aux_tensors_arg = None + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + if mask_name == "document": + doc_len = max(seqlen_q, seqlen_k) + doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( + dtype=torch.int32, device="cuda" + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + causal = False if causal and seqlen_k < seqlen_q: pytest.skip("causal masking requires seqlen_k >= seqlen_q") @@ -253,40 +208,16 @@ def test_mask_mod_output( ) # Compute block sparsity for mask_mod - full_cnt, full_idx, mask_cnt, mask_idx = None, None, None, None - if use_mask_mod: - from dataclasses import dataclass - - @dataclass - class Config: - seqlen_q: int - seqlen_k: int - nheads: int - nheads_kv: int - batch_size: int - tile_m: int - tile_n: int - use_mask_mod: bool - mask_mod_name: str - window_size: int = 1024 - verbose: bool = False - - config = Config( - seqlen_q=seqlen_q, - seqlen_k=seqlen_k, - nheads=nheads, - nheads_kv=nheads_kv, - batch_size=batch_size, - tile_m=tile_m, - tile_n=tile_n, - use_mask_mod=True, - mask_mod_name=mask_name, - window_size=window_size if window_size is not None else 1024, - ) - - full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( - config=config, mask_mod_flex=mask_mod_flex, device="cuda" - ) + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + _, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = bm.as_tuple() softmax_scale = 1.0 / math.sqrt(headdim) @@ -304,14 +235,12 @@ class Config: # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") # if mask_cnt[0,0,0] > 0: # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") - block_sparse_mask = None - if use_mask_mod: - block_sparse_mask = BlockSparseTensorsTorch( - mask_block_cnt=mask_cnt, - mask_block_idx=mask_idx, - full_block_cnt=full_cnt, - full_block_idx=full_idx, - ) + block_sparse_mask = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) out_tuple = _flash_attn_fwd( q=tensors["q"], @@ -339,74 +268,19 @@ class Config: mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask, return_lse=True, - aux_tensors=None, + aux_tensors=aux_tensors_arg, ) out_cute = out_tuple[0] + tensors_fp32 = { + k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v + for k, v in tensors.items() + } - # Determine which reference implementation to use - dtype_ref = torch.bfloat16 - use_flash_attn_ref = False - - # Use FlashAttention reference for causal and local window cases - if mask_name == "causal" and not use_mask_mod: - use_flash_attn_ref = True - window_size_ref = (None, None) # attention_ref handles causal internally - elif mask_name == "identity" and not use_mask_mod and not is_local: - use_flash_attn_ref = True - window_size_ref = (None, None) # No window for identity - elif is_local: - use_flash_attn_ref = True - window_size_ref = (window_left, window_right) - if window_right == 0: - causal = True # Override causal flag for reference computation - elif use_mask_mod and mask_name == "sliding_window": - use_flash_attn_ref = True - # For sliding window mask_mod, window_size corresponds directly to window_left - # in attention_ref (number of previous tokens that can be attended to) - # Sliding window with window_right=0 is inherently causal - window_size_ref = (window_size, 0) - causal = True # Override causal flag for reference computation - - if use_flash_attn_ref: - # Compute reference using FlashAttention's attention_ref - out_ref_fp32 = compute_reference_flash_attn( - tensors, - causal=causal, - window_size=window_size_ref, - dtype_ref=torch.float32, - upcast=True, - ) - out_ref = compute_reference_flash_attn( - tensors, - causal=causal, - window_size=window_size_ref, - dtype_ref=dtype_ref, - upcast=False, - ) - - # Also compute PyTorch reference for comparison (with reorder_ops for better accuracy) - out_pt = compute_reference_flash_attn( - tensors, - causal=causal, - window_size=window_size_ref, - dtype_ref=dtype, - upcast=False, - ) - else: - # Use flex_attention for custom mask_mods - tensors_fp32 = { - k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v - for k, v in tensors.items() - } - - out_ref_fp32 = compute_reference_flex_attn( - tensors_fp32, mask_mod_flex, mask_name, tile_m, tile_n - ) - out_ref = compute_reference_flex_attn( - tensors, mask_mod_flex, mask_name, tile_m, tile_n - ) - out_pt = out_ref.clone() + block_size = (tile_m, tile_n) + out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) + out_pt = out_ref.clone() # Check for invalid values assert out_cute.shape == out_ref_fp32.shape == out_ref.shape @@ -423,23 +297,15 @@ class Config: pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - # Build description string - if is_local: - mask_desc = f"is_local(L={window_left},R={window_right})" - elif use_mask_mod: - mask_desc = f"mask_mod={mask_name}" - if mask_name == "sliding_window" and window_size is not None: - mask_desc += f"(w={window_size})" - else: - mask_desc = mask_name if mask_name else "identity" + mask_desc = f"mask_mod={mask_name}" + if mask_name == "sliding_window" and window_size is not None: + mask_desc += f"(w={window_size})" print( f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " f"D={headdim}, M={tile_m}, N={tile_n}" ) - print( - f" Reference implementation: {'FlashAttention' if use_flash_attn_ref else 'FlexAttention'}" - ) + print(" Reference implementation: FlexAttention") print(f" Reference vs FP32: {ref_error:.2e}") print(f" PyTorch vs FP32: {pt_error:.2e}") print(f" Kernel vs FP32: {cute_error:.2e}") @@ -463,5 +329,85 @@ class Config: ) +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "mask_name", + ["block_diagonal", "mini_causal"], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) +def test_static_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, mask_name, tile_m, tile_n +): + """Test static masks that don't require recompilation per seqlen pair. + + Known good masks: + - block_diagonal: Masks by 64-element diagonal blocks + - mini_causal: Local causal within 128-element tiles + """ + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=None, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ("document", None), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) +def test_parameterized_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, mask_name, window_size, tile_m, tile_n +): + """Test parameterized masks that require recompilation per seqlen pair. + + Uses fewer seqlen combinations to reduce test time. + + Masks tested: + - causal, block_causal: Require offset = seqlen_k - seqlen_q + - sliding_window: Requires window size and offset parameters + - document: Slower to check + """ + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=window_size, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) From 6c9eef9e2f93246bcb7d03e07c642a1c103e53d2 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:44:21 -0800 Subject: [PATCH 211/258] [Cute] Fix main (#1982) --- flash_attn/cute/interface.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c9685d461c5..71e4339619e 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -955,6 +955,15 @@ def forward( mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, ): + # Only create block sparse tensors if at least one block sparse parameter is provided + block_sparse_tensors = None + if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): + block_sparse_tensors = BlockSparseTensorsTorch( + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, + ) out, lse = _flash_attn_fwd( q, k, @@ -967,12 +976,7 @@ def forward( softcap=softcap, pack_gqa=pack_gqa, mask_mod=mask_mod, - block_sparse_tensors=BlockSparseTensorsTorch( - full_block_cnt=full_block_cnt, - full_block_idx=full_block_idx, - mask_block_cnt=mask_block_cnt, - mask_block_idx=mask_block_idx, - ) + block_sparse_tensors=block_sparse_tensors ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale From e724e2588cbe754beb97cf7c011b5e7e34119e62 Mon Sep 17 00:00:00 2001 From: timmy-feng <70349932+timmy-feng@users.noreply.github.com> Date: Wed, 5 Nov 2025 02:13:26 +0100 Subject: [PATCH 212/258] [Cute,Fwd,Sm100] Implement SplitKV (#1940) * Implement split KV * Remove modal bench harness * Fixes --- flash_attn/cute/block_info.py | 17 +- flash_attn/cute/flash_bwd.py | 5 +- flash_attn/cute/flash_bwd_postprocess.py | 5 +- flash_attn/cute/flash_bwd_preprocess.py | 5 +- flash_attn/cute/flash_bwd_sm100.py | 12 +- flash_attn/cute/flash_bwd_sm90.py | 10 +- flash_attn/cute/flash_fwd.py | 11 +- flash_attn/cute/flash_fwd_combine.py | 4 +- flash_attn/cute/flash_fwd_sm100.py | 922 ++++++++++++----------- flash_attn/cute/interface.py | 96 ++- flash_attn/cute/seqlen_info.py | 53 +- flash_attn/cute/tile_scheduler.py | 110 ++- tests/cute/test_flash_attn.py | 28 +- 13 files changed, 755 insertions(+), 523 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 6382700bf16..eeaa0e3e740 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -15,12 +15,19 @@ class BlockInfo: tile_n: cutlass.Constexpr[int] is_causal: cutlass.Constexpr[bool] is_local: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit - def get_n_block_min_max(self, seqlen_info: SeqlenInfoQK, m_block: Int32) -> Tuple[Int32, Int32]: + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: cutlass.Int32 = 0, + num_splits: cutlass.Int32 = 1, + ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): m_idx_max = (m_block + 1) * self.tile_m @@ -37,6 +44,14 @@ def get_n_block_min_max(self, seqlen_info: SeqlenInfoQK, m_block: Int32) -> Tupl n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) + if cutlass.const_expr(self.is_split_kv): + num_n_blocks_per_split = ( + cutlass.Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) return n_block_min, n_block_max @cute.jit diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 12f900b3970..ce0a1b6e5e9 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -405,6 +405,7 @@ def __call__( num_block=cute.ceil_div(mK.shape[1], self.n_block_size), num_head=num_head, num_batch=num_batch, + num_splits=1, seqlen_k=0, headdim=mK.shape[2], headdim_v=mV.shape[2], @@ -505,10 +506,10 @@ def kernel( tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: - seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) m_block_min = 0 diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 45a0d102eba..14d746ba346 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -242,6 +242,7 @@ def __call__( num_block=cute.ceil_div(mdQ.shape[1], self.tile_m), num_head=num_head, num_batch=num_batch, + num_splits=1, seqlen_k=0, headdim=mdQ.shape[2], headdim_v=0, @@ -317,14 +318,14 @@ def kernel( tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() - m_block, num_head, batch_size = work_tile.tile_idx + m_block, num_head, batch_size, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK( + seqlen = SeqlenInfoQK.create( batch_size, mdQ.shape[1], 0, diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index dd5455b98c4..985391a7898 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -160,6 +160,7 @@ def __call__( num_block=cute.ceil_div(mO.shape[1], self.m_block_size), num_head=num_head, num_batch=num_batch, + num_splits=1, seqlen_k=0, headdim=0, headdim_v=mO.shape[2], @@ -212,13 +213,13 @@ def kernel( tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() - m_block, num_head, batch_size = work_tile.tile_idx + m_block, num_head, batch_size, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK( + seqlen = SeqlenInfoQK.create( batch_size, mO.shape[1], 0, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 1044a39b453..5b85c691cd0 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -541,6 +541,7 @@ def __call__( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), + 1, # num_splits cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], @@ -927,12 +928,13 @@ def kernel( self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested self.is_causal, self.is_local, + False, # is_split_kv None, None, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, + SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], mCuSeqlensQ=None, @@ -1159,7 +1161,7 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] @@ -1415,7 +1417,7 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] @@ -1723,7 +1725,7 @@ def compute_loop( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] @@ -1981,7 +1983,7 @@ def dQacc_reduce( pipeline.PipelineUserType.Producer, self.sdQaccum_stage ) while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 59d4c2c4680..641adef4846 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -397,6 +397,7 @@ def __call__( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), cute.size(mK.shape[2]), cute.size(mK.shape[3]), + 1, # num_splits cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], @@ -551,12 +552,13 @@ def kernel( self.tile_n, self.is_causal, self.is_local, + False, # is_split_kv None, None, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, + SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], mCuSeqlensQ=None, @@ -678,7 +680,7 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mK_cur = mK[None, None, head_idx, batch_idx] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) @@ -932,7 +934,7 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( @@ -1208,7 +1210,7 @@ def dQaccum_store( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - n_block, head_idx, batch_idx = work_tile.tile_idx + n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 16d57991f97..e7f93056fca 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -759,11 +759,12 @@ def kernel( self.tile_n, self.is_causal, self.is_local, + False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: @@ -1459,6 +1460,7 @@ def __call__( cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits cute.size(mK.shape[0]), mQ.shape[1], mV.shape[1], @@ -1652,12 +1654,13 @@ def kernel( self.tile_n, self.is_causal, self.is_local, + False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, + SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, @@ -1764,7 +1767,7 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # if work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = ( @@ -2106,7 +2109,7 @@ def mma( # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 4c423b80968..b23ab8ba78e 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -255,7 +255,7 @@ class SharedStorage: # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) seqlen = mO_partial.shape[0] num_head = mO_partial.shape[3] - batch_size = mO_partial.shape[4] + batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) # Create FastDivmod objects for efficient division seqlen_divmod = FastDivmod.create(seqlen) @@ -341,7 +341,7 @@ def kernel( else mLSE_partial.shape[1] ) # Handle variable length sequences using SeqlenInfo - seqlen_info = SeqlenInfo( + seqlen_info = SeqlenInfo.create( batch_idx=batch_idx, seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 1ec7dce3a1a..6e030b17615 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -5,8 +5,9 @@ # - hdim 64, 96, 128, (192, 128). # - varlen # - sliding window +# - split-kv # Unsupported features that will be added later: -# - split-kv (optimizing for inference) +# - page size != 128 # - more hdim (192, 256) # Based on the cutlass example and cute-dsl example: # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha @@ -68,6 +69,7 @@ def __init__( qhead_per_kvhead: cutlass.Constexpr[int] = 1, is_causal: bool = False, is_local: bool = False, + is_split_kv: bool = False, pack_gqa: bool = False, m_block_size: int = 128, n_block_size: int = 128, @@ -101,11 +103,15 @@ def __init__( self.is_causal = is_causal self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead + self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa if pack_gqa: assert m_block_size % self.qhead_per_kvhead == 0, ( "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" ) + assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( + "SplitKV is not supported for hdim >= 192" + ) self.score_mod = score_mod if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 @@ -114,9 +120,11 @@ def __init__( # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) self.s0_s1_barrier = False - self.overlap_sO_sQ = self.head_dim_padded == 192 and self.head_dim_v_padded >= 64 + self.overlap_sO_sQ = ( + (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or + (self.head_dim_v_padded >= 128 and self.is_split_kv) + ) if self.overlap_sO_sQ: - assert self.head_dim_padded >= self.head_dim_v_padded # We assume sQ is larger than sO self.is_persistent = False self.softmax0_warp_ids = (0, 1, 2, 3) @@ -255,18 +263,23 @@ def __call__( cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO) ] - QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) - for t in (mQ, mO) - ] + Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose)) # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] - LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + if const_expr(self.is_split_kv): + O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] + num_splits = mO.shape[0] + else: + O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + num_splits = Int32(1) + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) mLSE = ( cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) @@ -408,7 +421,7 @@ def __call__( ) shape_O_packed = ( (self.qhead_per_kvhead, mO.shape[0]), - mK.shape[1], + mO.shape[1], mK.shape[2], *mO.shape[3:], ) @@ -528,6 +541,7 @@ def __call__( cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits, cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], @@ -543,6 +557,7 @@ def __call__( element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, lpt=self.is_causal or self.is_local, + is_split_kv=self.is_split_kv, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler @@ -565,6 +580,10 @@ def __call__( self.mbar_total = self.mbar_P_full_2_offset + 2 sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + sQ_size = ( + cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else + cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) + ) @cute.struct class SharedStorage: @@ -580,7 +599,7 @@ class SharedStorage: self.buffer_align_bytes, ] sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + cute.struct.MemRange[self.q_dtype, sQ_size], self.buffer_align_bytes, ] sK: cute.struct.Align[ @@ -647,6 +666,7 @@ class SharedStorage: tiled_mma_qk, tiled_mma_pv, tile_sched_params, + num_splits, aux_tensors, fastdiv_mods, ).launch( @@ -690,6 +710,7 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, + num_splits: Int32, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), ): @@ -801,7 +822,7 @@ def kernel( if const_expr(not self.overlap_sO_sQ): sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) else: - sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer) + sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer) sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) @@ -845,12 +866,13 @@ def kernel( self.cta_tiler[1], self.is_causal, self.is_local, + self.is_split_kv, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( - SeqlenInfoQK, + SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) @@ -898,6 +920,7 @@ def kernel( pipeline_kv, mbar_ptr, block_info, + num_splits, SeqlenInfoCls, TileSchedulerCls, ) @@ -926,6 +949,7 @@ def kernel( pipeline_kv, mbar_ptr, block_info, + num_splits, SeqlenInfoCls, TileSchedulerCls, ) @@ -949,7 +973,15 @@ def kernel( if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.epilogue_s2g( - mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, ) # /////////////////////////////////////////////////////////////////////////////// @@ -968,6 +1000,7 @@ def kernel( learnable_sink=learnable_sink, mbar_ptr=mbar_ptr, block_info=block_info, + num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, TileSchedulerCls=TileSchedulerCls, @@ -1016,6 +1049,7 @@ def kernel( mbar_ptr, softmax_scale_log2, block_info, + num_splits, SeqlenInfoCls, TileSchedulerCls, ) @@ -1041,6 +1075,7 @@ def load( pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, + num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): @@ -1051,7 +1086,7 @@ def load( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) @@ -1125,30 +1160,33 @@ def load( K_or_V="V", ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - page_idx = ( - mPageTable[batch_idx, n_block_max - 1] - if const_expr(mPageTable is not None) - else None - ) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 - kv_producer_state.advance() - if const_expr(self.q_stage == 2): - load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 - q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 2 - i + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if n_block_min < n_block_max: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 page_idx = ( - mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + mPageTable[batch_idx, n_block_max - 1] + if const_expr(mPageTable is not None) + else None ) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + if const_expr(self.q_stage == 2): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + ) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1168,6 +1206,7 @@ def mma( pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, + num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): @@ -1212,60 +1251,128 @@ def mma( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - for stage in cutlass.range_constexpr(self.q_stage): - # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) - # 1. wait for Q0 / Q1 - cute.arch.mbarrier_wait( + if n_block_min < n_block_max: + for stage in cutlass.range_constexpr(self.q_stage): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait( mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase ) - # 2. wait for K0 - if const_expr(stage == 0): + # 2. wait for K0 + if const_expr(stage == 0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + O_should_accumulate = False + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) - tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] - # We don't need to acquire empty S0 / S1. - # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 - # are empty. For subsequent iterations, the wait happened at the end - # of the while loop. - # 3. gemm - # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) - sK_cur = sK[None, None, None, mma_kv_consumer_state.index] - if const_expr(self.uneven_kv_smem): - sK_cur = self.offset_kv_smem( - sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase - ) - gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) - # 4. release S0 / S1 + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase, + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if const_expr(stage == 1): + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if const_expr(stage == 0): + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) - mma_q_consumer_phase ^= 1 - # 5. release K0 - pipeline_kv.consumer_release(mma_kv_consumer_state) - mma_kv_consumer_state.advance() - # End of GEMM (Q1 * K0 -> S1) - # Note: Q0 & Q1 are still needed in the seqlen_kv loop - # so we need to release them after the seqlen_kv loop - - # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate - O_should_accumulate = False - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) - mma_kv_release_state = mma_kv_consumer_state.clone() Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(2): - # 2. acquire corrected O0/O1_partial and P0 / P1 - # For the first iteration in this work tile, waiting for O0/O1_partial - # means that the correction warps has finished reading tO during - # the last iteration of the previous work tile has finished. + # 2. acquire corrected Oi_partial and Pi cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase, + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase ) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) @@ -1280,86 +1387,19 @@ def mma( mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase, ) - # 4. release accumulated O0_partial / O1_partial - # Don't need to signal O_full to the correction warps anymore since the - # correction warps wait for the softmax warps anyway. By the time the softmax - # warps finished, S_i for the next iteration must have been done, so O_i-1 - # must have been done as well. - # with cute.arch.elect_one(): - # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) - # 5. release V(i-1) - if const_expr(stage == 1): - pipeline_kv.consumer_release(mma_kv_release_state) - mma_kv_release_state.advance() - # End of GEMM_PV00 (P0 * V0 -> O0_partial) - - # GEMM_QK0i (Q0 * Ki -> S0) - # 1. wait for Ki - if const_expr(stage == 0): - mma_kv_consumer_state.advance() - pipeline_kv.consumer_wait(mma_kv_consumer_state) - Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase - # 2. gemm - # Don't need to wait for the softmax warp to have finished reading the previous - # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si - # has been read and Pi has been written. - # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) - sK_cur = sK[None, None, None, Ki_index] - if const_expr(self.uneven_kv_smem): - sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) - gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) - # 3. release S0 + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warp, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) - # End of GEMM_QK0i (Q0 * Ki -> S0) - # 4. release Ki + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end pipeline_kv.consumer_release(mma_kv_consumer_state) mma_kv_consumer_state.advance() - P_full_O_rescaled_phase ^= 1 - O_should_accumulate = True - # End of seqlen_kv loop - - # release Q0 & Q1 - with cute.arch.elect_one(): - for stage in cutlass.range_constexpr(self.q_stage): - tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) - - # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop - # 1. wait for V0 - pipeline_kv.consumer_wait(mma_kv_consumer_state) - Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase - tOrVi = tOrV[None, None, None, Vi_index] - for stage in cutlass.range_constexpr(2): - # 2. acquire corrected Oi_partial and Pi - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase - ) - # 3. gemm - # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) - # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) - sV_cur = sV[None, None, None, Vi_index] - if const_expr(self.uneven_kv_smem): - sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) - gemm_Pi[stage]( - tCrB=tOrVi, - sB=sV_cur, - zero_init=not O_should_accumulate, - mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase, - ) - # 4. release accumulated O0_partial - # We do need O_full here since for the last tile, by the time the softmax warp - # has signaled to the correction warp, the softmax warp has just finished compute - # the row sum of the current tile. It does not guarantee that the 1st tile - # of the next work tile has been computed yet. - with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) - # End of GEMM_PV00 (P0 * V0 -> O0_partial) - P_full_O_rescaled_phase ^= 1 - # 5. release Vi_end - pipeline_kv.consumer_release(mma_kv_consumer_state) - mma_kv_consumer_state.advance() - # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile tile_scheduler.advance_to_next_work() @@ -1380,6 +1420,7 @@ def softmax_loop( learnable_sink: Optional[cute.Tensor], mbar_ptr: cute.Pointer, block_info: BlockInfo, + num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, @@ -1448,118 +1489,119 @@ def softmax_loop( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( - mask.apply_mask_sm100, - m_block=self.q_stage * m_block + stage, - thr_mma=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - mask_causal=self.is_causal, - mask_local=self.is_local, - ) - softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, - ) - softmax.reset() - - softmax_step = partial( - self.softmax_step, - softmax=softmax, - mbar_ptr=mbar_ptr, - mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, - thr_mma_qk=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - thr_tmem_store=thr_tmem_store, - thr_tmem_store_scale=thr_tmem_store_scale, - tStS_t2r=tStS_t2r, - tStScale_r2t=tStScale_r2t, - tStP_r2t=tStP_r2t, - sScale=sScale, - stage=stage, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=self.q_stage * m_block + stage, - seqlen=seqlen, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, - ) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if n_block_min < n_block_max: + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask_sm100, + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + ) + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase - ) - si_corr_producer_phase ^= 1 - - # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block_max - 1, - is_first=True, - mask_fn=partial(mask_fn, mask_seqlen=True), - ) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase ) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), - ) - ) - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - n_block = n_block_max - n_tile - 1 + si_corr_producer_phase ^= 1 + # 1 masking iter mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block ) - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) ) - ) - # Now that we no longer already have the 1st iteration, need mask_seqlen=True here - - # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape - # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) - # tSrScale_r2t[0] = softmax.row_sum[0] - # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) - # cute.arch.fence_view_async_tmem_store() - sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - if const_expr(mLSE is not None or learnable_sink is not None): - sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ - 0 - ] - # if tidx == 0: - # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) - # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape + # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) + # tSrScale_r2t[0] = softmax.row_sum[0] + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ + 0 + ] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) # # Write LSE to gmem # if const_expr(mLSE is not None): @@ -1726,6 +1768,7 @@ def correction_loop( mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, block_info: BlockInfo, + num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): @@ -1757,24 +1800,70 @@ def correction_loop( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - # Ignore first signal from softmax as no correction is required - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase - ) - softmax_corr_consumer_phase ^= 1 + # Default LSE to -inf for invalid split_idx tiles + stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage - tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) - for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): - for stage in cutlass.range_constexpr(2): - # wait for S0 / S1 + if n_block_min < n_block_max: + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase + ) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) + for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for stage in cutlass.range_constexpr(2): + # wait for S0 / S1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[tidx + stage * self.m_block_size] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale( + thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + # End of seqlen_corr_loop_steps + + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase, @@ -1782,90 +1871,64 @@ def correction_loop( # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] - scale = sScale[tidx + stage * self.m_block_size] - should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 - # should_rescale = True - # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) - # Don't need O_full anymore, since by the time softmax has signaled the correction - # warps, S_i must have been done, so O_i-1 must have been done as well. - # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) - if should_rescale: - self.correction_rescale( - thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - cute.arch.mbarrier_arrive( - mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + row_sum = sScale[tidx + stage * self.m_block_size] + if const_expr(mLSE is not None or learnable_sink is not None): + row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + else: + row_max = None + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + sink_val = learnable_sink_val[stage] + if const_expr(not self.is_split_kv) or split_idx == 0: + if row_max == -Float32.inf: + # It's possible to have an empty row with splitKV. + row_max = sink_val * (LOG2_E / softmax_scale_log2) + row_sum = Float32(1.0) + else: + row_sum += utils.exp2f( + sink_val * LOG2_E - row_max * softmax_scale_log2 + ) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase ) - softmax_corr_consumer_phase ^= 1 - # o_corr_consumer_phase ^= 1 - # End of seqlen_corr_loop_steps - - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) - - # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without - # additional sync because the MMA in the top half must have been done. - # Similarly we can write to stage 1 of sO without additional sync. - stats = [None] * self.q_stage - learnable_sink_val = [None] * self.q_stage - if const_expr(learnable_sink is not None): - if const_expr(not self.pack_gqa): - sink_val = Float32(learnable_sink[head_idx]) - learnable_sink_val = [sink_val] * self.q_stage - else: # Each thread might have a different sink value due to different q_head - for stage in cutlass.range_constexpr(self.q_stage): - q_head_idx = ( - (self.q_stage * m_block + stage) * self.m_block_size + tidx - ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead - learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) - for stage in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + stage, - softmax_corr_consumer_phase, - ) - # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) - # cute.arch.fence_view_async_tmem_load() - # scale = tSrScale_t2r[0] - row_sum = sScale[tidx + stage * self.m_block_size] - if const_expr(mLSE is not None or learnable_sink is not None): - row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] - else: - row_max = None - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) - if const_expr(learnable_sink is not None): - LOG2_E = math.log2(math.e) - row_sum += utils.exp2f( - learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase ) - acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum - stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) - scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase - ) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase - ) - self.correction_epilogue( - thr_mma_pv, - tOtOs[stage], - tidx, - scale, - sO[None, None, stage], - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) - # Signal for the next work tile that O buffers in tmem are already read, so - # mma warp can write to them - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + self.correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + scale, + sO[None, None, stage], + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[None, head_idx, batch_idx] + if const_expr(self.is_split_kv): + mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx] + else: + mLSE_cur = mLSE[None, head_idx, batch_idx] else: offset = ( seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) ) - mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + if const_expr(self.is_split_kv): + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx]) + else: + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): gLSE = cute.local_tile( mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) @@ -1888,10 +1951,6 @@ def correction_loop( # This actually just works with PackGQA too gLSE[tidx] = lse - o_corr_consumer_phase ^= 1 - softmax_corr_consumer_phase ^= 1 - corr_epi_producer_phase ^= 1 - # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) # gO = gO_qdhb[None, None, None, head_idx, batch_idx] # tOsO, tOgO = cpasync.tma_partition( @@ -2060,6 +2119,8 @@ def epilogue_s2g( gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: int, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): @@ -2067,86 +2128,93 @@ def epilogue_s2g( tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: - m_block, head_idx, batch_idx = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) - if const_expr(self.use_tma_O): - store_O, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_O, 0, cute.make_layout(1), sO, gO - ) - for stage in cutlass.range_constexpr(self.q_stage): - # wait from corr, issue tma store on smem - # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if n_block_min < n_block_max: + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(self.use_tma_O): + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO ) - # 2. copy O0 / O1 to gmem - store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) - cute.arch.cp_async_bulk_commit_group() - for stage in cutlass.range_constexpr(self.q_stage): - # Ensure O0 / O1 buffer is ready to be released - cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) - else: - tidx = cute.arch.thread_idx()[0] % ( - cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) - ) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - tOsO = gmem_thr_copy_O.partition_S(sO) - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tOgO = gmem_thr_copy_O.partition_D(gO) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it - assert not self.pack_gqa - pack_gqa = PackGQA( - self.m_block_size, - self.head_dim_v_padded, - self.check_hdim_v_oob, - self.qhead_per_kvhead, - ) - for stage in cutlass.range_constexpr(self.q_stage): - # wait from corr, issue tma store on smem - # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) + # 2. copy O0 / O1 to gmem + store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) + cute.arch.cp_async_bulk_commit_group() + for stage in cutlass.range_constexpr(self.q_stage): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + else: + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) ) - # 2. copy O0 / O1 to gmem - # load acc O from smem to rmem for wider vectorization - tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) - cute.autovec_copy(tOsO[None, None, None, stage], tOrO) - # copy acc O from rmem to gmem - if const_expr(not self.pack_gqa): - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if ( - t0OcO[0, rest_m, 0][0] - < seqlen.seqlen_q - - (self.q_stage * m_block + stage) * self.m_block_size - - tOcO[0][0] - ): - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] - if self.check_hdim_v_oob - else None, - ) - else: - pack_gqa.store_O( - mO_cur, - tOrO, - gmem_tiled_copy_O, - tidx, - self.q_stage * m_block + stage, - seqlen.seqlen_q, + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if self.check_hdim_v_oob + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + epi_consumer_phase ^= 1 # Advance to next tile - epi_consumer_phase ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 71e4339619e..2158cb51933 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -59,6 +59,16 @@ def maybe_contiguous(x): } +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): + # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if num_n_blocks <= 4: + return 1 + + # NOTE: We should revisit this heuristic after persistence is supported for split KV. + # Sometimes, it's ideal to over-schedule splits for better efficiency. + return min(num_SMs // total_mblocks, max_splits, num_n_blocks) + + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -80,6 +90,7 @@ def _flash_attn_fwd( m_block_size: int = 128, n_block_size: int = 128, num_threads: int = 384, + num_splits: int = 1, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, score_mod: Optional[Callable] = None, @@ -229,15 +240,6 @@ def _flash_attn_fwd( assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] - q_tensor, k_tensor, v_tensor, o_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) - for t in (q, k, v, out) - ] - lse_tensor = ( - from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) - if lse is not None - else None - ) ( cu_seqlens_q_tensor, cu_seqlens_k_tensor, @@ -301,6 +303,40 @@ def _flash_attn_fwd( or (cu_seqlens_q is not None or seqused_q is not None) ): pack_gqa = False + # TODO: fix GQA + SplitKV + non-varlen + if pack_gqa and num_splits != 1 and cu_seqlens_q is None: + pack_gqa = False + + if num_splits < 1: + max_seqlen_k = seqlen_k if cu_seqlens_k is None else (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + max_seqlen_q = seqlen_q if cu_seqlens_q is None else (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead + seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) + num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + num_m_blocks = (seqlen_q_packgqa + m_block_size - 1) // m_block_size + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_splits = num_splits_heuristic( + total_mblocks, + torch.cuda.get_device_properties(device).multi_processor_count, + num_n_blocks, + 128, + ) + + is_split_kv = num_splits > 1 + if is_split_kv: + out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) + lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) + + q_tensor, k_tensor, v_tensor, o_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out if not is_split_kv else out_partial) + ] + if is_split_kv: + lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 1) + elif lse is not None: + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + else: + lse_tensor = None # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False @@ -372,6 +408,7 @@ def _flash_attn_fwd( m_block_size, n_block_size, num_threads, + is_split_kv, pack_gqa, compute_capability, ) @@ -379,6 +416,7 @@ def _flash_attn_fwd( if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" + assert not is_split_kv, "SplitKV not supported on SM 9.0" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, @@ -412,11 +450,13 @@ def _flash_attn_fwd( qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, + is_split_kv=is_split_kv, pack_gqa=pack_gqa, is_persistent=not causal and not local and cu_seqlens_q is None - and seqused_q is None, + and seqused_q is None + and not is_split_kv, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, ) @@ -464,6 +504,15 @@ def _flash_attn_fwd( sparse_tensors, cute_aux_tensors, ) + if is_split_kv: + _flash_attn_fwd_combine( + out_partial, + lse_partial.transpose(-1, -2), + out, + lse.transpose(-1, -2) if lse is not None else None, + cu_seqlens_q, + seqused_q, + ) return out, lse @@ -948,6 +997,7 @@ def forward( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, pack_gqa: Optional[bool] = None, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, @@ -974,6 +1024,7 @@ def forward( window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, pack_gqa=pack_gqa, mask_mod=mask_mod, block_sparse_tensors=block_sparse_tensors @@ -1019,6 +1070,7 @@ def forward( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, pack_gqa: Optional[bool] = None, ): out, lse = _flash_attn_fwd( @@ -1036,6 +1088,7 @@ def forward( window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -1078,6 +1131,7 @@ def flash_attn_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, pack_gqa: Optional[bool] = None, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, @@ -1094,6 +1148,7 @@ def flash_attn_func( window_size, learnable_sink, softcap, + num_splits, pack_gqa, mask_mod, full_block_cnt, @@ -1117,6 +1172,7 @@ def flash_attn_varlen_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, + num_splits: int = 1, pack_gqa: Optional[bool] = None, ): return FlashAttnVarlenFunc.apply( @@ -1133,6 +1189,7 @@ def flash_attn_varlen_func( window_size, learnable_sink, softcap, + num_splits, pack_gqa, ) @@ -1217,12 +1274,12 @@ def _flash_attn_fwd_combine( # Convert to cute tensors (using kernel-formatted tensors) out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic( - leading_dim=4 + leading_dim=4 if not is_varlen else 3 ) lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=lse_partial.ndim - 2 ) - out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3) + out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3 if not is_varlen else 2) lse_tensor = ( from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) if lse is not None @@ -1278,7 +1335,7 @@ def _flash_attn_fwd_combine( num_threads=256, ): raise RuntimeError( - f"FlashAttention combine kernel cannot be implemented with given parameters" + "FlashAttention combine kernel cannot be implemented with given parameters" ) _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( @@ -1315,6 +1372,8 @@ def flash_attn_combine( lse_partial: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. @@ -1332,6 +1391,8 @@ def flash_attn_combine( - (num_splits, total_q, num_heads) for variable length input out: Optional output tensor. If None, will be created automatically. out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch return_lse: Whether to return the combined LSE tensor. Default is True. Returns: @@ -1397,5 +1458,12 @@ def flash_attn_combine( else: lse = None - _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) + _flash_attn_fwd_combine( + out_partial, + lse_partial, + out, + lse, + cu_seqlens, + seqused, + ) return out, lse diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 792da01bd90..0851ddd0522 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -1,4 +1,5 @@ from typing import Optional +from dataclasses import dataclass import cutlass import cutlass.cute as cute @@ -11,26 +12,39 @@ """ +@dataclass(frozen=True) class SeqlenInfo: - def __init__( - self, + offset: cutlass.Int32 + seqlen: cutlass.Int32 + + @staticmethod + def create( batch_idx: cutlass.Int32, seqlen_static: cutlass.Int32, cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, ): - self.offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] if const_expr(seqused is not None): - self.seqlen = seqused[batch_idx] + seqlen = seqused[batch_idx] elif const_expr(cu_seqlens is not None): - self.seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: - self.seqlen = seqlen_static + seqlen = seqlen_static + return SeqlenInfo(offset, seqlen) +@dataclass(frozen=True) class SeqlenInfoQK: - def __init__( - self, + offset_q: cutlass.Int32 + offset_k: cutlass.Int32 + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + + @staticmethod + def create( batch_idx: cutlass.Int32, seqlen_q_static: cutlass.Int32, seqlen_k_static: cutlass.Int32, @@ -39,26 +53,29 @@ def __init__( mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, ): - self.offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] - self.offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] if const_expr(mSeqUsedQ is not None): - self.seqlen_q = mSeqUsedQ[batch_idx] + seqlen_q = mSeqUsedQ[batch_idx] else: - self.seqlen_q = ( + seqlen_q = ( seqlen_q_static if const_expr(mCuSeqlensQ is None) - else mCuSeqlensQ[batch_idx + 1] - self.offset_q + else mCuSeqlensQ[batch_idx + 1] - offset_q ) if const_expr(mSeqUsedK is not None): - self.seqlen_k = mSeqUsedK[batch_idx] + seqlen_k = mSeqUsedK[batch_idx] else: - self.seqlen_k = ( + seqlen_k = ( seqlen_k_static if const_expr(mCuSeqlensK is None) - else mCuSeqlensK[batch_idx + 1] - self.offset_k + else mCuSeqlensK[batch_idx + 1] - offset_k ) - self.has_cu_seqlens_q: int = mCuSeqlensQ is not None - self.has_cu_seqlens_k: int = mCuSeqlensK is not None + has_cu_seqlens_q: int = mCuSeqlensQ is not None + has_cu_seqlens_k: int = mCuSeqlensK is not None + return SeqlenInfoQK( + offset_q, offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, has_cu_seqlens_k + ) def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor: """Seqlen must be the first dimension of mQ""" diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 517dd8a91a5..1ee11f6d11c 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -2,15 +2,28 @@ from typing import Optional, Tuple from dataclasses import dataclass, fields +from typing import override import cutlass +from cutlass._mlir import ir import cutlass.cute as cute -from cutlass import Int32 +from cutlass import Int32, const_expr import flash_attn.cute.utils as utils from flash_attn.cute.fast_math import FastDivmod, clz +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @dataclass class ParamsBase: def __extract_mlir_values__(self): @@ -40,6 +53,7 @@ class TileSchedulerArguments(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 + num_splits: Int32 seqlen_k: Int32 headdim: Int32 headdim_v: Int32 @@ -52,6 +66,7 @@ class TileSchedulerArguments(ParamsBase): element_size: cutlass.Constexpr[int] = 2 is_persistent: cutlass.Constexpr[bool] = False lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -60,15 +75,27 @@ class Params(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmod + is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) @staticmethod def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileScheduler.Params": - return SingleTileScheduler.Params(args.num_block, args.num_head, args.num_batch, args.cluster_shape_mn) + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmod.create(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + ) - def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params self._blk_coord = blk_coord self._is_first_block = True self._loc = loc @@ -81,7 +108,7 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": blk_coord = cute.arch.block_idx() - return SingleTileScheduler(blk_coord, loc=loc, ip=ip) + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) # called by host @staticmethod @@ -93,10 +120,18 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" - return cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head, params.num_batch + return cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head * params.num_splits, params.num_batch - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: - return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = self.params.num_splits_divmod.divmod(head_idx) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) @@ -109,7 +144,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self._blk_coord]: + for obj in [self.params, self._blk_coord]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -117,7 +152,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self._blk_coord], self._values_pos): + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) @@ -167,14 +202,14 @@ def get_grid_shape( return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) # @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: hn_idx, block_idx = self.params.num_block_divmod.divmod(self._tile_idx) batch_idx, head_idx = self.params.num_head_divmod.divmod(hn_idx) is_valid = self._tile_idx < self.params.total_blocks # if cute.arch.thread_idx()[0] == 0: # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) - return cutlass.utils.WorkTileInfo( - (Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -206,12 +241,14 @@ class SingleTileLPTScheduler: @dataclass class Params(ParamsBase): total_blocks: Int32 + num_splits: Int32 num_block_divmod: FastDivmod num_head_divmod: FastDivmod l2_minor_divmod: FastDivmod l2_major_divmod: FastDivmod l2_minor_residual_divmod: FastDivmod num_hb_quotient: Int32 + is_split_kv: cutlass.Constexpr[bool] = False @staticmethod @cute.jit @@ -244,11 +281,14 @@ def create( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + is_split_kv=args.is_split_kv, ) - def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx + self._split_idx = split_idx self._loc = loc self._ip = ip @@ -259,8 +299,8 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod @cute.jit def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": - tile_idx = cute.arch.block_idx()[0] - return SingleTileLPTScheduler(params, tile_idx, loc=loc, ip=ip) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) # called by host @staticmethod @@ -270,10 +310,10 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: - return (params.total_blocks, Int32(1), Int32(1)) + return (params.total_blocks, params.num_splits, Int32(1)) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = params.l2_major_divmod.divmod(self._tile_idx) @@ -289,8 +329,8 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: # Longest-processing-time-first block = params.num_block_divmod.divisor - 1 - block is_valid = self._tile_idx < params.total_blocks - return cutlass.utils.WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -305,7 +345,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx]: + for obj in [self.params, self._tile_idx, self._split_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -313,7 +353,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): + for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*(tuple(obj_list)), loc=self._loc) @@ -397,8 +437,8 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] - return cutlass.utils.WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -433,12 +473,14 @@ class Params(ParamsBase): num_head: Int32 num_batch: Int32 total_q: Int32 + num_splits: Int32 max_kvblock_in_l2: Int32 tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False @staticmethod @cute.jit @@ -454,17 +496,20 @@ def create( num_head=args.num_head, num_batch=args.num_batch, total_q=args.total_q, + num_splits=args.num_splits, max_kvblock_in_l2=max_kvblock_in_l2, tile_shape_mn=args.tile_shape_mn, mCuSeqlensQ=args.mCuSeqlensQ, mSeqUsedQ=args.mSeqUsedQ, qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, lpt=args.lpt, + is_split_kv=args.is_split_kv, ) - def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx + self._split_idx = split_idx self._is_first_block = True self._loc = loc self._ip = ip @@ -475,8 +520,8 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": - tile_idx = cute.arch.block_idx()[0] - return SingleTileVarlenScheduler(params, tile_idx, loc=loc, ip=ip) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) # called by host @staticmethod @@ -489,7 +534,7 @@ def get_grid_shape( total_blocks_max = ( params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) ) // params.tile_shape_mn[0] - return (total_blocks_max * params.num_head, Int32(1), Int32(1)) + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @cute.jit def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: @@ -515,7 +560,7 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: ) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) @@ -584,8 +629,9 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: block = mh_block - head_idx * num_m_blocks is_valid = self._is_first_block and batch_idx < params.num_batch # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) - return cutlass.utils.WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -600,7 +646,7 @@ def advance_to_next_work(self, *, loc=None, ip=None): def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx]: + for obj in [self.params, self._tile_idx, self._split_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -608,7 +654,7 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx], self._values_pos, + for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 7dc132e4f7e..481e22f731b 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -2,6 +2,7 @@ import math import itertools +import os import pytest import torch @@ -27,20 +28,23 @@ ) +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_learnable_sink", [False, True]) -@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -222,8 +226,9 @@ def test_flash_attn_output( print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1] + num_splits_vals = [1] # [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -237,7 +242,7 @@ def test_flash_attn_output( softcap=softcap, learnable_sink=learnable_sink, # pack_gqa=pack_gqa, - # num_splits=num_splits + num_splits=num_splits, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -260,6 +265,7 @@ def test_flash_attn_output( and not local and dv == d and learnable_sink is None + and mha_type == "mha" # and False ): g = torch.randn_like(out) @@ -568,7 +574,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): pack_gqa_vals = [False, True, None] # num_splits_vals = [1, 3] - num_splits_vals = [1] + # SplitKV is not supported for hdim >= 192 + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad, lse = flash_attn_varlen_func( q_unpad, @@ -587,6 +594,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, + num_splits=num_splits, pack_gqa=pack_gqa, ) out = output_pad_fn(out_unpad) @@ -1097,7 +1105,7 @@ def test_flash_attn_kvcache( k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() # num_splits_vals = [1, 0] - num_splits_vals = [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] # precompute_metadata_vals = [False, True] precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product( From ad70a007e6287d4f7e766f94bcf2f9a813f20f6b Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 5 Nov 2025 08:45:59 -0800 Subject: [PATCH 213/258] [Cute] Extract block-sparse utilities from SM80/90 (#1984) - Create block_sparse_utils.py with SM80/90 block-sparse logic - Refactor flash_fwd.py to use extracted utilities - Clean up whitespace in block_sparsity.py This extracts the block-sparse consumer loop and related utilities from flash_fwd.py into a reusable module for SM80/90 architectures. --- flash_attn/cute/block_sparse_utils.py | 419 ++++++++++++++++++++++++++ flash_attn/cute/block_sparsity.py | 1 + flash_attn/cute/flash_fwd.py | 327 +++----------------- 3 files changed, 461 insertions(+), 286 deletions(-) create mode 100644 flash_attn/cute/block_sparse_utils.py diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py new file mode 100644 index 00000000000..d1cb95e18ed --- /dev/null +++ b/flash_attn/cute/block_sparse_utils.py @@ -0,0 +1,419 @@ +""" +Block-sparse runtime utilities for CUTE DSL kernels. + +This module contains runtime execution functions for block-sparse attention kernels. +These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. +""" + +from typing import Callable +from functools import partial +import cutlass +import cutlass.cute as cute +from cutlass import const_expr + +# Import data structures from block_sparsity +from flash_attn.cute.block_sparsity import BlockSparseTensors + + +@cute.jit +def load_block_list( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + first_block_preloaded: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, +): + """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. + for the intra_wg_overlap case, we overlap the loads of K and V. And this + means we need to pipeline the last V load from the partial block case, + with the loads for the full blocks. Set first_block_preloaded when the + caller has already issued the first K load for the list. + + Note: + we iterate along the block_n indices in reverse. + + Returns: + Updated kv_producer_state after processing the block list. + + """ + if block_count > 0: + if const_expr(not intra_wg_overlap): + # Peel first iteration: the first block may need to load Q alongside K, + # Parameters are already Constexpr, so no need to wrap in const_expr() + n_block_first = block_indices[block_count - 1] + extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_first, producer_state=kv_producer_state) + kv_producer_state.advance() + + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + n_block_first = block_indices[block_count - 1] + if const_expr(not first_block_preloaded): + extra_tx = ( + tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + ) + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + + for idx in cutlass.range(block_count - 1, unroll=1): + n_block_prev = block_indices[block_count - 1 - idx] + n_block = block_indices[block_count - 2 - idx] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + + return kv_producer_state + + +@cute.jit +def finish_overlap_v_load( + block_indices: cute.Tensor, + block_count, + load_V, + pipeline_v, + kv_producer_state, +): + """Load the final V block after overlapped K/V loads.""" + if block_count > 0: + n_block_last = block_indices[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_last, producer_state=kv_producer_state) + kv_producer_state.advance() + + return kv_producer_state + + +@cute.jit +def produce_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, +): + """Iterate over the mask and full block lists for a single tile. + + The masked (partial) list may leave the last V load pending when intra-warp-group + overlap is enabled. The first full block must consume that pending V while + issuing its own K load on the next pipeline stage. + + In the intra-wg-overlap path, the last masked block leaves its V copy in flight + while we advance the producer state to start the next full K. Either the full list + overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. + + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + if mask_empty: + # No masked blocks: the full list owns the initial Q+K load. + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Masked blocks present: load Q together with the first masked K so consumers can + # start immediately. When overlap is disabled this fully drains the list. + kv_producer_state = load_block_list( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if full_empty: + if const_expr(intra_wg_overlap): + kv_producer_state = finish_overlap_v_load( + curr_mask_block_idx, + curr_mask_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + if const_expr(intra_wg_overlap): + # Bridge the masked list to the full list by overlapping the pending masked V + # with the first full K load. + n_block_mask_last = curr_mask_block_idx[0] + n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=True, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Non-overlap path with both lists: run the full list normally (skipping the Q + # reload because the masked list already issued it). + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + return kv_producer_state + + +@cute.jit +def consume_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + mask_mod, + intra_wg_overlap: cutlass.Constexpr, + warp_scheduler_barrier_sync: Callable, + warp_scheduler_barrier_arrive: Callable, +): + """Consume the mask and full block lists for a single tile on the consumer side. + + Mirrors `produce_block_sparse_loads` so that the consumer pipeline + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 + + if const_expr(not intra_wg_overlap): + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + if curr_full_block_cnt == 0: + warp_scheduler_barrier_arrive() + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + is_first_n_block=False, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + warp_scheduler_barrier_arrive() + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + kv_consumer_state = process_first_half_block( + n_block=mask_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + kv_consumer_state = process_first_half_block( + n_block=full_n_block, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=None), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_mask_block_cnt + curr_full_block_cnt > 0: + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + + return kv_consumer_state, O_should_accumulate, processed_any diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 1a243e74127..cefb48e7e24 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -13,6 +13,7 @@ import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack + # placeholder Config = type("Config", (), {}) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index e7f93056fca..369bd1c81e6 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -30,6 +30,10 @@ from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd @@ -1835,155 +1839,21 @@ def load( load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() else: - # ========================================== - # Flex Attention blocksparsity - # ========================================== - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - - if const_expr(not self.intra_wg_overlap): - if curr_mask_block_cnt > 0: - # First mask block - load with Q - n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1] - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] - if const_expr(self.use_tma_Q) - else 0, - ) - if const_expr(self.use_tma_Q): - load_Q( - tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) - ) - load_K(src_idx=n_block_mask, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_mask, producer_state=kv_producer_state) - kv_producer_state.advance() - - # Remaining mask blocks - for i in cutlass.range(1, curr_mask_block_cnt): - n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_mask, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_mask, producer_state=kv_producer_state) - kv_producer_state.advance() - - if curr_full_block_cnt > 0: - n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] - if curr_mask_block_cnt == 0: - # must load Q if not loaded in mask loop - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] - if const_expr(self.use_tma_Q) - else 0, - ) - if const_expr(self.use_tma_Q): - load_Q( - tma_bar_ptr=pipeline_k.producer_get_barrier( - kv_producer_state - ) - ) - load_K(src_idx=n_block_full, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_full, producer_state=kv_producer_state) - kv_producer_state.advance() - else: - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_full, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_full, producer_state=kv_producer_state) - kv_producer_state.advance() - for j in cutlass.range(1, curr_full_block_cnt): - n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_full, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_full, producer_state=kv_producer_state) - kv_producer_state.advance() - - else: - # ========================================== - # Overlap path - # ========================================== - - # Load Q with the first K block (whether mask or full) - n_block_first = -1 - if curr_mask_block_cnt > 0: - n_block_first = curr_mask_block_idx[curr_mask_block_cnt - 1] - elif curr_full_block_cnt > 0: - n_block_first = curr_full_block_idx[curr_full_block_cnt - 1] - - if n_block_first >= 0: - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] - if const_expr(self.use_tma_Q) - else 0, - ) - if const_expr(self.use_tma_Q): - load_Q( - tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state) - ) - load_K(src_idx=n_block_first, producer_state=kv_producer_state) - - if curr_mask_block_cnt > 0: - # Staggered loading for remaining mask blocks - for i in cutlass.range(1, curr_mask_block_cnt): - n_block_mask_prev = curr_mask_block_idx[curr_mask_block_cnt - i] - n_block_mask = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] - kv_producer_state_prev = kv_producer_state.clone() - kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_mask, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V( - src_idx=n_block_mask_prev, producer_state=kv_producer_state_prev - ) - - # Handle transition from mask to full blocks - if curr_full_block_cnt > 0: - # Load first full block K, last mask block V - n_block_mask_last = curr_mask_block_idx[0] - n_block_full = curr_full_block_idx[curr_full_block_cnt - 1] - kv_producer_state_prev = kv_producer_state.clone() - kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_full, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V( - src_idx=n_block_mask_last, producer_state=kv_producer_state_prev - ) - else: - # No full blocks, just load last mask block V - n_block_mask_last = curr_mask_block_idx[0] - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state) - kv_producer_state.advance() - - if curr_full_block_cnt > 0: - # Staggered loading for remaining full blocks ( - for j in cutlass.range(1, curr_full_block_cnt): - n_block_full_prev = curr_full_block_idx[curr_full_block_cnt - j] - n_block_full = curr_full_block_idx[curr_full_block_cnt - 1 - j] - kv_producer_state_prev = kv_producer_state.clone() - kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block_full, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V( - src_idx=n_block_full_prev, producer_state=kv_producer_state_prev - ) - - # Load last full block V - n_block_full_last = curr_full_block_idx[0] - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_full_last, producer_state=kv_producer_state) - kv_producer_state.advance() + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + self.use_tma_Q, + self.tma_copy_bytes["Q"], + self.intra_wg_overlap, + ) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -2247,143 +2117,27 @@ def mma( # ========================================== # Block sparsity # ========================================== - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] - - # first masked and full blocks - mask_n_block = 0 - full_n_block = 0 - if curr_mask_block_cnt > 0: - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] - if curr_full_block_cnt > 0: - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] - - if const_expr(not self.intra_wg_overlap): - # ========================================== - # Non-overlap path - # ========================================== - if curr_mask_block_cnt > 0: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=mask_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), - is_first_n_block=True, - ) - O_should_accumulate = True - for i in cutlass.range(1, curr_mask_block_cnt): - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=mask_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - is_first_n_block=False, - ) - if curr_full_block_cnt == 0: - self.warp_scheduler_barrier_arrive() - - if curr_full_block_cnt > 0: - if curr_mask_block_cnt == 0: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=full_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=True), - is_first_n_block=True, - ) - O_should_accumulate = True - else: - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=full_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=True), - is_first_n_block=False, - ) - O_should_accumulate = True - for i in cutlass.range(1, curr_full_block_cnt): - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=full_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), - is_first_n_block=False, - ) - self.warp_scheduler_barrier_arrive() - else: - # ========================================== - # Overlap path - # ========================================== - - # Process first block - if curr_mask_block_cnt > 0: - kv_consumer_state = process_first_half_block( - n_block=mask_n_block, - kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=self.mask_mod), - score_mod_fn=score_mod_fn, - is_first_block=True, - ) - - # Process remaining mask blocks - for i in cutlass.range(1, curr_mask_block_cnt): - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=mask_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - - # Process full blocks - if curr_full_block_cnt > 0: - # If no mask blocks, first full block is the overall first - if curr_mask_block_cnt == 0: - kv_consumer_state = process_first_half_block( - n_block=full_n_block, - kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=None), - score_mod_fn=score_mod_fn, - is_first_block=True, - ) - - else: - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=full_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), - ) - O_should_accumulate = True - - # Process remaining full blocks - for i in cutlass.range(1, curr_full_block_cnt): - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=full_n_block, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), - ) - O_should_accumulate = True - - # Final PV gemm for last block - if curr_mask_block_cnt > 0 or curr_full_block_cnt > 0: - kv_consumer_state = process_last_half_block( - kv_consumer_state=kv_consumer_state, - zero_init=not O_should_accumulate, - ) - O_should_accumulate = True - - if curr_mask_block_cnt + curr_full_block_cnt == 0: + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, + ) + + # Handle empty case (when no blocks to process) + if not processed_any: softmax.reset() acc_O.fill(0.0) @@ -2426,6 +2180,7 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit def first_half_block_overlap( self, From c8abdd432d3b020aad750f9f93f054cb438ec08a Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Sun, 9 Nov 2025 13:12:13 -0800 Subject: [PATCH 214/258] Enable python-3.10+ (#1998) --- .pre-commit-config.yaml | 1 - flash_attn/cute/pyproject.toml | 5 ++- flash_attn/cute/tile_scheduler.py | 64 +++++++++++++++++++++++-------- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bdc9b1b35b..67dcf8ba868 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,6 @@ repos: interface| pack_gqa| testing| - tile_scheduler| utils )\.py$ - id: ruff-format diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index a5d829a908b..1b21df4b227 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -7,7 +7,7 @@ name = "flash-attn-cute" version = "0.1.0" description = "Flash Attention CUTE (CUDA Template Engine) implementation" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.10" license = {text = "BSD 3-Clause License"} authors = [ {name = "Tri Dao"}, @@ -16,6 +16,8 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] @@ -23,6 +25,7 @@ dependencies = [ "nvidia-cutlass-dsl==4.3.0.dev0", "torch", "einops", + "typing_extensions", ] [project.optional-dependencies] diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 1ee11f6d11c..f3a06c186e7 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -2,7 +2,11 @@ from typing import Optional, Tuple from dataclasses import dataclass, fields -from typing import override + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override import cutlass from cutlass._mlir import ir @@ -120,7 +124,11 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" - return cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head * params.num_splits, params.num_batch + return ( + cute.round_up(params.num_block, params.cluster_shape_mn[0]), + params.num_head * params.num_splits, + params.num_batch, + ) def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord @@ -231,7 +239,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx], self._values_pos,): + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) @@ -382,7 +393,9 @@ def create( num_hb_remainder = (args.num_head * args.num_batch) % swizzle num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0]) return SingleTileLPTBwdScheduler.Params( - total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch, + total_blocks=(num_block * args.cluster_shape_mn[0]) + * args.num_head + * args.num_batch, num_head_divmod=FastDivmod.create(args.num_head), l2_minor_divmod=FastDivmod.create(swizzle), l2_major_divmod=FastDivmod.create(swizzle * num_block), @@ -437,9 +450,7 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] - return WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid - ) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) @@ -488,7 +499,9 @@ def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]) + max_kvblock_in_l2 = size_l2 // ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) @@ -610,16 +623,37 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = num_m_blocks * params.tile_shape_mn[0] // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) # Seems faster to have this be a power of 2 - nheads_in_l2 = 16 if num_n_blocks * 16 <= params.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= params.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= params.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1))) + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) nheads_in_l2 = min(nheads_in_l2, params.num_head) mh_in_l2 = nheads_in_l2 * num_m_blocks section_idx = mh_block // mh_in_l2 l2_mod = mh_block - section_idx * mh_in_l2 # Deal with tail section - nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= params.num_head else params.num_head - section_idx * nheads_in_l2 + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) block = l2_mod // nheads_in_this_section head_idx_residual = l2_mod - block * nheads_in_this_section head_idx = section_idx * nheads_in_l2 + head_idx_residual @@ -630,9 +664,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: is_valid = self._is_first_block and batch_idx < params.num_batch # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) - return WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid - ) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) @@ -654,7 +686,9 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos, + for obj, n_items in zip( + [self.params, self._tile_idx, self._split_idx], + self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] From 2ef346bd74357adacbbfb4470d20e5768195e45b Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 11 Nov 2025 22:19:00 -0800 Subject: [PATCH 215/258] [Cute, Bwd, Sm100] Add GQA support (#2004) * add gqa for sm100 bwd * remove mha guard for test * change to cluster size 1 --- flash_attn/cute/flash_bwd_sm100.py | 220 +++++++++++++++++------------ flash_attn/cute/interface.py | 16 ++- tests/cute/test_flash_attn.py | 2 +- 3 files changed, 142 insertions(+), 96 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 5b85c691cd0..3b9aa00cb33 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -47,7 +47,6 @@ def __init__( deterministic: bool = False, cluster_size: int = 1, ): - assert qhead_per_kvhead == 1, "GQA is not supported yet in FlashAttentionBackwardSm100" # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) @@ -163,13 +162,15 @@ def _setup_attributes(self): self.Q_stage = 2 self.dO_stage = 1 # LSE_stage = Q_stage and dPsum_stage = dO_stage - self.sdKVaccum_stage = 2 + # self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma self.dQ_reduce_ncol = 32 self.sdQaccum_stage = 64 // self.dQ_reduce_ncol assert self.tile_hdim % self.dQ_reduce_ncol == 0 self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 + # number of tma reduce adds for dKacc and dVacc epilogue + self.dK_reduce_ncol = 32 def _get_tiled_mma(self): cta_group = tcgen05.CtaGroup.ONE @@ -314,15 +315,23 @@ def _setup_smem_layout(self): ) self.sdKV_epi_tile = ( self.tile_n, - 128 // (self.dk_dtype.width // 8), + 128 // (self.dk_dtype.width // 8), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + self.num_epi_stages = (self.tile_hdim // 2) // self.sdKV_epi_tile[1] + self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages + # TODO: dK and dV could have different shapes - self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( - self.dk_dtype, - LayoutEnum.ROW_MAJOR, - self.sdKV_epi_tile, - self.sdKVaccum_stage, - ) + if const_expr(self.qhead_per_kvhead == 1): + self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dk_dtype, + LayoutEnum.ROW_MAJOR, + self.sdKV_epi_tile, + 2, # num compute wgs + ) + else: + self.sdKV_layout = cute.make_layout( + (self.tile_n * self.dK_reduce_ncol, 2) + ) @cute.jit def __call__( @@ -380,14 +389,21 @@ def __call__( ] layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO, mdK, mdV = [ - utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO, mdK, mdV) + mQ, mK, mV, mdO = [ + utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO) ] LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) mLSE, mdPsum, mdQaccum = [ utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] - dO_transpose = [1, 0, 2, 3] + if const_expr(self.qhead_per_kvhead == 1): + layout_dKV_transpose = layout_transpose + else: + layout_dKV_transpose = LSE_dPsum_dQaccum_transpose + mdK, mdV = [ + utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV) + ] + dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) mdO = utils.select(mdO, mode=dO_transpose) semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) @@ -426,21 +442,18 @@ def __call__( self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.is_q_do_mcast = self.num_mcast_ctas_b > 1 - self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) - self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) - dK_major_mode = self.mdK_layout_enum.mma_major_mode() - dV_major_mode = self.mdV_layout_enum.mma_major_mode() - if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): - raise RuntimeError("The layout of mdK is wrong") - if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): - raise RuntimeError("The layout of mdV is wrong") - - if const_expr(self.use_tma_store): - if const_expr(self.dk_dtype.width == 32): - tma_copy_op_dKV = cpasync.CopyReduceBulkTensorTileS2GOp() - else: - tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() - + if const_expr(self.qhead_per_kvhead == 1): + self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) + self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) + dK_major_mode = self.mdK_layout_enum.mma_major_mode() + dV_major_mode = self.mdV_layout_enum.mma_major_mode() + if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdK is wrong") + if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdV is wrong") + + if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1): + tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdK, @@ -456,24 +469,28 @@ def __call__( 1, # no mcast ) else: - assert self.qhead_per_kvhead == 1, "Must use TMA reduce add for GQA" mdV_tma_tensor = mdV mdK_tma_tensor = mdK tma_atom_dV = None tma_atom_dK = None - thr_layout_r2s_dKV = cute.make_ordered_layout((self.tile_n, 1), order=(1, 0)) # 128 threads - val_layout_r2s_dKV = cute.make_ordered_layout( - (1, 128 // self.dk_dtype.width), order=(1, 0) - ) # 4 or 8 vals for 16 byte store - copy_atom_r2s_dKV = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.dk_dtype, - num_bits_per_copy=128, - ) - tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( - copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV - ) + if const_expr(self.qhead_per_kvhead == 1): + thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads + val_layout_r2s_dKV = cute.make_ordered_layout( + (1, 128 // self.dk_dtype.width), order=(1, 0) + ) # 4 or 8 vals for 16 byte store + copy_atom_r2s_dKV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( + copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV + ) + else: + tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d( + Float32, 128, num_copy_elems=128 // Float32.width + ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) @@ -533,6 +550,7 @@ def __call__( self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 + self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 # TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler TileScheduler = SingleTileScheduler @@ -708,7 +726,7 @@ def kernel( sdS_layout: cute.ComposedLayout, sKt_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - sdKV_layout: cute.ComposedLayout, + sdKV_layout: cute.ComposedLayout | cute.Layout, tP_layout: cute.ComposedLayout, tdS_layout: cute.ComposedLayout, tiled_mma_S: cute.TiledMma, @@ -871,12 +889,16 @@ def kernel( sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, sdOt_layout.inner), sdOt_layout.outer) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) - sdV = storage.sdO.get_tensor( - sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype - ) - sdK = storage.sQ.get_tensor( - sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype - ) + if const_expr(self.qhead_per_kvhead == 1): + sdV = storage.sdO.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype + ) + sdK = storage.sQ.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype + ) + else: + sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype) + sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype) assert cute.size_in_bytes(self.do_dtype, sdO_layout) >= cute.size_in_bytes( self.dv_dtype, sdKV_layout ), "Not enough space for sdV" @@ -1930,7 +1952,7 @@ def compute_loop( thr_copy_r2s_dKV, pipeline_dKV, consumer_state_dKV, - softmax_scale, + softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, ) @@ -2228,32 +2250,53 @@ def epilogue_dK_or_dV_tma( num_wg = num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 - sdKV = sdKV[None, None, wg_idx] + if const_expr(self.qhead_per_kvhead == 1): + sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 + else: + sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 + + # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) + tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) head_idx_kv = head_idx // self.qhead_per_kvhead - mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] - - gdKV_p = cute.local_tile(mdKV_cur, (self.tile_m, self.tile_hdimv), (n_block, 0)) - gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) - gdKV_epi = cute.local_tile(gdKV, self.sdKV_epi_tile, (0, None)) + if const_expr(self.qhead_per_kvhead == 1): + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) + ) # (tile_n, hdim) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) + gdKV_epi = cute.local_tile( + gdKV, self.sdKV_epi_tile, (0, None) + ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) + else: + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n * self.tile_hdim, ), (n_block, ) + ) # (tile_n * hdim) + gdKV = cute.logical_divide( + gdKV_p, (self.tile_n * self.tile_hdim // num_wg, ) + )[((None, wg_idx), )] # (tile_n * hdim / 2) + gdKV_epi = cute.flat_divide( + gdKV, (self.sdKV_flat_epi_tile, ) + ) # (tile_n * hdim / 2 / epi_stage, epi_stage) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] - # (TMA) and (TMA, EPI_STAGE) - tdKVsdKV, tdKVgdKV = cpasync.tma_partition( - tma_atom_dKV, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sdKV, 0, 2), - cute.group_modes(gdKV_epi, 0, 2), - ) - - assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" - assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" - - num_epi_stages = cute.size(tdKVgdKV.shape[1]) - assert num_epi_stages == 1 or num_epi_stages == 2, "Wrong number of epi stages" + if const_expr(self.qhead_per_kvhead == 1): + tdKVsdKV, tdKVgdKV = cpasync.tma_partition( + tma_atom_dKV, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sdKV, 0, 2), + cute.group_modes(gdKV_epi, 0, 2), + ) # (TMA) and (TMA, EPI_STAGE) + assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" + assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" + num_epi_stages = cute.size(tdKVgdKV.shape[1]) + assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong" + else: + num_epi_stages = self.num_epi_stages tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 @@ -2270,20 +2313,20 @@ def epilogue_dK_or_dV_tma( ) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) - for s in cutlass.range_constexpr(num_epi_stages): + for epi_stage in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): - tdKVtdKV_t2r = tdKVtdKV_t2r[None, s] + tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage] cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): - tdKVcdKV_t2r = tdKVcdKV_t2r[None, s] + tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage] tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) @@ -2301,30 +2344,11 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) ) - tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) - # RMEM -> SMEM -- setup - tdKVcdKV_r2s_p = thr_copy_r2s_dKV.partition_S(cdKV) - tdKVcdKV_r2s = self.split_wg(tdKVcdKV_r2s_p, wg_idx, num_wg) - tdKVcdKV_r2s = cute.logical_divide( - tdKVcdKV_r2s, - ( - tdKVcdKV_r2s.shape[0], - tdKVcdKV_r2s.shape[1], - tdKVcdKV_r2s.shape[2] // num_epi_stages, - ), - )[((None, 0), (None, 0), (None, s))] - - tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVcdKV_r2s.shape) - - tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) - - assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), ( - "RMEM<->SMEM fragment size mismatch" - ) - # RMEM -> SMEM -- copy, fence and barrier + tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta @@ -2333,8 +2357,16 @@ def epilogue_dK_or_dV_tma( # SMEM -> GMEM if leader_warp: - cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, s]) - if s < num_epi_stages - 1: + if const_expr(self.qhead_per_kvhead == 1): + cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage]) + else: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdKV.iterator, + gdKV_epi[None, epi_stage].iterator, + self.tma_copy_bytes["dKacc"], + ) + if const_expr(epi_stage < num_epi_stages - 1): cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) cute.arch.barrier_arrive( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 2158cb51933..ce32f567e97 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -562,11 +562,16 @@ def _flash_attn_bwd( AtomLayoutMSdP = 1 AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 + cluster_size = 1 else: m_block_size = 128 n_block_size = 128 dQ_swapAB = False + dKV_swapAB = False AtomLayoutMdQ = 1 + AtomLayoutNdKV = 1 + # TODO: support cluster size 2 + cluster_size = 1 q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -637,6 +642,8 @@ def _flash_attn_bwd( qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 + if compute_capability == 10: + pack_gqa = False # override for now device = q.device # TODO: check if this is the right rounding @@ -675,6 +682,9 @@ def _flash_attn_bwd( head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + num_n_blocks = seqlen_k_rounded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + seqlen_k_rounded = seqlen_k_rounded + n_block_size dk_accum = torch.zeros( batch_size, num_head_kv, @@ -693,6 +703,9 @@ def _flash_attn_bwd( total_k_rounded_padded = ( (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size ) + num_n_blocks = total_k_rounded_padded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + total_k_rounded_padded = total_k_rounded_padded + n_block_size dk_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_rounded, @@ -802,6 +815,7 @@ def _flash_attn_bwd( n_block_size, num_threads, pack_gqa, + cluster_size, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -854,7 +868,7 @@ def _flash_attn_bwd( qhead_per_kvhead=qhead_per_kvhead, # tile_m=m_block_size, # tile_n=n_block_size, - cluster_size=2, + cluster_size=cluster_size, # cluster_size=1, ) # TODO: check @can_implement diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 481e22f731b..6c264c30f55 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -265,7 +265,7 @@ def test_flash_attn_output( and not local and dv == d and learnable_sink is None - and mha_type == "mha" + # and mha_type == "mha" # and False ): g = torch.randn_like(out) From 13380067063e1861f6bd355efec2b8d369c01ecf Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 11 Nov 2025 23:04:25 -0800 Subject: [PATCH 216/258] [Cute,Fwd,Sm100] fix major regression with split kv (#2006) --- flash_attn/cute/flash_fwd_sm100.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6e030b17615..c4a569fa0d1 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1162,7 +1162,7 @@ def load( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 page_idx = ( mPageTable[batch_idx, n_block_max - 1] @@ -1255,7 +1255,7 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 @@ -1493,7 +1493,7 @@ def softmax_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( mask.apply_mask_sm100, @@ -1807,7 +1807,7 @@ def correction_loop( # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: # Ignore first signal from softmax as no correction is required cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase @@ -2132,7 +2132,7 @@ def epilogue_s2g( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: From 16d78bb2e32fc805238b4eddc7085aa79c941ffe Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Wed, 12 Nov 2025 18:07:30 -0500 Subject: [PATCH 217/258] [CuTe DSL] Block sparsity computation kernel (#1983) * begin block sparsity computation kernel * block sparsity computation kernel and benchmark working * loop range_constexpr * add fast kernel * merge fast and regular kernel * use TensorSSA approach to mask mod * update with OOB check * tests and benchmarks for block sparsity working * remove extraneous files * Revert mask.py to previous state - removing unintended changes from block sparsity work * remove flex attn test stub * add sleeps to benchmark * correct block sparsity benchmark to use torch.compile * Restore missing mask definitions and fix benchmark window_size handling * move benchmarks into new directory * compute_block_sparsity docstring * streamline compute block sparsity benchmark script --- benchmarks/cute/benchmark_block_sparsity.py | 363 +++++++++++++++ .../cute/benchmark_mask_mod.py | 16 +- flash_attn/cute/compute_block_sparsity.py | 403 +++++++++++++++++ flash_attn/cute/interface.py | 2 + flash_attn/cute/mask_definitions.py | 50 +++ tests/cute/test_block_sparsity.py | 422 ++++++++++++++++++ 6 files changed, 1248 insertions(+), 8 deletions(-) create mode 100644 benchmarks/cute/benchmark_block_sparsity.py rename {flash_attn => benchmarks}/cute/benchmark_mask_mod.py (98%) create mode 100644 flash_attn/cute/compute_block_sparsity.py create mode 100644 tests/cute/test_block_sparsity.py diff --git a/benchmarks/cute/benchmark_block_sparsity.py b/benchmarks/cute/benchmark_block_sparsity.py new file mode 100644 index 00000000000..74f220e8795 --- /dev/null +++ b/benchmarks/cute/benchmark_block_sparsity.py @@ -0,0 +1,363 @@ +""" +Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. +""" + +import torch +from dataclasses import dataclass +from typing import Callable, Optional, List +from tabulate import tabulate +from tqdm import tqdm +import itertools + +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.testing import benchmark as cute_benchmark +import cutlass.cute as cute +from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.mask_definitions import ( + get_mask_pair, + random_doc_id_tensor, + flex_document_mask, + cute_document_mask, +) + +from torch.nn.attention.flex_attention import create_block_mask +from triton.testing import do_bench + +# Configure torch.compile cache to prevent memory buildup +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + batch_size: int + num_heads: int + seqlen_q: int + seqlen_k: int + mask_name: str + tile_m: int = 128 + tile_n: int = 128 + use_fast_sampling: bool = False + aux_tensors_cute: Optional[list] = None + + +@dataclass(frozen=True) +class BenchmarkResult: + """Result of a single benchmark run.""" + + config: BenchmarkConfig + cute_time_ms: Optional[float] + pytorch_time_ms: Optional[float] + error_message: Optional[str] = None + + +def benchmark_pytorch_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark PyTorch block mask creation (compiled). + Returns: creation_time_ms + """ + device = "cuda" + + try: + cbm = torch.compile(create_block_mask) + + def run_benchmark(): + return cbm( + mask_fn, + config.batch_size, + config.num_heads, + config.seqlen_q, + config.seqlen_k, + device=device, + ) + + creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) + + return creation_time_ms + + except Exception as e: + print(f"PyTorch benchmark failed ({config.mask_name}): {e}") + import traceback + traceback.print_exc() + return None + + +def benchmark_cute_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark CuTe block sparsity kernel. + Returns: creation_time_ms + """ + device = "cuda" + + try: + num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m + num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + mask_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + full_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + + # Convert to CuTe tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + # Create kernel + use_aux = config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 + kernel = BlockSparsityKernel( + mask_mod=mask_fn, + tile_mn=(config.tile_m, config.tile_n), + compute_full_blocks=True, + use_aux_tensors=use_aux, + use_fast_sampling=config.use_fast_sampling, + ) + + # Compile kernel + compiled_kernel = cute.compile( + kernel, + blocksparse_tensors, + config.seqlen_q, + config.seqlen_k, + config.aux_tensors_cute, + ) + + def generate_tensors(): + from cutlass.cute.testing import JitArguments + + return JitArguments( + blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute + ) + + creation_time_us = cute_benchmark( + compiled_kernel, + workspace_generator=generate_tensors, + warmup_iterations=10, + iterations=100, + ) + + torch.cuda.synchronize(device) + creation_time_ms = creation_time_us / 1000.0 + + return creation_time_ms + + except Exception as e: + print(f"CuTe benchmark failed: {e}") + return None + + +def run_benchmark( + config: BenchmarkConfig, + pytorch_mask_fn: Callable, + cute_mask_fn: Callable, +) -> BenchmarkResult: + """Run benchmarks for both implementations.""" + + print( + f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " + f"M={config.seqlen_q}, N={config.seqlen_k}" + ) + + # Benchmark PyTorch + pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) + + # Benchmark CuTe + cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) + + return BenchmarkResult( + config=config, + cute_time_ms=cute_time, + pytorch_time_ms=pytorch_time, + ) + + +def generate_configs( + batch_sizes: List[int], + num_heads: List[int], + seqlens: List[int], + mask_names: List[str], +) -> List[BenchmarkConfig]: + """Generate all benchmark configurations.""" + configs = [] + for B, H, S, mask_name in itertools.product(batch_sizes, num_heads, seqlens, mask_names): + configs.append( + BenchmarkConfig( + batch_size=B, + num_heads=H, + seqlen_q=S, + seqlen_k=S, + mask_name=mask_name, + ) + ) + return configs + + +def print_results(results: List[BenchmarkResult]): + successful_results = [ + r for r in results if r.cute_time_ms is not None and r.pytorch_time_ms is not None + ] + + if not successful_results: + print("No successful benchmark results to display") + return + + headers = ["B", "H", "M", "N", "Mask Type", "CuTe Time (ms)", "PyTorch Time (ms)", "Speedup"] + + rows = [] + for result in successful_results: + speedup = result.pytorch_time_ms / result.cute_time_ms if result.cute_time_ms > 0 else 0 + + rows.append( + [ + result.config.batch_size, + result.config.num_heads, + result.config.seqlen_q, + result.config.seqlen_k, + result.config.mask_name, + f"{result.cute_time_ms:.4f}", + f"{result.pytorch_time_ms:.4f}", + f"{speedup:.2f}x", + ] + ) + + # Sort by batch, head, seqlen, then mask type + rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) + + print("\n" + "=" * 100) + print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") + print("=" * 100) + print(tabulate(rows, headers=headers, tablefmt="github")) + print("=" * 100) + + +def main(): + """Run the comparative benchmark.""" + + # Configuration + batch_sizes = [1, 4, 8] + num_heads = [8, 16] + seqlens = [1024, 2048, 4096, 8192] + mask_names = [ + "causal", + "sliding_window", + "prefix_lm", + "dilated_sliding_window", + "document", + ] + + device = "cuda" + max_seqlen = max(seqlens) + max_batch = max(batch_sizes) + max_heads = max(num_heads) + + # Create document IDs using the helper from mask_definitions + doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) + doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + + # Generate base configurations + base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) + + # Update configs with aux tensors for document masking + configs = [] + for config in base_configs: + if config.mask_name == "document": + # Add aux tensors for document masking + configs.append( + BenchmarkConfig( + batch_size=config.batch_size, + num_heads=config.num_heads, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + mask_name=config.mask_name, + tile_m=config.tile_m, + tile_n=config.tile_n, + use_fast_sampling=False, + aux_tensors_cute=[doc_ids_cute], + ) + ) + else: + configs.append(config) + + # Run benchmarks + results = [] + print(f"Running {len(configs)} benchmark configurations...") + for config in tqdm(configs, desc="Benchmarking"): + try: + # Get mask pair from mask_definitions + mask_kwargs = {} + if config.mask_name == "sliding_window": + mask_kwargs["window_size"] = 128 # Default window size + + cute_mask_fn, pytorch_mask_fn = get_mask_pair( + config.mask_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + **mask_kwargs, + ) + + # For document masking, create wrapper that captures doc_ids + if config.mask_name == "document": + # PyTorch wrapper + def pytorch_mask_fn(b, h, q, kv): + return flex_document_mask(b, h, q, kv, doc_ids) + # CuTe wrapper - reuse cute_document_mask with aux_tensors + cute_mask_fn = cute_document_mask + + result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) + results.append(result) + + except Exception as e: + print(f"Failed to run config {config}: {e}") + results.append( + BenchmarkResult( + config=config, + cute_time_ms=None, + pytorch_time_ms=None, + error_message=str(e), + ) + ) + finally: + torch.cuda.empty_cache() + torch._dynamo.reset() + + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/flash_attn/cute/benchmark_mask_mod.py b/benchmarks/cute/benchmark_mask_mod.py similarity index 98% rename from flash_attn/cute/benchmark_mask_mod.py rename to benchmarks/cute/benchmark_mask_mod.py index 88db8418abc..348d2ee485d 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/benchmarks/cute/benchmark_mask_mod.py @@ -14,8 +14,8 @@ import numpy as np import torch -from flash_fwd import FlashAttentionForwardSm90 -from mask_definitions import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute.mask_definitions import ( get_mask_pair, random_doc_id_tensor, ) @@ -74,8 +74,8 @@ class BenchmarkConfig: mma_pv_is_rs: bool = True # Benchmark parameters - warmup_iters: int = 5 - benchmark_iters: int = 20 + warmup_iters: int = 10 + benchmark_iters: int = 25 verbose: bool = False seed: int = 42 @@ -649,16 +649,16 @@ def _print_results(self, results: Dict[str, Any]): dtype=torch.bfloat16, batch_size=B, # batch_size=1, - seqlen_q=16384 // B, + seqlen_q=8192, # seqlen_q=128, - seqlen_k=16384 // B, + seqlen_k=8192, # seqlen_k=192, use_varlen=False, - use_mask_mod=True, + use_mask_mod=False, mask_mod_name="causal", window_size=128, # Configurable window size for mask_mod use_learnable_sink=False, - causal=False, + causal=True, is_local=False, verbose=True, ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py new file mode 100644 index 00000000000..bec6fe5701f --- /dev/null +++ b/flash_attn/cute/compute_block_sparsity.py @@ -0,0 +1,403 @@ +from functools import partial +import math +import operator +from typing import Callable, Optional, Tuple, Type + +import cuda.bindings.driver as cuda +import cutlass +from cutlass import Boolean, Constexpr, Float32, Int32, Int8, const_expr +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import torch + +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar + + +class BlockSparsityKernel: + """Block sparsity kernel for FlexAttention. + + This kernel computes `mask_mod` for every token of each block + to determine if an n block is full, masked, or neither. + + Writes block counts and indices to a BlockSparseTensors object. + + When use_fast_sampling=True, uses 5-point sampling (4 corners + center) + which is much faster but only suitable for masks where this is sufficient. + """ + + def __init__( + self, + mask_mod: Callable, + tile_mn: Tuple[int, int], + compute_full_blocks: bool = True, + use_aux_tensors: bool = False, + use_fast_sampling: bool = False, + ): + self.mask_mod = mask_mod + self.tile_mn = tile_mn + self.compute_full_blocks = compute_full_blocks + self.use_aux_tensors = use_aux_tensors + self.use_fast_sampling = use_fast_sampling + + @cute.jit + def __call__( + self, + blocksparse_tensors: BlockSparseTensors, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + + if const_expr(self.compute_full_blocks): + assert self.full_cnt is not None and self.full_idx is not None, ( + "full block tensors must be provided when computing full blocks" + ) + + batch_size, num_heads, num_m_blocks, num_n_blocks = list(self.mask_idx.shape) + grid = [num_m_blocks, num_heads, batch_size] + + # Fast sampling uses only 5 threads (4 corners + center), full sampling uses 1 thread per row + if const_expr(self.use_fast_sampling): + num_threads = 5 + self.num_warps = 1 + else: + num_threads = self.tile_mn[0] + self.num_warps = (num_threads + 32 - 1) // 32 + + self.kernel( + self.mask_cnt, + self.mask_idx, + self.full_cnt, + self.full_idx, + num_n_blocks, + seqlen_q, + seqlen_k, + aux_tensors, + ).launch(grid=grid, block=[num_threads, 1, 1]) + + @cute.kernel + def kernel( + self, + mask_cnt: cute.Tensor, + mask_idx: cute.Tensor, + full_cnt: cute.Tensor, + full_idx: cute.Tensor, + num_n_blocks: Int32, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + # Store seqlens as instance variables for use in the kernel + self.seqlen_q = seqlen_q + self.seqlen_k = seqlen_k + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + m_block, head_idx, batch_idx = cute.arch.block_idx() + + ssa = partial(scalar_to_ssa, dtype=Int32) + + @cute.struct + class SharedStorage: + reduction_buffer_smem: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024 + ] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage, 16) + + reduction_buffer = storage.reduction_buffer_smem.get_tensor( + cute.make_layout((self.num_warps, 2)) + ) + + num_mask_blocks = Int32(0) + num_full_blocks = Int32(0) + + for n_block in cutlass.range(num_n_blocks, unroll_full=True): + m_base = m_block * self.tile_mn[0] + n_base = n_block * self.tile_mn[1] + + if const_expr(self.use_fast_sampling): + # Fast path: 5-point sampling (4 corners + center) + # Out-of-bounds indices are treated as masked (False) + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + q_idx = Int32(0) + kv_idx = Int32(0) + + if tidx == 0: + # Top-left corner (0, 0) + q_idx = m_base + kv_idx = n_base + elif tidx == 1: + # Top-right corner + q_idx = m_base + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 2: + # Bottom-left corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + elif tidx == 3: + # Bottom-right corner + q_idx = m_base + self.tile_mn[0] - 1 + kv_idx = n_base + self.tile_mn[1] - 1 + elif tidx == 4: + # Center point + q_idx = m_base + self.tile_mn[0] // 2 + kv_idx = n_base + self.tile_mn[1] // 2 + + # Check bounds and determine if this thread has a valid index pair + if q_idx < self.seqlen_q and kv_idx < self.seqlen_k: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + kv_idx_ssa = ssa(kv_idx) + thread_result = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), ssa(head_idx), q_idx_ssa, kv_idx_ssa, aux_tensors + ) + ) + else: + thread_is_valid = Boolean(False) + + # Use vote_any_sync to see if any valid thread found unmasked or masked + # Only count results from threads that checked valid indices + has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) + has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + + else: + # Full path: check all elements in the block + # Track if this thread's row has any masked or unmasked elements + thread_has_unmasked = Boolean(False) + thread_has_masked = Boolean(False) + thread_is_valid = Boolean(False) + + # Each thread handles 1 row + q_idx = m_base + tidx + kv_idx = Int32(0) + if tidx < self.tile_mn[0] and q_idx < self.seqlen_q: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + + # Loop over all columns in this row + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + kv_idx_ssa = ssa(kv_idx) + + # Only check elements within valid sequence bounds + if kv_idx < self.seqlen_k: + # Direct scalar call + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + q_idx_ssa, + kv_idx_ssa, + aux_tensors, + ) + ) + + # Update tracking flags + if mask_val: + thread_has_unmasked = Boolean(True) + else: + thread_has_masked = Boolean(True) + + # Block-level reduction to combine results across all threads + # Only count votes from threads that checked valid indices + warp_has_unmasked_mask = cute.arch.vote_any_sync( + thread_has_unmasked & thread_is_valid + ) + warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) + + # lane 0 writes the ballot mask to shared memory + lane_id = tidx % 32 + if lane_id == 0: + # Store as Int8 + reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) + + cute.arch.sync_threads() + + # Thread 0 ORs all warp results together + has_unmasked = Boolean(False) + has_masked = Boolean(False) + if tidx == 0: + for w in cutlass.range(self.num_warps): + if reduction_buffer[w, 0]: + has_unmasked = Boolean(True) + if reduction_buffer[w, 1]: + has_masked = Boolean(True) + + # Only thread 0 updates the output arrays (common to both paths) + if tidx == 0: + # Block classification based on what we found: + # - If has_masked and has_unmasked: partial block (needs masking) + # - If only has_unmasked: full block (no masking needed) + # - If only has_masked: skip this block entirely + is_partial = Boolean(has_masked and has_unmasked) + is_full = Boolean(has_unmasked and (not has_masked)) + + if is_partial: + mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + num_mask_blocks += 1 + elif is_full and const_expr(self.compute_full_blocks): + full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + num_full_blocks += 1 + + # Only thread 0 writes back the counts + if tidx == 0: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + + +def compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + mask_mod: Callable, + aux_tensors: Optional[list], # list[cute.Tensor] + device, + compute_full_blocks: bool = True, + use_fast_sampling: bool = False, +) -> Tuple[BlockSparseTensors, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes block sparsity for a given `mask_mod`. + + Args: + tile_m: The tile size for the m dimension. + tile_n: The tile size for the n dimension. + batch_size: The batch size. + num_heads: The number of heads. + seqlen_q: The sequence length for the query. + seqlen_k: The sequence length for the key. + mask_mod: The `mask_mod` callable to use. + aux_tensors: A list of auxiliary tensors. + device: The device to use. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. + + Returns: + A tuple of `BlockSparseTensors` and the underlying torch tensors. + """ + num_m_blocks = (seqlen_q + tile_m - 1) // tile_m + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + full_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + full_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + + # Convert to cute tensors + mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + mask_mod_hash = hash_callable(mask_mod) + + compile_key = ( + tile_m, + tile_n, + mask_mod_hash, + compute_full_blocks, + aux_tensors is not None, + use_fast_sampling, + ) + if compile_key not in compute_block_sparsity.compile_cache: + kernel = BlockSparsityKernel( + mask_mod, + tile_mn=(tile_m, tile_n), + compute_full_blocks=True, + use_aux_tensors=aux_tensors is not None, + use_fast_sampling=use_fast_sampling, + ) + + compute_block_sparsity.compile_cache[compile_key] = cute.compile( + kernel, + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + compute_block_sparsity.compile_cache[compile_key]( + blocksparse_tensors, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + # Return both the BlockSparseTensors (cute) and the underlying torch tensors + return blocksparse_tensors, (full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx) + + +compute_block_sparsity.compile_cache = {} + + +def run(): + """Test the BlockSparsityKernel with a simple causal mask.""" + + print("Testing BlockSparsityKernel...") + + # Configuration + batch_size = 2 + num_heads = 2 + seqlen_q = 16384 + seqlen_k = 16384 + tile_m, tile_n = 128, 128 # Use very small tiles for initial testing + + # Define a simple causal mask function + @cute.jit + def causal_mask(batch_idx, head_idx, q_idx, kv_idx, aux_tensors): + """Simple causal mask: only attend to positions <= current position.""" + return q_idx >= kv_idx + + try: + compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + causal_mask, + None, + device="cuda", + ) + print("Kernel execution completed!") + except Exception as e: + print(f"Kernel execution failed: {e}") + + +if __name__ == "__main__": + run() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ce32f567e97..4989067b8c1 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -106,6 +106,8 @@ def _flash_attn_fwd( Args: ... score_mod: A callable that takes the attention scores and applies a modification. + mask_mod: A callable that takes token position information and selectively masks + block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 0bb0d56751a..bbf2d212c0c 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -153,6 +153,54 @@ def cute_mini_causal_mask( return m_mod >= n_mod +@cute.jit +def cute_prefix_lm_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32) + both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa) + causal_part = m_idx >= n_idx + return both_in_prefix | causal_part + + +def flex_prefix_lm_mask(b, h, q_idx, kv_idx): + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size = 512 + both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size) + causal_part = q_idx >= kv_idx + return both_in_prefix | causal_part + + +@cute.jit +def cute_dilated_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + """Dilated sliding window: every other position in a 256-position window.""" + window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32) + dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32) + in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa) + dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32) + return in_window & dilated + + +def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): + """Dilated sliding window: every other position in a 256-position window.""" + window_size = 256 + dilation = 2 + in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size) + dilated = ((q_idx - kv_idx) % dilation) == 0 + return in_window & dilated + + def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) for b in range(batch): @@ -175,6 +223,8 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): STATIC_MASKS = { "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), + "dilated_sliding_window": (cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask), "document": (cute_document_mask, flex_document_mask), } diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py new file mode 100644 index 00000000000..d1ac5318004 --- /dev/null +++ b/tests/cute/test_block_sparsity.py @@ -0,0 +1,422 @@ +"""Tests for block sparsity computation in flash attention.""" + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask + +from flash_attn.cute.mask_definitions import get_mask_pair +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + + +def _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity and return torch tensors.""" + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + blocksparse_tensors, torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + use_fast_sampling=use_fast_sampling, + ) + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = torch_tensors + return full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + + +def _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, +): + """Compare block sparsity against reference. Returns (all_match, error_msg).""" + if not isinstance(mask_block_cnt, torch.Tensor): + return False, f"mask_block_cnt is not a tensor: {type(mask_block_cnt)}" + + n_blocks_q = mask_block_cnt.shape[2] + mask_cnt_match = torch.all(mask_block_cnt == mask_block_cnt_ref).item() + full_cnt_match = torch.all(full_block_cnt == full_block_cnt_ref).item() + + if not mask_cnt_match or not full_cnt_match: + error_msg = [] + if not mask_cnt_match: + error_msg.append("Mask counts mismatch") + diff = (mask_block_cnt != mask_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {mask_block_cnt[b, h, m].item()}, " + f"expected {mask_block_cnt_ref[b, h, m].item()}" + ) + if not full_cnt_match: + error_msg.append("Full counts mismatch") + diff = (full_block_cnt != full_block_cnt_ref).nonzero(as_tuple=False) + if len(diff) > 0: + b, h, m = diff[0].tolist() + error_msg.append( + f" First mismatch at [{b},{h},{m}]: " + f"got {full_block_cnt[b, h, m].item()}, " + f"expected {full_block_cnt_ref[b, h, m].item()}" + ) + return False, "\n".join(error_msg) + + # Compare indices + for b in range(batch_size): + for h in range(nheads): + for m in range(n_blocks_q): + num_mask = mask_block_cnt[b, h, m].item() + num_full = full_block_cnt[b, h, m].item() + + if num_mask > 0: + mask_indices = mask_block_idx[b, h, m, :num_mask].sort()[0] + mask_indices_ref = mask_block_idx_ref[b, h, m, :num_mask].sort()[0] + if not (mask_indices == mask_indices_ref).all(): + return False, f"Mask indices mismatch at [{b},{h},{m}]" + + if num_full > 0: + full_indices = full_block_idx[b, h, m, :num_full].sort()[0] + full_indices_ref = full_block_idx_ref[b, h, m, :num_full].sort()[0] + if not (full_indices == full_indices_ref).all(): + return False, f"Full indices mismatch at [{b},{h},{m}]" + + return True, "" + + +# Test configurations +SEQLEN_PAIRS = [ + # Small aligned + (64, 64), + (128, 128), + (256, 256), + (512, 512), + # Rectangular + (128, 256), + (256, 128), + (512, 256), + (256, 512), + # Large aligned + (1024, 1024), + (2048, 2048), + (4096, 4096), + # Large unaligned + (1000, 1000), + (2000, 2000), + (4000, 4000), + # Edge cases with unaligned seqlens + (113, 203), + (127, 127), + (129, 129), + (255, 255), + (257, 257), + (1023, 1023), + (1025, 1025), + (2047, 2047), + (2049, 2049), +] +TILE_SIZES = [ + # Standard powers of 2 + (32, 32), + (64, 64), + (128, 128), + (256, 256), + # Rectangular + (32, 64), + (64, 32), + (64, 128), + (128, 64), + (128, 256), + (256, 128), + # Unusual sizes + (40, 40), + (48, 48), + (96, 96), + (112, 112), + (32, 128), + (128, 32), + (40, 96), + (96, 40), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize("tile_m,tile_n", TILE_SIZES) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal"]) +def test_fixed_length_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name +): + """Test fixed-length masks.""" + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_parameterized_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size +): + """Test parameterized masks.""" + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + ) + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k,tile_m,tile_n", + [ + (1, 1, 64, 64), + (63, 63, 64, 64), + (65, 65, 64, 64), + (129, 129, 128, 128), + (100, 200, 64, 128), + ], +) +def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): + """Test edge cases with unaligned dimensions.""" + batch_size, nheads = 1, 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + "causal", + ) + ) + + _, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): + """Test fast sampling mode (5-point sampling).""" + batch_size = 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + use_fast_sampling=True, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + ) + + if seqlen_unaligned and not all_match: + pytest.skip(f"Skipping at seqlen extreme: {error_msg}") + assert all_match, f"Mismatch: {error_msg}" From fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89 Mon Sep 17 00:00:00 2001 From: Johnny Date: Thu, 13 Nov 2025 07:38:39 +0100 Subject: [PATCH 218/258] [NVIDIA] bump github actions (#1996) * Update GitHub Actions to use checkout@v5 and setup-python@v6; enhance compute capability support * revert changes * revert * Update publish.yml * Update publish.yml * Update publish.yml * Update publish.yml * cuda-toolkit@v0.2.29 --- .github/workflows/_build.yml | 4 ++-- .github/workflows/pre-commit.yaml | 4 ++-- .github/workflows/publish.yml | 15 +++++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 3bbd5f0a4f5..8c529583c72 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -43,7 +43,7 @@ jobs: name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: ref: ${{ inputs.release-version }} submodules: recursive @@ -77,7 +77,7 @@ jobs: - name: Install CUDA ${{ inputs.cuda-version }} if: ${{ inputs.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.27 + uses: Jimver/cuda-toolkit@v0.2.29 id: cuda-toolkit with: cuda: ${{ inputs.cuda-version }} diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 1613bb365bd..bc304a5641a 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -22,10 +22,10 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 26013ad5d67..47f374ade99 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -41,8 +41,8 @@ jobs: # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-22.04, ubuntu-22.04-arm] - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] + python-version: ["3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.1"] cuda-version: ["12.9.1"] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -50,8 +50,11 @@ jobs: # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ["FALSE", "TRUE"] include: - - torch-version: "2.9.0.dev20250904" - cuda-version: "13.0.0" + - torch-version: "2.9.1" + cuda-version: "13.0.2" + python-version: "3.14" + - torch-version: "2.10.0.dev20251108" + cuda-version: "13.0.2" exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 @@ -72,8 +75,8 @@ jobs: needs: [build_wheels] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies From 5d2cd3bcbaeff6fe1bfc5d0ff489451b0d4827a6 Mon Sep 17 00:00:00 2001 From: timmy-feng <70349932+timmy-feng@users.noreply.github.com> Date: Fri, 14 Nov 2025 08:43:37 -0800 Subject: [PATCH 219/258] [Cute,Fwd,Sm100] Support paged attention (#1999) * modal bench and correctness * implement for one thread per row * coalesced(?) gmem loads * use cp async * use 64 threads to load * fill in smem for V * pass tests * fixes * removed extra files * handle V loading for n_block < 0 --- flash_attn/cute/flash_fwd_sm100.py | 246 +++++++++++++++++++---------- flash_attn/cute/interface.py | 5 +- flash_attn/cute/mask.py | 10 ++ flash_attn/cute/paged_kv.py | 176 +++++++++++++++++++++ tests/cute/test_flash_attn.py | 6 +- 5 files changed, 354 insertions(+), 89 deletions(-) create mode 100644 flash_attn/cute/paged_kv.py diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index c4a569fa0d1..915315d461b 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -27,6 +27,7 @@ import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic +from flash_attn.cute.paged_kv import PagedKVManager import flash_attn.cute.utils as utils from flash_attn.cute import copy_utils import flash_attn.cute.pipeline as pipeline @@ -76,7 +77,9 @@ def __init__( is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, + paged_kv_non_tma: bool = False, ): + self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 @@ -127,11 +130,15 @@ def __init__( if self.overlap_sO_sQ: self.is_persistent = False + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 - self.load_warp_id = 13 + self.load_warp_ids = (13,) self.epilogue_warp_ids = (14,) self.empty_warp_ids = (15,) SM100_TMEM_CAPACITY_COLUMNS = 512 @@ -143,7 +150,7 @@ def __init__( *self.softmax1_warp_ids, *self.correction_warp_ids, self.mma_warp_id, - self.load_warp_id, + *self.load_warp_ids, *self.epilogue_warp_ids, *self.empty_warp_ids, ) @@ -449,11 +456,20 @@ def __call__( mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) ) + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, sQ_layout), + ("K", mK, sK_layout), + ("V", mV, sV_layout), + ] + } + # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_A( + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mQ, cute.select(sQ_layout, mode=[0, 1, 2]), @@ -462,24 +478,32 @@ def __call__( self.cluster_layout_vmnk.shape, ) - # TMA load for K - tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, - mK, - cute.select(sK_layout, mode=[0, 1, 2]), - self.mma_tiler_qk, - tiled_mma_qk, - self.cluster_layout_vmnk.shape, - ) - # TMA load for V - tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_B( - tma_load_op, - mV, - cute.select(sV_layout, mode=[0, 1, 2]), - self.mma_tiler_pv, - tiled_mma_pv, - self.cluster_layout_vmnk.shape, - ) + if const_expr(self.use_tma_KV): + # TMA load for K + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + else: + assert self.use_tma_O, "Loading O and K/V will contend for the empty warp." + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + tma_atom_K = None + tma_atom_V = None o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) @@ -514,15 +538,7 @@ def __call__( assert self.m_block_size % tO_layout.shape[0] == 0 vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - - self.tma_copy_bytes = { - name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) - for name, mX, layout in [ - ("Q", mQ, sQ_layout), - ("K", mK, sK_layout), - ("V", mV, sV_layout), - ] - } + print("gmem_tiled_copy_O: ", gmem_tiled_copy_O) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -638,9 +654,9 @@ class SharedStorage: # Launch the kernel synchronously self.kernel( - tma_tensor_Q, - tma_tensor_K, - tma_tensor_V, + mQ, + mK, + mV, mO, mLSE, mCuSeqlensQ, @@ -693,8 +709,8 @@ def kernel( mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, softmax_scale: Float32 | None, @@ -733,8 +749,10 @@ def kernel( # Prefetch tma descriptor if warp_idx == 0: cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_K is not None): + cpasync.prefetch_descriptor(tma_atom_K) + if const_expr(tma_atom_V is not None): + cpasync.prefetch_descriptor(tma_atom_V) if const_expr(tma_atom_O is not None): cpasync.prefetch_descriptor(tma_atom_O) @@ -748,7 +766,7 @@ def kernel( # Init "full" barrier with number of producers, "empty" barrier with number of consumers for i in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_init( - mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id]) + mbar_ptr + self.mbar_load_q_full_offset + i, 1 ) cute.arch.mbarrier_init( mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) @@ -902,7 +920,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.load_warp_id: + if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) self.load( thr_mma_qk, @@ -1070,8 +1088,8 @@ def load( sV: cute.Tensor, mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], pipeline_kv: cutlass.pipeline.PipelineAsync, mbar_ptr: cute.Pointer, block_info: BlockInfo, @@ -1079,6 +1097,8 @@ def load( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): + num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE + tidx = cute.arch.thread_idx()[0] % num_load_threads q_producer_phase = Int32(1) kv_producer_state = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.kv_stage @@ -1117,20 +1137,43 @@ def load( load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ ) - tKsK, tKgK = cpasync.tma_partition( - tma_atom_K, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK, 0, 3), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_atom_V, - 0, # no multicast - cute.make_layout(1), - cute.group_modes(sV, 0, 3), - cute.group_modes(tOgV, 0, 3), - ) + + if const_expr(self.use_tma_KV): + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + paged_kv_manager = None + else: + page_size = mK.shape[0] + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmod.create(page_size), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.n_block_size, + self.head_dim_padded, + self.head_dim_v_padded, + num_load_threads, + mK.element_type, + ) + tKsK, tKgK = None, None + tVsV, tVgV = None, None load_Q = partial( self.load_Q, @@ -1146,6 +1189,8 @@ def load( tma_atom_K, tKgK, tKsK, + paged_kv_manager, + sK, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="K", @@ -1155,6 +1200,8 @@ def load( tma_atom_V, tVgV, tVsV, + paged_kv_manager, + sV, mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, K_or_V="V", @@ -1163,15 +1210,19 @@ def load( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( - mPageTable[batch_idx, n_block_max - 1] - if const_expr(mPageTable is not None) + mPageTable[batch_idx, n_block_first] + if const_expr(mPageTable is not None and self.use_tma_KV) else None ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block_first) load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() - if const_expr(self.q_stage == 2): + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 q_producer_phase ^= 1 load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 @@ -1179,8 +1230,12 @@ def load( for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( - mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() @@ -2235,9 +2290,11 @@ def load_Q( @cute.jit def load_KV( self, - tma_atom: cute.CopyAtom, - tXgX: cute.Tensor, - tXsX: cute.Tensor, + tma_atom: Optional[cute.CopyAtom], + tXgX: Optional[cute.Tensor], + tXsX: Optional[cute.Tensor], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, mbar_full_ptr: cute.Pointer, mbar_empty_ptr: cute.Pointer, block: Int32, @@ -2253,17 +2310,29 @@ def load_KV( # K. So we need to wait for the stage after that (stage 1) to be empty as well. if stage == 0: cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V] + + if const_expr(self.use_tma_KV): + assert ( + tXgX is not None and + tXsX is not None and + tma_atom is not None ) - tXsX_cur = tXsX[None, stage] - if const_expr(self.uneven_kv_smem): - # Since this is the producer_state, the phase starts at 1, so we have to invert it - tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) - # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 - tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] - cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V], + ) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + else: + assert paged_kv_manager is not None + paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage) @cute.jit def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): @@ -2277,19 +2346,30 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): return sX def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) - ) load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - return cutlass.pipeline.PipelineTmaUmma.create( - barrier_storage=load_kv_mbar_ptr, - num_stages=self.kv_stage, - producer_group=load_kv_producer_group, - consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_bytes["K"], - ) + if self.use_tma_KV: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) + ) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + ) + else: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE + ) + return cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + barrier_storage=load_kv_mbar_ptr, + ) # @cute.jit # def warp_scheduler_barrier_init(self): diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4989067b8c1..fb36bfd492b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -413,6 +413,7 @@ def _flash_attn_fwd( is_split_kv, pack_gqa, compute_capability, + page_size not in [None, 128], # paged KV non-TMA ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -441,9 +442,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - assert page_size in [None, 128], ( - "Only page_size=128 is supported for paged KV on SM 10.0" - ) if sparse_tensors is not None: raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( @@ -461,6 +459,7 @@ def _flash_attn_fwd( and not is_split_kv, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, 128], ) else: raise ValueError( diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 6f92d0835ac..aa18566cb23 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -106,6 +106,11 @@ def apply_mask( ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_mn[0][COL] + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): @@ -299,6 +304,11 @@ def apply_mask_sm100( cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS = thr_mma.partition_C(cS) tScS_t2r = thr_tmem_load.partition_D(tScS) + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True if const_expr(not mask_causal and not mask_local): diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py new file mode 100644 index 00000000000..ccb2296b4a7 --- /dev/null +++ b/flash_attn/cute/paged_kv.py @@ -0,0 +1,176 @@ +from typing import Type +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.cute_dsl_utils import ParamsBase + + +@dataclass +class PagedKVManager(ParamsBase): + mPageTable: cute.Tensor + mK_paged: cute.Tensor + mV_paged: cute.Tensor + thread_idx: Int32 + + page_size_divmod: FastDivmod + seqlen_k: Int32 + leftpad_k: Int32 + n_block_size: Int32 + num_threads: cutlass.Constexpr[Int32] + head_dim_padded: cutlass.Constexpr[Int32] + head_dim_v_padded: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + page_entry_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + tPrPage: cute.Tensor + tPrPageOffset: cute.Tensor + tKpK: cute.Tensor + tVpV: cute.Tensor + + @staticmethod + def create( + mPageTable: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + page_size_divmod: FastDivmod, + bidb: Int32, + bidh: Int32, + thread_idx: Int32, + seqlen_k: Int32, + leftpad_k: Int32, + n_block_size: cutlass.Constexpr[Int32], + head_dim_padded: cutlass.Constexpr[Int32], + head_dim_v_padded: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + ): + universal_copy_bits = 128 + gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line + async_copy_elems = universal_copy_bits // dtype.width + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads + + tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + + mPageTable = mPageTable[bidb, None] + mK_paged = mK_paged[None, None, bidh, None] + mV_paged = mV_paged[None, None, bidh, None] + + cK = cute.make_identity_tensor((n_block_size, head_dim_padded)) + tKcK = gmem_thr_copy_KV.partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1]) + + if const_expr(head_dim_padded == head_dim_v_padded): + tVpV = tKpK + else: + cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) + tVcV = gmem_thr_copy_KV.partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + + return PagedKVManager( + mPageTable, + mK_paged, + mV_paged, + thread_idx, + page_size_divmod, + seqlen_k, + leftpad_k, + n_block_size, + num_threads, + head_dim_padded, + head_dim_v_padded, + gmem_threads_per_row, + page_entry_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + tPrPage, + tPrPageOffset, + tKpK, + tVpV, + ) + + @cute.jit + def load_page_table(self, n_block: Int32): + for i in cutlass.range(self.page_entry_per_thread, unroll=1): + row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row + row_idx = n_block * self.n_block_size + row + + page_idx, page_offset = self.page_size_divmod.divmod(row_idx + self.leftpad_k) + + is_valid = ( + (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size + ) and row_idx < self.seqlen_k + page = self.mPageTable[page_idx] if is_valid else 0 + + self.tPrPage[i] = page + self.tPrPageOffset[i] = page_offset + + @cute.jit + def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): + assert K_or_V in ("K", "V") + + # Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + + if const_expr(K_or_V == "V"): + # Need to transpose V + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + + head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded + cX = cute.make_identity_tensor((self.n_block_size, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0 + for m in cutlass.range(cute.size(tXsX, mode=[1]), unroll=1): + should_load = tXcX[0, m, 0][0] < seqlenk_row_limit + + page = self.tPrPage[m] + page_offset = self.tPrPageOffset[m] + mX_paged_cur = ( + self.mK_paged[page_offset, None, page] + if const_expr(K_or_V == "K") + else self.mV_paged[None, page_offset, page] + ) + mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) + + if should_load: + for k in cutlass.range(cute.size(tXsX, mode=[2]), unroll=1): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy[None, ki], + tXsX[None, m, k], + ) + elif const_expr(K_or_V == "V"): + # Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway. + tXsX[None, m, None].fill(0) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 6c264c30f55..14034fa9fd2 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -731,8 +731,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("rotary_interleaved", [True]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0]) -# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) -@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +# @pytest.mark.parametrize("page_size", [None, 128]) # @pytest.mark.parametrize("page_size", [128]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -1154,7 +1154,7 @@ def test_flash_attn_kvcache( # attention_chunk=attention_chunk, # rotary_interleaved=rotary_interleaved, # scheduler_metadata=scheduler_metadata, - # num_splits=num_splits, + num_splits=num_splits, # return_softmax_lse=True ) if varlen_q: From c7697bbf3ec350c9bff9c81d3d94ee282d9d11c9 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 16 Jul 2025 13:13:12 -0300 Subject: [PATCH 220/258] Add torch.compile support to flash attention 3 --- .gitignore | 2 + hopper/build.sh | 38 ++++ hopper/flash_api.cpp | 2 +- hopper/flash_attn_interface.py | 392 +++++++++++++++++++++++++++------ hopper/setup.py | 37 +++- hopper/test_flash_attn.py | 14 ++ 6 files changed, 414 insertions(+), 71 deletions(-) create mode 100644 hopper/build.sh diff --git a/.gitignore b/.gitignore index 060470d3c6f..39b997512e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.ncu-rep .DS_store .vscode +flash_attn_config.py # Byte-compiled / optimized / DLL files __pycache__/ @@ -27,6 +28,7 @@ var/ # IDE-related .idea/ +.vscode/ # Dev venv diff --git a/hopper/build.sh b/hopper/build.sh new file mode 100644 index 00000000000..6a343c3e858 --- /dev/null +++ b/hopper/build.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -e + +# Flash Attention Minimal Build Script for PHI-1 Reproducer +# Uses subshell to automatically clean up environment variables + +# Run in subshell - variables are automatically cleaned up when it exits +( + # Set minimal build flags for PHI-1 reproducer + export PYTHONBREAKPOINT="pdbp.set_trace" + export FLASH_ATTENTION_DISABLE_BACKWARD=FALSE + export FLASH_ATTENTION_DISABLE_SPLIT=FALSE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=FALSE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=FALSE + export FLASH_ATTENTION_DISABLE_PACKGQA=FALSE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=FALSE + export FLASH_ATTENTION_DISABLE_FP8=FALSE + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP32=TRUE + + # Keep only 64-dim heads for PHI-1 + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=FALSE + + echo "Environment variables set for minimal build..." + + # Install flash-attention + # python setup.py install + # python -m pytest test_flash_attn_torch_compile.py --tb=line -x -rs -sv + python -m pytest test_flash_attn.py --tb=line + +) \ No newline at end of file diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 0233da799f2..f1502390593 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1563,7 +1563,7 @@ std::tuple diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 1158ee02ad2..83706b42a3f 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, Tri Dao. -from typing import Optional, Union +from typing import Optional, Union, List, Tuple import torch import torch.nn as nn @@ -17,41 +17,90 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def round_multiple(x, m): + return (x + m - 1) // m * m + + +def round_up_headdim(head_size: int) -> int: + from flash_attn_config import CONFIG + + if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: + if head_size <= 64: + return 64 + if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: + if head_size <= 96: + return 96 + if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: + if head_size <= 128: + return 128 + if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: + if head_size <= 192: + return 192 + if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: + if head_size <= 256: + return 256 + return 256 + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if torch.__version__ >= "2.4.0": + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size=(-1, -1), - attention_chunk=0, - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - ): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: Optional[float], + causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -89,8 +138,8 @@ def _flash_attn_forward( v_descale, softmax_scale, causal, - window_size[0], - window_size[1], + window_size_left, + window_size_right, attention_chunk, softcap, rotary_interleaved, @@ -102,29 +151,134 @@ def _flash_attn_forward( return out, softmax_lse, *rest +@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") +def _flash_attn_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: Optional[float], + causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Symbolic fake implementation of flash attention forward. + Returns tensors with the correct shapes and dtypes without actual computation. + """ + + # Determine if we're in varlen mode + is_varlen_q = cu_seqlens_q is not None + is_varlen_k = cu_seqlens_k is not None + + # Get dimensions from query tensor + if is_varlen_q: + # varlen mode: q is (total_q, num_heads, head_size) + total_q, num_heads, head_size = q.shape + batch_size = cu_seqlens_q.shape[0] - 1 + + if max_seqlen_q is None: + raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided") + seqlen_q = max_seqlen_q + else: + # batch mode: q is (batch_size, seqlen_q, num_heads, head_size) + batch_size, seqlen_q, num_heads, head_size = q.shape + total_q = batch_size * q.shape[1] + # Get value head dimension + head_size_v = v.shape[-1] + + # Determine output dtype (FP8 inputs produce BF16 outputs) + q_type = q.dtype + if q_type == torch.float8_e4m3fn: + out_dtype = torch.bfloat16 + else: + out_dtype = q_type + + # Create output tensor + if out is None: + if is_varlen_q: + out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + else: + out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + + # Create softmax_lse tensor + if is_varlen_q: + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device) + else: + softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + + # TODO(guilhermeleobas): Implement "get_num_splits" + # There's an heuristic to compute num_splits when "num_splits <= 0" + # assert that num_splits is > 0 for now + if num_splits <= 0: + raise ValueError(f"{num_splits=} is not supported yet. Please set a value greater than 0") + + if num_splits > 1: + if is_varlen_q: + out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device) + else: + out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + else: + # Tensors are not set when num_splits < 1 + out_accum = None + softmax_lse_accum = None + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - dq, - dk, - dv, - softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - sm_margin=0, -): + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + sequed_q: Optional[torch.Tensor], + sequed_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + softmax_scale: Optional[float], + is_causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( @@ -144,9 +298,9 @@ def _flash_attn_backward( max_seqlen_q, max_seqlen_k, softmax_scale, - causal, - window_size[0], - window_size[1], + is_causal, + window_size_left, + window_size_right, softcap, deterministic, sm_margin, @@ -154,6 +308,99 @@ def _flash_attn_backward( return dq, dk, dv, softmax_d +@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") +def _flash_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + sequed_q: Optional[torch.Tensor], + sequed_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + softmax_scale: Optional[float], + is_causal: bool, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +): + + is_varlen_q = bool(cu_seqlens_q) + is_varlen_k = bool(cu_seqlens_k) + is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k) + + if not is_varlen_q: + batch_size = q.size()[0] + seqlen_q = q.size()[1] + seqlen_k = k.size()[1] + total_q = batch_size * q.size()[1] + else: + batch_size = cu_seqlens_q.size(0) - 1 + total_q = q.size()[0] + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if window_size_left >= seqlen_k - 1: + window_size_left = -1 + + if window_size_right >= seqlen_q - 1: + window_size_right = -1 + + if is_causal: + window_size_right = 0 + + is_causal = window_size_left < 0 and window_size_right == 0 + + head_size = q.size(-1) + head_size_v = v.size(-1) + head_size_rounded = round_up_headdim(max(head_size, head_size_v)) + + # Hopper gpus uses cuda compute capabilities 9.0 + cap = torch.cuda.get_device_capability(q.device) + arch = cap[0] * 10 + cap[1] + + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + if arch < 90: + raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}") + + if head_size_rounded <= 64: + kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 + elif head_size_rounded <= 96: + kBlockM_sm90 = 64 + elif head_size_rounded <= 128: + kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80 + else: + kBlockM_sm90 = 64 + + kBlockM = kBlockM_sm90 + + num_heads = q.shape[-2] + seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) + + total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) + + dq = torch.empty_like(q) if dq is None else dq + dk = torch.empty_like(k) if dk is None else dk + dv = torch.empty_like(v) if dv is None else dv + + if not is_varlen: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device) + else: + softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) + + return dq, dk, dv, softmax_d + + class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( @@ -196,7 +443,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, sm_margin=sm_margin, @@ -242,7 +490,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -290,7 +539,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -328,7 +578,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -388,7 +639,8 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, @@ -431,7 +683,8 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -787,7 +1040,8 @@ def flash_attn_with_kvcache( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, diff --git a/hopper/setup.py b/hopper/setup.py index 519d1c04f42..6ccb126c174 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -82,6 +82,40 @@ _maybe_write, ) +def create_build_config_file(): + CONFIG = { + "build_flags": { + "FLASHATTENTION_DISABLE_BACKWARD": DISABLE_BACKWARD, + "FLASHATTENTION_DISABLE_SPLIT": DISABLE_SPLIT, + "FLASHATTENTION_DISABLE_PAGEDKV": DISABLE_PAGEDKV, + "FLASHATTENTION_DISABLE_APPENDKV": DISABLE_APPENDKV, + "FLASHATTENTION_DISABLE_LOCAL": DISABLE_LOCAL, + "FLASHATTENTION_DISABLE_SOFTCAP": DISABLE_SOFTCAP, + "FLASHATTENTION_DISABLE_PACKGQA": DISABLE_PACKGQA, + "FLASHATTENTION_DISABLE_FP16": DISABLE_FP16, + "FLASHATTENTION_DISABLE_FP8": DISABLE_FP8, + "FLASHATTENTION_DISABLE_VARLEN": DISABLE_VARLEN, + "FLASHATTENTION_DISABLE_CLUSTER": DISABLE_CLUSTER, + "FLASHATTENTION_DISABLE_HDIM64": DISABLE_HDIM64, + "FLASHATTENTION_DISABLE_HDIM96": DISABLE_HDIM96, + "FLASHATTENTION_DISABLE_HDIM128": DISABLE_HDIM128, + "FLASHATTENTION_DISABLE_HDIM192": DISABLE_HDIM192, + "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, + "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, + "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + } + } + + with open("flash_attn_config.py", "w") as f: + f.write("# Auto-generated by flash attention 3 setup.py\n") + f.write(f"CONFIG = {repr(CONFIG)}\n") + f.write("\n") + + f.write("def show():\n") + f.write(" from pprint import pprint\n") + f.write(" pprint(CONFIG)\n") + f.write("\n") + def _write_ninja_file(path, cflags, post_cflags, @@ -395,6 +429,7 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) + create_build_config_file() check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): @@ -676,7 +711,7 @@ def run(self): "benchmarks", ) ), - py_modules=["flash_attn_interface"], + py_modules=["flash_attn_interface", "flash_attn_config"], description="FlashAttention-3", long_description=long_description, long_description_content_type="text/markdown", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 0b5a0e2af98..3b066505159 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from torch._C import parse_schema +from torch.testing._internal.optests import fake_check from einops import rearrange, repeat try: @@ -38,6 +39,7 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" +DISABLE_FAKE_CHECK = os.getenv("FLASH_ATTENTION_DISABLE_FAKE_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -49,6 +51,18 @@ ) +def run_fake_check(fn): + def wrapper(*args, **kwargs): + fake_check(fn, args, kwargs) + return fn(*args, **kwargs) + return wrapper + + +if not DISABLE_FAKE_CHECK: + flash_attn_func = run_fake_check(flash_attn_func) + flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) From e1944ba9cb4436e4d357e0b9c983bd742b3aa5e7 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 24 Jul 2025 20:30:00 +0000 Subject: [PATCH 221/258] Don't return mutated variables in mha_bwd --- hopper/build.sh | 8 ++------ hopper/flash_api.cpp | 6 +++--- hopper/flash_attn_interface.py | 10 ++++++---- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/hopper/build.sh b/hopper/build.sh index 6a343c3e858..bb5042b1119 100644 --- a/hopper/build.sh +++ b/hopper/build.sh @@ -2,12 +2,8 @@ set -e -# Flash Attention Minimal Build Script for PHI-1 Reproducer -# Uses subshell to automatically clean up environment variables - # Run in subshell - variables are automatically cleaned up when it exits ( - # Set minimal build flags for PHI-1 reproducer export PYTHONBREAKPOINT="pdbp.set_trace" export FLASH_ATTENTION_DISABLE_BACKWARD=FALSE export FLASH_ATTENTION_DISABLE_SPLIT=FALSE @@ -31,8 +27,8 @@ set -e echo "Environment variables set for minimal build..." # Install flash-attention - # python setup.py install + python setup.py install # python -m pytest test_flash_attn_torch_compile.py --tb=line -x -rs -sv - python -m pytest test_flash_attn.py --tb=line + python -m pytest test_flash_attn.py --tb=line -x ) \ No newline at end of file diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index f1502390593..7ab4352984e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1264,7 +1264,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( +std::tuple mha_bwd( at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k @@ -1563,7 +1563,7 @@ std::tuple @@ -1727,7 +1727,7 @@ TORCH_LIBRARY(flash_attn_3, m) { "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," - "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 83706b42a3f..940d11420cf 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -278,10 +278,11 @@ def _flash_attn_backward( softcap: float = 0.0, deterministic: bool = False, sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + print('aqui2') + softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, k, @@ -305,7 +306,7 @@ def _flash_attn_backward( deterministic, sm_margin, ) - return dq, dk, dv, softmax_d + return softmax_d @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") @@ -398,7 +399,7 @@ def _flash_attn_backward_fake( else: softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) - return dq, dk, dv, softmax_d + return softmax_d class FlashAttnQKVPackedFunc(torch.autograd.Function): @@ -563,6 +564,7 @@ def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + print('aqui1') _flash_attn_backward( dout, q, From a760ca3e1776e2135c931a90ea33ec3f214a0b43 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 25 Jul 2025 19:39:24 +0000 Subject: [PATCH 222/258] Change fake_check flag to be opt-in; Remove build.sh and remove if-else around `torch.library.custom_op` usage --- hopper/build.sh | 34 ---------------------------------- hopper/flash_attn_interface.py | 32 +++++--------------------------- hopper/test_flash_attn.py | 4 ++-- 3 files changed, 7 insertions(+), 63 deletions(-) delete mode 100644 hopper/build.sh diff --git a/hopper/build.sh b/hopper/build.sh deleted file mode 100644 index bb5042b1119..00000000000 --- a/hopper/build.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -set -e - -# Run in subshell - variables are automatically cleaned up when it exits -( - export PYTHONBREAKPOINT="pdbp.set_trace" - export FLASH_ATTENTION_DISABLE_BACKWARD=FALSE - export FLASH_ATTENTION_DISABLE_SPLIT=FALSE - export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE - export FLASH_ATTENTION_DISABLE_LOCAL=FALSE - export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE - export FLASH_ATTENTION_DISABLE_VARLEN=FALSE - export FLASH_ATTENTION_DISABLE_PACKGQA=FALSE - export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE - export FLASH_ATTENTION_DISABLE_APPENDKV=FALSE - export FLASH_ATTENTION_DISABLE_FP8=FALSE - export FLASH_ATTENTION_DISABLE_FP16=TRUE - export FLASH_ATTENTION_DISABLE_FP32=TRUE - - # Keep only 64-dim heads for PHI-1 - export FLASH_ATTENTION_DISABLE_HDIM96=TRUE - export FLASH_ATTENTION_DISABLE_HDIM128=TRUE - export FLASH_ATTENTION_DISABLE_HDIM192=TRUE - export FLASH_ATTENTION_DISABLE_HDIM256=FALSE - - echo "Environment variables set for minimal build..." - - # Install flash-attention - python setup.py install - # python -m pytest test_flash_attn_torch_compile.py --tb=line -x -rs -sv - python -m pytest test_flash_attn.py --tb=line -x - -) \ No newline at end of file diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 940d11420cf..aaefa14ca63 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -41,30 +41,8 @@ def round_up_headdim(head_size: int) -> int: return 256 return 256 -# torch.compile() support is only enabled for pytorch >= 2.4 -# The reason for this is that we are using the new custom_op and register_fake -# APIs, which support inplace modification of inputs in the function itself -if torch.__version__ >= "2.4.0": - _torch_custom_op_wrapper = torch.library.custom_op - _torch_register_fake_wrapper = torch.library.register_fake -else: - def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - if fn is None: - return wrap - return fn - def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - if fn is None: - return wrap - return fn - _torch_custom_op_wrapper = noop_custom_op_wrapper - _torch_register_fake_wrapper = noop_register_fake_wrapper - - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") + +@torch.library.custom_op("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, @@ -151,7 +129,7 @@ def _flash_attn_forward( return out, softmax_lse, *rest -@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") +@torch.library.register_fake("flash_attn::_flash_attn_forward") def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, @@ -254,7 +232,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, out_accum, softmax_lse_accum -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +@torch.library.custom_op("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, @@ -309,7 +287,7 @@ def _flash_attn_backward( return softmax_d -@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") +@torch.library.register_fake("flash_attn::_flash_attn_backward") def _flash_attn_backward_fake( dout: torch.Tensor, q: torch.Tensor, diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 3b066505159..87b409c1170 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -39,7 +39,7 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" -DISABLE_FAKE_CHECK = os.getenv("FLASH_ATTENTION_DISABLE_FAKE_CHECK", "FALSE") == "TRUE" +ENABLE_FAKE_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_FAKE_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -58,7 +58,7 @@ def wrapper(*args, **kwargs): return wrapper -if not DISABLE_FAKE_CHECK: +if ENABLE_FAKE_CHECK: flash_attn_func = run_fake_check(flash_attn_func) flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) From 24cc2b25e3a890101fee392ee9ae10d0af237f33 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 30 Jul 2025 13:01:00 -0300 Subject: [PATCH 223/258] Remove print statements and update exception message --- hopper/flash_attn_interface.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index aaefa14ca63..64f3c7c92bc 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -173,8 +173,7 @@ def _flash_attn_forward_fake( # Determine if we're in varlen mode is_varlen_q = cu_seqlens_q is not None - is_varlen_k = cu_seqlens_k is not None - + # Get dimensions from query tensor if is_varlen_q: # varlen mode: q is (total_q, num_heads, head_size) @@ -190,7 +189,7 @@ def _flash_attn_forward_fake( total_q = batch_size * q.shape[1] # Get value head dimension head_size_v = v.shape[-1] - + # Determine output dtype (FP8 inputs produce BF16 outputs) q_type = q.dtype if q_type == torch.float8_e4m3fn: @@ -215,7 +214,7 @@ def _flash_attn_forward_fake( # There's an heuristic to compute num_splits when "num_splits <= 0" # assert that num_splits is > 0 for now if num_splits <= 0: - raise ValueError(f"{num_splits=} is not supported yet. Please set a value greater than 0") + raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}") if num_splits > 1: if is_varlen_q: @@ -259,7 +258,6 @@ def _flash_attn_backward( ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - print('aqui2') softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, @@ -542,7 +540,6 @@ def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - print('aqui1') _flash_attn_backward( dout, q, From 5e114d53ff3a5527e8c1f62bce735c2b5301b78a Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 6 Aug 2025 21:24:00 +0000 Subject: [PATCH 224/258] Fix flash_attn_backward_fake --- hopper/flash_attn_interface.py | 198 ++++++++++++++++----------------- hopper/test_flash_attn.py | 11 +- 2 files changed, 105 insertions(+), 104 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 64f3c7c92bc..77c03ebc043 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -24,19 +24,19 @@ def round_multiple(x, m): def round_up_headdim(head_size: int) -> int: from flash_attn_config import CONFIG - if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: if head_size <= 64: return 64 - if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: if head_size <= 96: return 96 - if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: if head_size <= 128: return 128 - if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: if head_size <= 192: return 192 - if CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: if head_size <= 256: return 256 return 256 @@ -47,28 +47,28 @@ def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - qv: Optional[torch.Tensor], - out: Optional[torch.Tensor], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - cu_seqlens_k_new: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - page_table: Optional[torch.Tensor], - kv_batch_idx: Optional[torch.Tensor], - leftpad_k: Optional[torch.Tensor], - rotary_cos: Optional[torch.Tensor], - rotary_sin: Optional[torch.Tensor], - seqlens_rotary: Optional[torch.Tensor], - q_descale: Optional[torch.Tensor], - k_descale: Optional[torch.Tensor], - v_descale: Optional[torch.Tensor], - softmax_scale: Optional[float], - causal: bool, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, attention_chunk: int = 0, @@ -134,28 +134,28 @@ def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - qv: Optional[torch.Tensor], - out: Optional[torch.Tensor], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - cu_seqlens_k_new: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - page_table: Optional[torch.Tensor], - kv_batch_idx: Optional[torch.Tensor], - leftpad_k: Optional[torch.Tensor], - rotary_cos: Optional[torch.Tensor], - rotary_sin: Optional[torch.Tensor], - seqlens_rotary: Optional[torch.Tensor], - q_descale: Optional[torch.Tensor], - k_descale: Optional[torch.Tensor], - v_descale: Optional[torch.Tensor], - softmax_scale: Optional[float], - causal: bool, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, attention_chunk: int = 0, @@ -233,28 +233,28 @@ def _flash_attn_forward_fake( @torch.library.custom_op("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - sequed_q: Optional[torch.Tensor], - sequed_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - softmax_scale: Optional[float], - is_causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - sm_margin: int = 0, + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -287,42 +287,42 @@ def _flash_attn_backward( @torch.library.register_fake("flash_attn::_flash_attn_backward") def _flash_attn_backward_fake( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - sequed_q: Optional[torch.Tensor], - sequed_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - softmax_scale: Optional[float], - is_causal: bool, - window_size_left: int = -1, - window_size_right: int = -1, - softcap: float = 0.0, - deterministic: bool = False, - sm_margin: int = 0, -): + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: is_varlen_q = bool(cu_seqlens_q) is_varlen_k = bool(cu_seqlens_k) is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k) if not is_varlen_q: - batch_size = q.size()[0] - seqlen_q = q.size()[1] - seqlen_k = k.size()[1] - total_q = batch_size * q.size()[1] + batch_size = q.size(0) + seqlen_q = q.size(1) + seqlen_k = k.size(1) + total_q = batch_size * q.size(1) else: batch_size = cu_seqlens_q.size(0) - 1 - total_q = q.size()[0] + total_q = q.size(0) seqlen_q = max_seqlen_q seqlen_k = max_seqlen_k diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 87b409c1170..323894a16cb 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from torch._C import parse_schema -from torch.testing._internal.optests import fake_check +from torch.testing._internal.optests.generate_tests import safe_fake_check, safe_schema_check from einops import rearrange, repeat try: @@ -51,16 +51,17 @@ ) -def run_fake_check(fn): +def run_opcheck(fn): def wrapper(*args, **kwargs): - fake_check(fn, args, kwargs) + safe_schema_check(fn, args, kwargs) + safe_fake_check(fn, args, kwargs) return fn(*args, **kwargs) return wrapper if ENABLE_FAKE_CHECK: - flash_attn_func = run_fake_check(flash_attn_func) - flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) + flash_attn_func = run_opcheck(flash_attn_func) + flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) From 734bc437bd1040be01ac941a13bdf36fe40aad0f Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 7 Aug 2025 19:11:14 +0000 Subject: [PATCH 225/258] Add `safe_aot_autograd_check` --- hopper/test_flash_attn.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 323894a16cb..efa13afb3fb 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,7 +6,11 @@ import torch import torch.nn.functional as F from torch._C import parse_schema -from torch.testing._internal.optests.generate_tests import safe_fake_check, safe_schema_check +from torch.testing._internal.optests.generate_tests import ( + safe_fake_check, + safe_schema_check, + safe_aot_autograd_check, +) from einops import rearrange, repeat try: @@ -40,6 +44,7 @@ DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" ENABLE_FAKE_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_FAKE_CHECK", "FALSE") == "TRUE" +ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_FAKE_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -50,11 +55,35 @@ + ([256] if not DISABLE_HDIM256 else []) ) +def should_test_backward(args, kwargs): + v = args[2] + dtype = v.dtype + has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True + attention_chunk = kwargs.get("attention_chunk") + dv = v.size(-1) + + if ( + ENABLE_AUTOGRAD_CHECK + and not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + return True + return False + def run_opcheck(fn): def wrapper(*args, **kwargs): safe_schema_check(fn, args, kwargs) safe_fake_check(fn, args, kwargs) + + if should_test_backward(args, kwargs): + # Expensive check + safe_aot_autograd_check(fn, args, kwargs, dynamic=False) + safe_aot_autograd_check(fn, args, kwargs, dynamic=True) return fn(*args, **kwargs) return wrapper From fde4bc0cd4218a031a40d87ca1259e7dfce19220 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 19 Aug 2025 20:13:34 +0000 Subject: [PATCH 226/258] Update namespace to flash_attn_3 --- hopper/flash_attn_interface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 77c03ebc043..143bd11b6c7 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -42,7 +42,7 @@ def round_up_headdim(head_size: int) -> int: return 256 -@torch.library.custom_op("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, @@ -129,7 +129,7 @@ def _flash_attn_forward( return out, softmax_lse, *rest -@torch.library.register_fake("flash_attn::_flash_attn_forward") +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, @@ -231,7 +231,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, out_accum, softmax_lse_accum -@torch.library.custom_op("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +@torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, @@ -285,7 +285,7 @@ def _flash_attn_backward( return softmax_d -@torch.library.register_fake("flash_attn::_flash_attn_backward") +@torch.library.register_fake("flash_attn_3::_flash_attn_backward") def _flash_attn_backward_fake( dout: torch.Tensor, q: torch.Tensor, From ab79ae25a077fb30a9963e4fa52157d8fc1c6145 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 22 Aug 2025 00:28:07 +0000 Subject: [PATCH 227/258] Add `flash_attn_forward.register_autograd` --- hopper/flash_attn_interface.py | 43 ++++++++++++++++++++++++++++++++++ hopper/test_flash_attn.py | 6 ++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 143bd11b6c7..7820a3e29d3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -378,6 +378,49 @@ def _flash_attn_backward_fake( return softmax_d +def setup_context(ctx, inputs, output): + q, k, v = inputs[:3] + out, softmax_lse, _, _ = output + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = inputs[-11] + ctx.causal = inputs[-10] + ctx.window_size = [inputs[-9], inputs[-8]] + ctx.attention_chunk = inputs[-7] + ctx.softcap = inputs[-6] + ctx.sm_margin = inputs[-1] + + +def _backward(ctx, dout, *grads): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + False, # deterministic + ctx.sm_margin, + ) + return dq, dk, dv, *((None,) * 21) + + +_flash_attn_forward.register_autograd(_backward, setup_context=setup_context) + + + class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index efa13afb3fb..4f81dcb1df6 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -43,8 +43,8 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" -ENABLE_FAKE_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_FAKE_CHECK", "FALSE") == "TRUE" -ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_FAKE_CHECK", "FALSE") == "TRUE" +ENABLE_OPCHECK = os.getenv("FLASH_ATTENTION_ENABLE_OPCHECK", "FALSE") == "TRUE" +ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -88,7 +88,7 @@ def wrapper(*args, **kwargs): return wrapper -if ENABLE_FAKE_CHECK: +if ENABLE_OPCHECK: flash_attn_func = run_opcheck(flash_attn_func) flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) From 6250fbecbc5a101185e7d0677a650d3a029dd3eb Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 22 Aug 2025 17:25:23 -0300 Subject: [PATCH 228/258] Fix bug in `flash_attn_backward_fake` --- hopper/flash_attn_interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 7820a3e29d3..438ccbaae81 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -311,9 +311,9 @@ def _flash_attn_backward_fake( sm_margin: int = 0, ) -> torch.Tensor: - is_varlen_q = bool(cu_seqlens_q) - is_varlen_k = bool(cu_seqlens_k) - is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k) + is_varlen_q = cu_seqlens_q is not None + is_varlen_k = cu_seqlens_q is not None + is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None if not is_varlen_q: batch_size = q.size(0) From 1e3539e457f90a1579780f4495dd9abd88336737 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 2 Sep 2025 18:13:05 +0000 Subject: [PATCH 229/258] Add support and tests for torch.export and aoti_compile_and_package --- hopper/flash_attn_interface.py | 17 ++++-- hopper/test_torch_compile_and_export.py | 73 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 hopper/test_torch_compile_and_export.py diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 438ccbaae81..4896a08e626 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -90,7 +90,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( q, k, v, @@ -126,7 +126,14 @@ def _flash_attn_forward( pack_gqa, sm_margin, ) - return out, softmax_lse, *rest + + if out_accum is None: + out_accum = torch.tensor([], device=out.device) + + if softmax_lse_accum is None: + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum @torch.library.register_fake("flash_attn_3::_flash_attn_forward") @@ -225,8 +232,8 @@ def _flash_attn_forward_fake( softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) else: # Tensors are not set when num_splits < 1 - out_accum = None - softmax_lse_accum = None + out_accum = torch.tensor([], device=out.device) + softmax_lse_accum = torch.tensor([], device=out.device) return out, softmax_lse, out_accum, softmax_lse_accum @@ -253,7 +260,7 @@ def _flash_attn_backward( window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, - deterministic: bool = False, + deterministic: bool= False, sm_margin: int = 0, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous diff --git a/hopper/test_torch_compile_and_export.py b/hopper/test_torch_compile_and_export.py new file mode 100644 index 00000000000..53beef46340 --- /dev/null +++ b/hopper/test_torch_compile_and_export.py @@ -0,0 +1,73 @@ +import torch +from flash_attn_interface import flash_attn_func +from torch import nn + + +class EfficienctMultiHeadAttention(nn.Module): + def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True): + super().__init__() + assert embed_size % num_heads == 0, f"{embed_size=} {num_heads=}" + + self.embed_size = embed_size + self.num_heads = num_heads + self.head_dim = embed_size // num_heads + self.use_flash_attn = use_flash_attn and (flash_attn_func is not None) + + self.qkv_proj = nn.Linear(embed_size, 3 * embed_size) + self.out_proj = nn.Linear(embed_size, embed_size) + self.dropout = dropout + + def forward(self, x, attention_mask=None): + N, seq_length, _ = x.shape + + qkv = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(N, seq_length, self.num_heads, self.head_dim) + k = k.view(N, seq_length, self.num_heads, self.head_dim) + v = v.view(N, seq_length, self.num_heads, self.head_dim) + + if self.use_flash_attn and attention_mask is None: + out = flash_attn_func( + q, k, v + ) + out = out.reshape(N, seq_length, self.embed_size) + out = self.out_proj(out) + return out + + +def create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16): + model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16() + input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16() + return model, input_tensor + + +def test_export_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + loss = expected.sum() + loss.backward() + + ep = torch.export.export(model, (input_tensor,)) + got = ep.module()(input_tensor,) + assert torch.equal(expected, got) + + loss_2 = got.sum() + loss_2.backward() + + assert torch.equal(loss, loss_2) + + +def test_compile_and_package_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + + exported = torch.export.export(model, (input_tensor,)) + torch._inductor.aoti_compile_and_package( + exported, + package_path="model.pt2", + ) + + compiled_model = torch._inductor.package.load_package("model.pt2") + out = compiled_model(input_tensor,) + assert torch.equal(expected, out) From f174bd6f464eca35139de1402c3c885a6db5f123 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 3 Sep 2025 21:24:36 +0000 Subject: [PATCH 230/258] format code --- hopper/flash_attn_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 4896a08e626..6ec8b260569 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -260,7 +260,7 @@ def _flash_attn_backward( window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, - deterministic: bool= False, + deterministic: bool = False, sm_margin: int = 0, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous From 6fe1c8c728d7e7e377ad6a7b47f49fc20037f692 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 19 Sep 2025 18:25:31 +0000 Subject: [PATCH 231/258] update flash_api_stable.cpp --- hopper/flash_api_stable.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 6de5c5ac380..15f0254e204 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1335,7 +1335,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::tuple mha_bwd( +std::tuple mha_bwd( Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k @@ -1641,7 +1641,7 @@ std::tuple mha_b torch::stable::zero_(softmax_d); } - return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; + return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } std::tuple @@ -1949,7 +1949,7 @@ STABLE_TORCH_LIBRARY(flash_attn_3, m) { "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," - "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," From b555ac7137aaf4e40075f1dd89a3a103d4ed1c72 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 13 Oct 2025 14:23:05 +0000 Subject: [PATCH 232/258] Fix flash_api_stable.cpp build --- .gitignore | 4 +++- hopper/flash_api_stable.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 39b997512e4..dc508654045 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ *.ncu-rep .DS_store .vscode -flash_attn_config.py # Byte-compiled / optimized / DLL files __pycache__/ @@ -32,3 +31,6 @@ var/ # Dev venv + +# compile-time generated file +flash_attn_config.py \ No newline at end of file diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 15f0254e204..66e6fe78192 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1828,11 +1828,11 @@ void boxed_mha_bwd( auto deterministic = to(stack[20]); auto sm_margin = to(stack[21]); - auto [dq_, dk_, dv_, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); + auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); - stack[0] = from(dq_); - stack[1] = from(dk_); - stack[2] = from(dv_); + stack[0] = from(dq); + stack[1] = from(dk); + stack[2] = from(dv); stack[3] = from(softmax_d); stack[4] = from(softmax_lse_log2); stack[5] = from(dq_accum); From 0aa4fa10ae9079be6d92c14a8d6247edffefdeb8 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 13 Oct 2025 17:36:29 +0000 Subject: [PATCH 233/258] Only run schema_check if dtype is not float8_e4m3fn --- hopper/test_flash_attn.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 4f81dcb1df6..042c6d440c9 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -75,9 +75,17 @@ def should_test_backward(args, kwargs): return False +def should_run_schema_check(args, kwargs): + v = args[2] + if v.dtype == torch.float8_e4m3fn: + return False + return True + + def run_opcheck(fn): def wrapper(*args, **kwargs): - safe_schema_check(fn, args, kwargs) + if should_run_schema_check(args, kwargs): + safe_schema_check(fn, args, kwargs) safe_fake_check(fn, args, kwargs) if should_test_backward(args, kwargs): From 47d7137ba3e5b5e6bdf7bf5cdff667938d0a0ef0 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 13 Oct 2025 15:13:29 -0300 Subject: [PATCH 234/258] Correctly compute kBlockM for sm88/86/80 --- hopper/flash_attn_interface.py | 17 +++++++++++------ hopper/test_flash_attn.py | 6 +++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 6ec8b260569..d985eae51a6 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -354,9 +354,6 @@ def _flash_attn_backward_fake( is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal - if arch < 90: - raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}") - if head_size_rounded <= 64: kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 elif head_size_rounded <= 96: @@ -366,7 +363,15 @@ def _flash_attn_backward_fake( else: kBlockM_sm90 = 64 - kBlockM = kBlockM_sm90 + kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64 + kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32 + + if arch >= 90: + kBlockM = kBlockM_sm90 + elif arch == 86 or arch == 89: + kBlockM = kBlockM_sm86 + else: + kBlockM = kBlockM_sm80 num_heads = q.shape[-2] seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) @@ -374,7 +379,7 @@ def _flash_attn_backward_fake( total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) dq = torch.empty_like(q) if dq is None else dq - dk = torch.empty_like(k) if dk is None else dk + dk = torch.empty_like(k) if dk is None else dk dv = torch.empty_like(v) if dv is None else dv if not is_varlen: @@ -396,7 +401,7 @@ def setup_context(ctx, inputs, output): ctx.softcap = inputs[-6] ctx.sm_margin = inputs[-1] - + def _backward(ctx, dout, *grads): q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 042c6d440c9..9aef059f2d0 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -64,9 +64,9 @@ def should_test_backward(args, kwargs): if ( ENABLE_AUTOGRAD_CHECK - and not DISABLE_BACKWARD - and dtype != torch.float8_e4m3fn - and not V_colmajor + and not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor and not has_qv and not dv > 256 and not attention_chunk != 0 From 49fb7752e75bd874d80f5a93813a6e24cf7e0ea5 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 13 Oct 2025 19:32:32 +0000 Subject: [PATCH 235/258] Fix bug in boxed_mha_bwd --- hopper/flash_api_stable.cpp | 13 +++++-------- hopper/test_flash_attn.py | 10 +++++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 66e6fe78192..5ae58bdd129 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1830,14 +1830,11 @@ void boxed_mha_bwd( auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); - stack[0] = from(dq); - stack[1] = from(dk); - stack[2] = from(dv); - stack[3] = from(softmax_d); - stack[4] = from(softmax_lse_log2); - stack[5] = from(dq_accum); - stack[6] = from(dk_accum); - stack[7] = from(dv_accum); + stack[0] = from(softmax_d); + stack[1] = from(softmax_lse_log2); + stack[2] = from(dq_accum); + stack[3] = from(dk_accum); + stack[4] = from(dv_accum); } void boxed_mha_combine( diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 9aef059f2d0..8cfa30c08ae 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -82,11 +82,19 @@ def should_run_schema_check(args, kwargs): return True +def should_run_fake_check(args, kwargs): + if 'num_splits' in kwargs: + return kwargs['num_splits'] > 0 + return True + + def run_opcheck(fn): def wrapper(*args, **kwargs): if should_run_schema_check(args, kwargs): safe_schema_check(fn, args, kwargs) - safe_fake_check(fn, args, kwargs) + + if should_run_fake_check(args, kwargs): + safe_fake_check(fn, args, kwargs) if should_test_backward(args, kwargs): # Expensive check From 65dd5806228447dee7053ea628d56ba3285c7051 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 12 Nov 2025 21:32:37 +0000 Subject: [PATCH 236/258] don't run autograd_check when num_splits > 0 --- hopper/test_flash_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 8cfa30c08ae..78a8e7c2cc4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -57,6 +57,7 @@ def should_test_backward(args, kwargs): v = args[2] + num_splits = kwargs.get("num_splits", 1) dtype = v.dtype has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True attention_chunk = kwargs.get("attention_chunk") @@ -70,6 +71,7 @@ def should_test_backward(args, kwargs): and not has_qv and not dv > 256 and not attention_chunk != 0 + and num_splits > 0 # we don't support num_split == 0 on torch.compile yet ): return True return False From b4555bfc3244a7607ea499158d3ef0b3a9ea2860 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 17 Nov 2025 16:58:03 -0800 Subject: [PATCH 237/258] [Cute] Add block-sparsity support to SM100 (#1985) - Implement block-sparse attention in flash_fwd_sm100.py - Update interface.py to handle SM100 block size calculations (2x multiplier for m_block_size since 1 CTA handles 2*tile_m rows) - Add mask_mod parameter support in mask.py for block-sparse masking - Add SM100 test fixtures and tile size handling in test_mask_mod.py This enables block-sparsity on SM 10.0 architecture, including mask_mod support and proper block size accounting. --- flash_attn/cute/block_sparse_utils.py | 381 ++++++++++++++++++- flash_attn/cute/compute_block_sparsity.py | 11 +- flash_attn/cute/flash_bwd_sm100.py | 52 ++- flash_attn/cute/flash_fwd_sm100.py | 438 ++++++++++++++-------- flash_attn/cute/interface.py | 32 +- flash_attn/cute/mask.py | 36 +- tests/cute/test_mask_mod.py | 71 +++- 7 files changed, 819 insertions(+), 202 deletions(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index d1cb95e18ed..f117498fd2c 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -7,12 +7,14 @@ from typing import Callable from functools import partial +import math import cutlass import cutlass.cute as cute -from cutlass import const_expr +from cutlass import Float32, Int32, const_expr # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute import utils @cute.jit @@ -143,8 +145,13 @@ def produce_block_sparse_loads( curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 @@ -417,3 +424,371 @@ def consume_block_sparse_loads( O_should_accumulate = True return kv_consumer_state, O_should_accumulate, processed_any + + +@cute.jit +def load_block_list_sm100( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + m_block, + q_stage: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, +): + """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + if block_count > 0: + # First iteration: load Q alongside K if requested + n_block_first = block_indices[block_count - 1] + + if const_expr(load_q_with_first): + # SM100 loads Q0 and optionally Q1 + load_Q(block=q_stage * m_block + 0, stage=0) + if const_expr(q_stage == 2): + load_Q(block=q_stage * m_block + 1, stage=1) + + # SM100 doesn't use producer_acquire for pipeline_kv in load path + # The pipeline barriers are handled inside load_KV + load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + # Remaining blocks + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + return kv_producer_state + + +# SM100-specific tile processor using SM100 helpers +@cute.jit +def produce_block_sparse_loads_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + q_stage: cutlass.Constexpr, + q_producer_phase: Int32, +): + """SM100 entry point for sparse block iteration. + + SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use + simplified block processing that just calls producer_acquire without extras. + """ + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + q_phase_flipped = False + + if mask_empty: + # No masked blocks: process full list with Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = not full_empty + else: + # Process masked blocks with Q loading + kv_producer_state = load_block_list_sm100( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = True + + if not full_empty: + # Process full blocks without Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + + if q_phase_flipped: + q_producer_phase ^= 1 + + return kv_producer_state, q_producer_phase + + +@cute.jit +def get_total_block_count( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + if const_expr(full_block_cnt is not None): + return ( + mask_block_cnt[batch_idx, head_idx, m_block] + + full_block_cnt[batch_idx, head_idx, m_block] + ) + else: + return mask_block_cnt[batch_idx, head_idx, m_block] + + +@cute.jit +def handle_block_sparse_empty_tile_correction_sm100( + tidx: Int32, + q_stage: cutlass.Constexpr, + m_block_size: cutlass.Constexpr, + qhead_per_kvhead, + pack_gqa: cutlass.Constexpr, + is_split_kv: cutlass.Constexpr, + learnable_sink, + mLSE, + seqlen, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + split_idx: Int32, + sScale: cute.Tensor, + stats: list, + correction_epilogue: Callable, + thr_mma_pv: cute.core.ThrMma, + tOtOs: tuple[cute.Tensor], + sO: cute.Tensor, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + mbar_corr_epi_full_offset: Int32, + mbar_corr_epi_empty_offset: Int32, + softmax_corr_consumer_phase: Int32, + o_corr_consumer_phase: Int32, + corr_epi_producer_phase: Int32, + softmax_scale_log2: Float32, +): + """Handle the block-sparse case where a tile is fully masked: + * zero staged results + * seed stats + * satisfy the usual barrier protocol so downstream warps continue to make progress. + """ + LOG2_E = Float32(math.log2(math.e)) + + for stage in cutlass.range_constexpr(q_stage): + row_sum_value = Float32(1.0) + row_max_value = ( + -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None + ) + if const_expr(learnable_sink is not None): + sink_val = -Float32.inf + if const_expr(not pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + elif tidx < m_block_size: + q_head_idx = ( + (q_stage * m_block + stage) * m_block_size + tidx + ) % qhead_per_kvhead + head_idx * qhead_per_kvhead + sink_val = Float32(learnable_sink[q_head_idx]) + if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): + if row_max_value == -Float32.inf: + row_max_value = sink_val * (LOG2_E / softmax_scale_log2) + row_sum_value = Float32(1.0) + else: + row_sum_value = row_sum_value + utils.exp2f( + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + ) + if tidx < m_block_size: + scale_row_idx = tidx + stage * m_block_size + sScale[scale_row_idx] = row_sum_value + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[scale_row_idx + m_block_size * 2] = row_max_value + acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value + stats[stage] = (row_sum_value, row_max_value, acc_flag) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) + correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs + sO[None, None, stage], + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) + + softmax_corr_consumer_phase ^= 1 + o_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + return ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) + + +@cute.jit +def softmax_block_sparse_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + softmax_step: Callable, + mask_fn: Callable, + mask_fn_none: Callable, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + q_stage: cutlass.Constexpr, + stage_idx: Int32, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + + if total_block_cnt == 0: + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx) + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), # last block could oob + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=True, + mask_fn=partial(mask_fn_none, mask_seqlen=True), + ) + else: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=False, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + + return ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + total_block_cnt == 0, + ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index bec6fe5701f..acaeac794c5 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -1,11 +1,8 @@ from functools import partial -import math -import operator -from typing import Callable, Optional, Tuple, Type +from typing import Callable, Optional, Tuple -import cuda.bindings.driver as cuda import cutlass -from cutlass import Boolean, Constexpr, Float32, Int32, Int8, const_expr +from cutlass import Boolean, Int32, Int8, const_expr import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack import torch @@ -276,11 +273,11 @@ def compute_block_sparsity( batch_size: The batch size. num_heads: The number of heads. seqlen_q: The sequence length for the query. - seqlen_k: The sequence length for the key. + seqlen_k: The sequence length for the key. mask_mod: The `mask_mod` callable to use. aux_tensors: A list of auxiliary tensors. device: The device to use. - compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 3b9aa00cb33..0a29ce462a8 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -315,7 +315,7 @@ def _setup_smem_layout(self): ) self.sdKV_epi_tile = ( self.tile_n, - 128 // (self.dk_dtype.width // 8), # 64 or 32 + 128 // (self.dk_dtype.width // 8), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] self.num_epi_stages = (self.tile_hdim // 2) // self.sdKV_epi_tile[1] self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages @@ -326,12 +326,10 @@ def _setup_smem_layout(self): self.dk_dtype, LayoutEnum.ROW_MAJOR, self.sdKV_epi_tile, - 2, # num compute wgs + 2, # num compute wgs ) else: - self.sdKV_layout = cute.make_layout( - (self.tile_n * self.dK_reduce_ncol, 2) - ) + self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) @cute.jit def __call__( @@ -389,9 +387,7 @@ def __call__( ] layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO = [ - utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO) - ] + mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)] LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) mLSE, mdPsum, mdQaccum = [ utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) @@ -400,10 +396,8 @@ def __call__( layout_dKV_transpose = layout_transpose else: layout_dKV_transpose = LSE_dPsum_dQaccum_transpose - mdK, mdV = [ - utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV) - ] - dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) + mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] + dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) mdO = utils.select(mdO, mode=dO_transpose) semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) @@ -451,7 +445,7 @@ def __call__( raise RuntimeError("The layout of mdK is wrong") if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - + if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1): tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( @@ -2253,32 +2247,32 @@ def epilogue_dK_or_dV_tma( if const_expr(self.qhead_per_kvhead == 1): sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 else: - sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 - + sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 + # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(self.qhead_per_kvhead == 1): - mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) - ) # (tile_n, hdim) - gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) + ) # (tile_n, hdim) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) gdKV_epi = cute.local_tile( gdKV, self.sdKV_epi_tile, (0, None) - ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) + ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: - mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) gdKV_p = cute.local_tile( - mdKV_cur, (self.tile_n * self.tile_hdim, ), (n_block, ) - ) # (tile_n * hdim) - gdKV = cute.logical_divide( - gdKV_p, (self.tile_n * self.tile_hdim // num_wg, ) - )[((None, wg_idx), )] # (tile_n * hdim / 2) + mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) + ) # (tile_n * hdim) + gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[ + ((None, wg_idx),) + ] # (tile_n * hdim / 2) gdKV_epi = cute.flat_divide( - gdKV, (self.sdKV_flat_epi_tile, ) - ) # (tile_n * hdim / 2 / epi_stage, epi_stage) + gdKV, (self.sdKV_flat_epi_tile,) + ) # (tile_n * hdim / 2 / epi_stage, epi_stage) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] @@ -2290,7 +2284,7 @@ def epilogue_dK_or_dV_tma( cute.make_layout(1), cute.group_modes(sdKV, 0, 2), cute.group_modes(gdKV_epi, 0, 2), - ) # (TMA) and (TMA, EPI_STAGE) + ) # (TMA) and (TMA, EPI_STAGE) assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" num_epi_stages = cute.size(tdKVgdKV.shape[1]) @@ -2344,7 +2338,7 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) ) - tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) # RMEM -> SMEM -- copy, fence and barrier diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 915315d461b..521e1325a8f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -36,6 +36,12 @@ from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_block_count, + produce_block_sparse_loads_sm100, + softmax_block_sparse_sm100, + handle_block_sparse_empty_tile_correction_sm100, +) from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils @@ -76,6 +82,7 @@ def __init__( n_block_size: int = 128, is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, ): @@ -116,6 +123,7 @@ def __init__( "SplitKV is not supported for hdim >= 192" ) self.score_mod = score_mod + self.mask_mod = mask_mod if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: @@ -652,6 +660,10 @@ class SharedStorage: seqlen_k_divmod = FastDivmod.create(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): + raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + # Launch the kernel synchronously self.kernel( mQ, @@ -673,6 +685,7 @@ class SharedStorage: window_size_left, window_size_right, learnable_sink, + blocksparse_tensors, sQ_layout, sK_layout, tP_layout, @@ -717,6 +730,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, @@ -941,6 +955,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) # /////////////////////////////////////////////////////////////////////////////// @@ -970,6 +985,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) # if warp_idx == self.mma_warp_id: @@ -1024,6 +1040,7 @@ def kernel( TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, + blocksparse_tensors=blocksparse_tensors, ) if const_expr(not self.s0_s1_barrier): @@ -1070,6 +1087,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -1096,6 +1114,7 @@ def load( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads @@ -1207,40 +1226,58 @@ def load( K_or_V="V", ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - n_block_first = n_block_max - 1 if n_block_max > 0 else 0 - page_idx = ( - mPageTable[batch_idx, n_block_first] - if const_expr(mPageTable is not None and self.use_tma_KV) - else None + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits ) - if const_expr(not self.use_tma_KV): - paged_kv_manager.load_page_table(n_block_first) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 - kv_producer_state.advance() - if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): - load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 - q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 2 - i + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( - mPageTable[batch_idx, n_block] + mPageTable[batch_idx, n_block_first] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) if const_expr(not self.use_tma_KV): - paged_kv_manager.load_page_table(n_block) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + paged_kv_manager.load_page_table(n_block_first) + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + + else: + kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + self.q_stage, + q_producer_phase, + ) + tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1264,6 +1301,7 @@ def mma( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1308,15 +1346,28 @@ def mma( while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + block_iter_count = Int32(0) + process_tile = False + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + process_tile = block_iter_count > Int32(0) + else: + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + block_iter_count = n_block_max - n_block_min + if const_expr(not self.is_split_kv): + process_tile = True + else: + process_tile = n_block_min < n_block_max + + if process_tile: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase - ) + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1345,8 +1396,9 @@ def mma( # so we need to release them after the seqlen_kv loop # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + block_loop_count = block_iter_count - 1 O_should_accumulate = False - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + for i in cutlass.range(block_loop_count, unroll=1): # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1444,7 +1496,7 @@ def mma( ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp - # has signaled to the correction warp, the softmax warp has just finished compute + # has signaled to the correction warps, the softmax warp has just finished compute # the row sum of the current tile. It does not guarantee that the 1st tile # of the next work tile has been computed yet. with cute.arch.elect_one(): @@ -1461,6 +1513,7 @@ def mma( work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop + # for both softmax0 and softmax1 warp group @cute.jit def softmax_loop( @@ -1481,6 +1534,7 @@ def softmax_loop( TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1548,115 +1602,173 @@ def softmax_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + shared_mask_kwargs = dict( + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + ) + block_mask_mod = self.mask_mod if const_expr(self.use_block_sparsity) else None + mask_fn = partial( + mask.apply_mask_sm100, + mask_mod=block_mask_mod, + **shared_mask_kwargs, + ) + if const_expr(self.use_block_sparsity): + # Full blocks dont need mask_mod + mask_fn_none = partial( mask.apply_mask_sm100, - m_block=self.q_stage * m_block + stage, - thr_mma=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - mask_causal=self.is_causal, - mask_local=self.is_local, - ) - softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, - ) - softmax.reset() - - softmax_step = partial( - self.softmax_step, - softmax=softmax, - mbar_ptr=mbar_ptr, - mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, - thr_mma_qk=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - thr_tmem_store=thr_tmem_store, - thr_tmem_store_scale=thr_tmem_store_scale, - tStS_t2r=tStS_t2r, - tStScale_r2t=tStScale_r2t, - tStP_r2t=tStP_r2t, - sScale=sScale, - stage=stage, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=self.q_stage * m_block + stage, - seqlen=seqlen, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, + mask_mod=None, + **shared_mask_kwargs, ) + else: + mask_fn_none = None + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + if const_expr(self.use_block_sparsity): + tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = tile_block_count > Int32(0) + else: + tile_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + if has_work: + # Softmax acts as the producer: wait until correction signals the stage is empty cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase ) si_corr_producer_phase ^= 1 - # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + + # Block sparse or dense iteration + if const_expr(self.use_block_sparsity): + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + empty_tile, + ) = softmax_block_sparse_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + softmax_step, + mask_fn, + mask_fn_none, mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, - n_block_max - 1, - is_first=True, - mask_fn=partial(mask_fn, mask_seqlen=True), + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.q_stage, + Int32(stage), ) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min + if not empty_tile: + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + else: + if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), ) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), - ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min ) - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - n_block = n_block_max - n_tile - 1 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block - ) - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) ) - ) - # Now that we no longer already have the 1st iteration, need mask_seqlen=True here - - # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape - # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) - # tSrScale_r2t[0] = softmax.row_sum[0] - # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) - # cute.arch.fence_view_async_tmem_store() - sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - if const_expr(mLSE is not None or learnable_sink is not None): - sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ - 0 - ] - # if tidx == 0: - # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) - # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block + ) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # Dense path always writes scale / signals + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) # # Write LSE to gmem # if const_expr(mLSE is not None): @@ -1826,6 +1938,7 @@ def correction_loop( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) @@ -1862,7 +1975,14 @@ def correction_loop( # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_block_sparsity): + total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = total_block_count > Int32(0) + else: + total_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + + if has_work: # Ignore first signal from softmax as no correction is required cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase @@ -1874,7 +1994,7 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) - for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for i in cutlass.range(total_block_count - 1, unroll=1): for stage in cutlass.range_constexpr(2): # wait for S0 / S1 cute.arch.mbarrier_wait( @@ -1969,6 +2089,44 @@ def correction_loop( o_corr_consumer_phase ^= 1 softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 + else: + if const_expr(self.use_block_sparsity): + ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) = handle_block_sparse_empty_tile_correction_sm100( + tidx, + self.q_stage, + self.m_block_size, + self.qhead_per_kvhead, + self.pack_gqa, + self.is_split_kv, + learnable_sink, + mLSE, + seqlen, + m_block, + head_idx, + batch_idx, + split_idx, + sScale, + stats, + self.correction_epilogue, + thr_mma_pv, + tOtOs, + sO, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.mbar_corr_epi_full_offset, + self.mbar_corr_epi_empty_offset, + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + softmax_scale_log2, + ) if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): @@ -2006,28 +2164,6 @@ def correction_loop( # This actually just works with PackGQA too gLSE[tidx] = lse - # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) - # gO = gO_qdhb[None, None, None, head_idx, batch_idx] - # tOsO, tOgO = cpasync.tma_partition( - # tma_atom_O, - # 0, - # cute.make_layout(1), - # cute.group_modes(sO, 0, 2), - # cute.group_modes(gO, 0, 2), - # ) - # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - # stage = warp_idx_in_wg - # if stage < self.q_stage: - # # wait from corr, issue tma store on smem - # # 1. wait for O0 / O1 final - # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) - # # 2. copy O0 / O1 to gmem - # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) - # cute.arch.cp_async_bulk_commit_group() - # # Ensure O0 / O1 buffer is ready to be released - # cute.arch.cp_async_bulk_wait_group(0, read=True) - # cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) - # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index fb36bfd492b..db7930de537 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -259,11 +259,25 @@ def _flash_attn_fwd( if page_table is not None else None ) + compute_capability = ( + torch.cuda.get_device_capability()[0] + if _compute_capability is None + else _compute_capability + ) + + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + + sparse_tensors = None if block_sparse_tensors is not None: if seqlen_q is None: raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") - expected_m_blocks = (seqlen_q + m_block_size - 1) // m_block_size + m_block_size_block = m_block_size + if compute_capability == 10: + # TODO: This multiplier should really be q_stage, wire up in later PR + # 1 cta handles 2*tile_m row + m_block_size_block = 2 * m_block_size + expected_m_blocks = (seqlen_q + m_block_size_block - 1) // m_block_size_block expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size block_sparse_tensors = normalize_block_sparse_tensors( block_sparse_tensors, @@ -286,12 +300,6 @@ def _flash_attn_fwd( else: causal, local = False, False - compute_capability = ( - torch.cuda.get_device_capability()[0] - if _compute_capability is None - else _compute_capability - ) - assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim. @@ -383,6 +391,10 @@ def _flash_attn_fwd( raise NotImplementedError( "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." ) + if is_split_kv: + raise NotImplementedError( + "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." + ) cute_aux_tensors = None if aux_tensors is not None: @@ -415,7 +427,6 @@ def _flash_attn_fwd( compute_capability, page_size not in [None, 128], # paged KV non-TMA ) - if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" @@ -442,8 +453,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - if sparse_tensors is not None: - raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -452,12 +461,15 @@ def _flash_attn_fwd( is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None and not is_split_kv, score_mod=score_mod, + mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], ) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index aa18566cb23..c5e0a7fe2bf 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -298,6 +298,10 @@ def apply_mask_sm100( mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -311,7 +315,7 @@ def apply_mask_sm100( n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True - if const_expr(not mask_causal and not mask_local): + if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): @@ -321,6 +325,36 @@ def apply_mask_sm100( acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case w/ mask_mod + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + row_coord_first = tScS_t2r[0][0] + global_row = row_coord_first + m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa != 1): + mask_row = global_row // self.qhead_per_kvhead_packgqa + else: + mask_row = global_row + mask_row_ssa = utils.scalar_to_ssa(mask_row, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_col = col_coord + n_block * self.tile_n + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + utils.scalar_to_ssa(global_col, cutlass.Int32), + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + out_of_bounds = (global_row >= self.seqlen_q) or (global_col >= self.seqlen_k) + acc_S[i] = -Float32.inf if out_of_bounds else acc_S[i] + else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 07e63e2bc7f..4c68fad0eba 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -28,8 +28,20 @@ random_doc_id_tensor, ) from flash_attn.cute.testing import attention_ref +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + + yield + + torch._dynamo.reset() + torch.cuda.empty_cache() + def create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ): @@ -142,6 +154,7 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tup (256, 256), (113, 203), (1024, 1024), + (128, 8192) ] @@ -208,6 +221,11 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): ) # Compute block sparsity for mask_mod + if COMPUTE_CAPABILITY == 10: + sparse_tile_m = 2 * tile_m + else: + sparse_tile_m = tile_m + bm = create_block_mask( mask_mod_flex, batch_size, @@ -215,7 +233,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): seqlen_q, seqlen_k, device="cuda", - BLOCK_SIZE=(tile_m, tile_n), + BLOCK_SIZE=(sparse_tile_m, tile_n), ) _, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = bm.as_tuple() @@ -348,6 +366,9 @@ def test_static_masks( - block_diagonal: Masks by 64-element diagonal blocks - mini_causal: Local causal within 128-element tiles """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -393,6 +414,9 @@ def test_parameterized_masks( - sliding_window: Requires window size and offset parameters - document: Slower to check """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -409,5 +433,50 @@ def test_parameterized_masks( ) +def test_sm100_block_sparse_sink_all_masked(): + """Block-sparse regression for the sink path""" + if torch.cuda.get_device_capability()[0] != 10: + pytest.skip("SM100-only test") + device = "cuda" + dtype = torch.bfloat16 + batch_size = 1 + seqlen_q = 256 + seqlen_k = 128 + nheads = 8 + headdim = 128 + q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device) + zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device) + zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device) + sparse = BlockSparseTensorsTorch( + mask_block_cnt=zero_cnt, + mask_block_idx=zero_idx, + full_block_cnt=zero_cnt, + full_block_idx=zero_idx, + ) + softmax_scale = 1.0 / math.sqrt(headdim) + _, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=learnable_sink, + m_block_size=128, + n_block_size=128, + num_threads=384, + pack_gqa=False, + block_sparse_tensors=sparse, + return_lse=True, + ) + # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit. + expected = learnable_sink.float()[None, :, None].expand_as(lse) + assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) From 43375aab2893018dfb7950db1cfa623c14946ad6 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 18 Nov 2025 16:10:00 -0800 Subject: [PATCH 238/258] [Cute,Sm100,Fwd] use correction warps for epi when not using TMA (#2014) * use correction warps for epi when varlen (non tma O) * properly enable fallback epilogue for varlen q * fix rebase errors * update tests --- flash_attn/cute/block_sparse_utils.py | 23 +++- flash_attn/cute/flash_fwd_sm100.py | 155 ++++++++++++++++++++------ flash_attn/cute/interface.py | 10 +- tests/cute/test_flash_attn.py | 22 ++-- 4 files changed, 158 insertions(+), 52 deletions(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index f117498fd2c..96a5dc2da84 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -5,7 +5,7 @@ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ -from typing import Callable +from typing import Callable, Optional from functools import partial import math import cutlass @@ -606,6 +606,9 @@ def handle_block_sparse_empty_tile_correction_sm100( o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Handle the block-sparse case where a tile is fully masked: * zero staged results @@ -650,18 +653,26 @@ def handle_block_sparse_empty_tile_correction_sm100( ) cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) - cute.arch.mbarrier_wait( - mbar_ptr + mbar_corr_epi_empty_offset + stage, - corr_epi_producer_phase, - ) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) correction_epilogue( thr_mma_pv, tOtOs[stage], tidx, + stage, + m_block, + seqlen.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, ) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 521e1325a8f..05520fca25d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -56,8 +56,8 @@ ) -# class NamedBarrierFwd(enum.IntEnum): -# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() # WarpSchedulerWG1 = enum.auto() # WarpSchedulerWG2 = enum.auto() # WarpSchedulerWG3 = enum.auto() @@ -85,6 +85,7 @@ def __init__( mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -112,6 +113,8 @@ def __init__( self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = is_local + self.is_varlen_q = is_varlen_q + self.use_correction_warps_for_epi = is_varlen_q self.qhead_per_kvhead = qhead_per_kvhead self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa @@ -146,8 +149,8 @@ def __init__( self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 - self.load_warp_ids = (13,) - self.epilogue_warp_ids = (14,) + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14,) self.empty_warp_ids = (15,) SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS @@ -164,6 +167,15 @@ def __init__( ) ) + if not self.use_tma_KV: + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + if self.use_correction_warps_for_epi: + self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids + self.epilogue_warp_ids = self.correction_warp_ids + elif self.is_varlen_q: # fallback + self.epilogue_warp_ids = (13, 14) + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -506,19 +518,11 @@ def __call__( self.cluster_layout_vmnk.shape, ) else: - assert self.use_tma_O, "Loading O and K/V will contend for the empty warp." - self.epilogue_warp_ids = (13,) - self.load_warp_ids = (14, 15) - self.empty_warp_ids = () tma_atom_K = None tma_atom_V = None o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) - # print(sO_layout.outer) - if const_expr(not self.use_tma_O): - self.epilogue_warp_ids = (14, 15) - self.empty_warp_ids = () self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) if const_expr(self.use_tma_O): tma_atom_O, mO = cpasync.make_tiled_tma_atom( @@ -546,7 +550,6 @@ def __call__( assert self.m_block_size % tO_layout.shape[0] == 0 vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - print("gmem_tiled_copy_O: ", gmem_tiled_copy_O) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -799,7 +802,7 @@ def kernel( cute.arch.mbarrier_init( mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE ) - if warp_idx == 4: + if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4: for i in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_init( mbar_ptr + self.mbar_corr_epi_full_offset + i, @@ -931,6 +934,12 @@ def kernel( if warp_idx == self.empty_warp_ids[0]: cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + if const_expr(len(self.empty_warp_ids) > 1): + if warp_idx == self.empty_warp_ids[1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + assert len(self.empty_warp_ids) <= 2 + # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// @@ -1004,19 +1013,20 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// - if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: - cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - self.epilogue_s2g( - mO, - sO, - gmem_tiled_copy_O, - tma_atom_O, - mbar_ptr, - block_info, - num_splits, - SeqlenInfoCls, - TileSchedulerCls, - ) + if const_expr(not self.use_correction_warps_for_epi): + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g( + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + ) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -1080,6 +1090,7 @@ def kernel( mLSE, sO, learnable_sink, + gmem_tiled_copy_O, tma_atom_O, mbar_ptr, softmax_scale_log2, @@ -1931,6 +1942,7 @@ def correction_loop( mLSE: cute.Tensor, sO: cute.Tensor, learnable_sink: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, @@ -1972,6 +1984,12 @@ def correction_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage @@ -2070,17 +2088,25 @@ def correction_loop( cute.arch.mbarrier_wait( mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase ) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase - ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) self.correction_epilogue( thr_mma_pv, tOtOs[stage], tidx, + stage, + m_block, + seqlen.seqlen_q, scale, sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so # mma warp can write to them cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) @@ -2090,6 +2116,11 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 else: + # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781 + if const_expr(self.use_correction_warps_for_epi): + gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O + else: + gmem_tiled_copy_O_for_empty_tile = None if const_expr(self.use_block_sparsity): ( softmax_corr_consumer_phase, @@ -2126,6 +2157,9 @@ def correction_loop( o_corr_consumer_phase, corr_epi_producer_phase, softmax_scale_log2, + mO_cur, + gO, + gmem_tiled_copy_O_for_empty_tile, ) if const_expr(mLSE is not None): @@ -2228,8 +2262,14 @@ def correction_epilogue( thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, tidx: Int32, + stage: Int32, + m_block: Int32, + seqlen_q: Int32, scale: Float32, sO: cute.Tensor, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Apply final scaling and transformation to attention output before writing to global memory. @@ -2302,6 +2342,57 @@ def correction_epilogue( space=cute.arch.SharedSpace.shared_cta, ) + if const_expr(self.use_correction_warps_for_epi): + assert(not self.use_tma_O) + assert(gmem_tiled_copy_O is not None) + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO, self.o_dtype) + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen_q, + ) + @cute.jit def epilogue_s2g( self, @@ -2389,7 +2480,7 @@ def epilogue_s2g( tOrO[None, rest_m, None], tOgO[None, rest_m, None, self.q_stage * m_block + stage], pred=tOpO[None, rest_m, None] - if self.check_hdim_v_oob + if const_expr(self.check_hdim_v_oob) else None, ) else: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index db7930de537..28bcb994ee7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -464,14 +464,16 @@ def _flash_attn_fwd( m_block_size=m_block_size, n_block_size=n_block_size, is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], + is_varlen_q=cu_seqlens_q is not None + or seqused_q is not None, ) else: raise ValueError( diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 14034fa9fd2..4b3398dd479 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -100,8 +100,8 @@ def test_flash_attn_output( mha_type, dtype, ): - if (causal or local) and seqlen_k < seqlen_q: - pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + # if (causal or local) and seqlen_k < seqlen_q: + # pytest.skip("Causal attention requires seqlen_k >= seqlen_q") device = "cuda" # set seed torch.random.manual_seed(0) @@ -228,7 +228,7 @@ def test_flash_attn_output( # pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1] # [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -267,6 +267,7 @@ def test_flash_attn_output( and learnable_sink is None # and mha_type == "mha" # and False + and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -388,7 +389,7 @@ def test_flash_attn_varlen_output( ): if ( causal or local - ): # Right now we only support causal attention with seqlen_k == seqlen_q + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed @@ -572,7 +573,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - pack_gqa_vals = [False, True, None] + # pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False] # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] @@ -721,8 +723,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @@ -738,14 +740,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) -# @pytest.mark.parametrize("varlen_q", [False, True]) -@pytest.mark.parametrize("varlen_q", [False]) +@pytest.mark.parametrize("varlen_q", [False, True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", From 3fcde4b345e37295c7a76a8d1e3dcb334cdff8c5 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 21 Nov 2025 17:19:08 +0000 Subject: [PATCH 239/258] Raise TypeError if out is specified when compiling _flash_attn_forward --- hopper/flash_attn_interface.py | 19 +++++++++++-------- hopper/setup.py | 2 ++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index d985eae51a6..44d1f027cb0 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -50,7 +50,7 @@ def _flash_attn_forward( k_new: Optional[torch.Tensor] = None, v_new: Optional[torch.Tensor] = None, qv: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, @@ -97,7 +97,7 @@ def _flash_attn_forward( k_new, v_new, qv, - out, + out_, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, @@ -144,7 +144,7 @@ def _flash_attn_forward_fake( k_new: Optional[torch.Tensor] = None, v_new: Optional[torch.Tensor] = None, qv: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, @@ -205,11 +205,14 @@ def _flash_attn_forward_fake( out_dtype = q_type # Create output tensor - if out is None: - if is_varlen_q: - out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) - else: - out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + if out_ is not None: + # If out_ is provided, _flash_attn_forward becomes non-functional + raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.") + + if is_varlen_q: + out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + else: + out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) # Create softmax_lse tensor if is_varlen_q: diff --git a/hopper/setup.py b/hopper/setup.py index 6ccb126c174..95729edabe2 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -103,6 +103,8 @@ def create_build_config_file(): "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + "FLASH_ATTENTION_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, + "FLASH_ATTENTION_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, } } From 052015a43fe9419f2ff5e30d6df5160b2b305c63 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 21 Nov 2025 12:38:09 -0800 Subject: [PATCH 240/258] add fastdivmod for oob reads in mask_mods (#2020) * add fastdivmod for oob reads in mask_mods * Updates for h100 --- flash_attn/cute/block_sparse_utils.py | 17 +++++++++-- flash_attn/cute/flash_fwd.py | 5 ++- flash_attn/cute/flash_fwd_sm100.py | 2 ++ flash_attn/cute/mask.py | 44 +++++++++++++++++++++------ flash_attn/cute/mask_definitions.py | 18 +++++++++++ tests/cute/test_mask_mod.py | 26 ++++++++++++++++ 6 files changed, 99 insertions(+), 13 deletions(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 96a5dc2da84..e814d6aa458 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -283,6 +283,7 @@ def consume_block_sparse_loads( score_mod_fn, O_should_accumulate, mask_mod, + fastdiv_mods, intra_wg_overlap: cutlass.Constexpr, warp_scheduler_barrier_sync: Callable, warp_scheduler_barrier_arrive: Callable, @@ -309,7 +310,12 @@ def consume_block_sparse_loads( kv_consumer_state, n_block=mask_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=True), + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), is_first_n_block=True, ) O_should_accumulate = True @@ -374,7 +380,12 @@ def consume_block_sparse_loads( kv_consumer_state = process_first_half_block( n_block=mask_n_block, kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=mask_mod), + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -394,7 +405,7 @@ def consume_block_sparse_loads( kv_consumer_state = process_first_half_block( n_block=full_n_block, kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=None), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, is_first_block=True, ) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 369bd1c81e6..0a4ded55d61 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -969,6 +969,7 @@ def preprocess_Q(): thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, + fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, ) # First iteration with seqlen masking @@ -1991,6 +1992,7 @@ def mma( mask_causal=self.is_causal, mask_local=self.is_local, aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) score_mod_fn = None if const_expr(self.score_mod is not None): @@ -2131,11 +2133,12 @@ def mma( score_mod_fn, O_should_accumulate, self.mask_mod, + fastdiv_mods, self.intra_wg_overlap, self.warp_scheduler_barrier_sync, self.warp_scheduler_barrier_arrive, ) - + # Handle empty case (when no blocks to process) if not processed_any: softmax.reset() diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 05520fca25d..625f4b3d14c 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1628,6 +1628,7 @@ def softmax_loop( mask_fn = partial( mask.apply_mask_sm100, mask_mod=block_mask_mod, + fastdiv_mods=fastdiv_mods, **shared_mask_kwargs, ) if const_expr(self.use_block_sparsity): @@ -1635,6 +1636,7 @@ def softmax_loop( mask_fn_none = partial( mask.apply_mask_sm100, mask_mod=None, + fastdiv_mods=fastdiv_mods, **shared_mask_kwargs, ) else: diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index c5e0a7fe2bf..aa3d1bba099 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -92,6 +92,7 @@ def apply_mask( mask_local: cutlass.Constexpr[bool] = False, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) @@ -131,24 +132,33 @@ def apply_mask( nrow = const_expr(cute.size(tScS_mn.shape[0])) ncol = const_expr(cute.size(tScS_mn.shape[1])) thr_col_offset = tScS_mn[0, 0][1] + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) for r in cutlass.range_constexpr(nrow): global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + row_for_mod = global_row_idx + if const_expr(wrap_aux_indices): + _, row_for_mod = fastdiv_mods[0].divmod(global_row_idx) for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n + col_for_mod = global_col_idx + if const_expr(wrap_aux_indices): + _, col_for_mod = fastdiv_mods[1].divmod(global_col_idx) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) - q_idx_ssa = utils.scalar_to_ssa( - tScS_mn[r, 0][0] + m_block * self.tile_m, cutlass.Int32 - ) - kv_idx_ssa = utils.scalar_to_ssa( - thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n, - cutlass.Int32, - ) + q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, head_idx_ssa, @@ -302,6 +312,7 @@ def apply_mask_sm100( batch_idx: Int32 = None, head_idx: Int32 = None, aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -328,6 +339,14 @@ def apply_mask_sm100( elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # Block sparse case w/ mask_mod + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) row_coord_first = tScS_t2r[0][0] @@ -336,17 +355,24 @@ def apply_mask_sm100( mask_row = global_row // self.qhead_per_kvhead_packgqa else: mask_row = global_row - mask_row_ssa = utils.scalar_to_ssa(mask_row, cutlass.Int32) + mask_row_for_mod = mask_row + if const_expr(wrap_aux_indices): + _, mask_row_for_mod = fastdiv_mods[0].divmod(mask_row) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) for i in cutlass.range_constexpr(ncol): col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] global_col = col_coord + n_block * self.tile_n + global_col_for_mod = global_col + if const_expr(wrap_aux_indices): + _, global_col_for_mod = fastdiv_mods[1].divmod(global_col) + kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, head_idx_ssa, mask_row_ssa, - utils.scalar_to_ssa(global_col, cutlass.Int32), + kv_idx_ssa, aux_tensors, ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index bbf2d212c0c..546adf17f37 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -201,6 +201,23 @@ def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): return in_window & dilated +def flex_ima_mask(b, h, q_idx, kv_idx, bias): + return kv_idx >= bias[kv_idx] + + +@cute.jit +def cute_ima_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + aux_tensors, +) -> cute.TensorSSA: + bias = aux_tensors[0] + threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32) + return n_idx >= threshold + + def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) for b in range(batch): @@ -226,6 +243,7 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), "dilated_sliding_window": (cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask), "document": (cute_document_mask, flex_document_mask), + "ima": (cute_ima_mask, flex_ima_mask), } PARAMETERIZED_MASK_FACTORIES = { diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 4c68fad0eba..52c09d03be9 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -211,6 +211,15 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) aux_tensors_arg = [doc_ids] + elif mask_name == "ima": + bias_threshold = (seqlen_k // 4) * 3 + bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device="cuda") + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): + return original_flex_mask(b, h, q_idx, kv_idx, bias) + + aux_tensors_arg = [bias] causal = False if causal and seqlen_k < seqlen_q: @@ -347,6 +356,23 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): ) +def test_mask_mod_ima_partial_block(): + _run_mask_test( + seqlen_q=257, + seqlen_k=257, + nheads=1, + kv_mode="mha", + headdim=128, + dtype=torch.bfloat16, + mask_name="ima", + window_size=None, + window_left=None, + window_right=None, + tile_m=128, + tile_n=128, + ) + + @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE) @pytest.mark.parametrize("nheads", [16]) @pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) From d063b333baae9c6066fe003be18c426eb602cbf3 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 21 Nov 2025 18:33:53 -0800 Subject: [PATCH 241/258] don't pass mask_fn to softmax_step generically (#2026) --- flash_attn/cute/flash_fwd_sm100.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 625f4b3d14c..6ce6c6d9e98 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1676,7 +1676,6 @@ def softmax_loop( seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, - mask_fn=partial(mask_fn, mask_seqlen=False), ) if has_work: From a986d0190ea33938c8495eb6641758c504e67be6 Mon Sep 17 00:00:00 2001 From: "Anakin(Yancheng) Zheng" <103552181+anakinxc@users.noreply.github.com> Date: Mon, 24 Nov 2025 09:51:17 +0800 Subject: [PATCH 242/258] swap order of decorators (#2029) --- flash_attn/cute/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 51a017e71a1..aa50c89c5bf 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -586,8 +586,8 @@ def cvt_f16(src: cute.Tensor, dst_or_dtype): dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) -@cute.jit @dsl_user_op +@cute.jit def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: deg = len(poly) - 1 out = poly[deg] @@ -596,8 +596,8 @@ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=N return out -@cute.jit @dsl_user_op +@cute.jit def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]: deg = len(poly) - 1 out = (poly[deg], poly[deg]) From 20cda05e6bfb4c266319065f6e38181878c9d02e Mon Sep 17 00:00:00 2001 From: jayhshah Date: Mon, 24 Nov 2025 17:33:08 -0800 Subject: [PATCH 243/258] [Cute,Bwd,Sm100] enable deterministic mode for sm100 bwd and fix race conditions (#2033) * enable deterministic mode for sm100 bwd and fix race conditions * turn off lpt scheduler for causal * use more regs for reduce when deterministic * make a src for tiled mma dK toggleable parameter, remove smem async fence for lse release * use 100k iterations for default --- flash_attn/cute/flash_bwd_sm100.py | 148 +++++--- flash_attn/cute/interface.py | 37 ++ flash_attn/cute/tile_scheduler.py | 15 +- flash_attn/cute/utils.py | 8 + tests/cute/test_flash_attn_race_condition.py | 341 +++++++++++++++++++ 5 files changed, 494 insertions(+), 55 deletions(-) create mode 100644 tests/cute/test_flash_attn_race_condition.py diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 0a29ce462a8..fb0e2e9b778 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -91,6 +91,7 @@ def __init__( # Speed optimizations, does not affect correctness self.shuffle_LSE = False self.shuffle_dPsum = False + self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal self.reduce_warp_ids = (0, 1, 2, 3) self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) @@ -146,7 +147,7 @@ def __init__( self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP - if not is_causal and not is_local: + if (not is_causal and not is_local) or deterministic: self.num_regs_reduce = 152 self.num_regs_compute = 136 else: @@ -203,6 +204,10 @@ def _get_tiled_mma(self): a_source=tcgen05.OperandSource.TMEM, ) # dK += dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dK_a_src = tcgen05.OperandSource.SMEM + else: + mma_dK_a_src = tcgen05.OperandSource.TMEM tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # dS_major_mode @@ -210,7 +215,7 @@ def _get_tiled_mma(self): self.acc_dtype, cta_group, self.mma_tiler_dsq[:2], - a_source=tcgen05.OperandSource.TMEM, + a_source=mma_dK_a_src, ) # dQ = dS @ K tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( @@ -403,13 +408,13 @@ def __call__( semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) if const_expr(self.deterministic): assert mdQ_semaphore is not None - mdQ_semaphore = utils.select(mdQ_semaphore.layout, mode=semaphore_transpose) + mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): assert mdK_semaphore is not None assert mdV_semaphore is not None mdK_semaphore, mdV_semaphore = [ - utils.select(t.layout, mode=semaphore_transpose) + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) ] else: @@ -546,15 +551,18 @@ def __call__( self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 - # TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler - TileScheduler = SingleTileScheduler - # TODO -- optimizer scheduler for causal + # TileScheduler = SingleTileScheduler + if const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + self.spt = self.is_causal and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), 1, # num_splits - cute.size(mK.shape[0]), + cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]), @@ -565,7 +573,7 @@ def __call__( qhead_per_kvhead_packgqa=1, element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, - lpt=False, + lpt=self.spt, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) @@ -1364,8 +1372,10 @@ def mma( tdPrV = tiled_mma_dP.make_fragment_A(sV) tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) # dK = dS.T @ Q - # tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) - tdKrdS = tiled_mma_dK.make_fragment_A(tdS) + if const_expr(self.use_smem_dS_for_mma_dK): + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + else: + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) @@ -1404,18 +1414,20 @@ def mma( # mma_dsk_fn = partial( # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) - # mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) - # Need to explicitly pass in tA_addr for correctness - mma_dsq_fn = partial( - gemm_ptx_w_idx, - tiled_mma_dK, - tdKtdK, - tdKrdS, - tdKrQ, - sA=None, - sB=sQt, - tA_addr=self.tmem_dS_offset, - ) + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + else: + # Need to explicitly pass in tA_addr for correctness + mma_dsq_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dK, + tdKtdK, + tdKrdS, + tdKrQ, + sA=None, + sB=sQt, + tA_addr=self.tmem_dS_offset, + ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage @@ -1486,18 +1498,29 @@ def mma( mma_qk_fn(B_idx=handle_Q_next.index) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # 2) dK = dS.T @ Q + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - # 3) dQ = dS @ K + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # dP uses the same tmem as dQ - # However, if dS is ready, then dP must have been ready, so we don't need to wait + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() @@ -1823,8 +1846,8 @@ def compute_loop( ) cute.arch.fence_view_async_tmem_store() + self.compute_sync_barrier.arrive_and_wait() - cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) @@ -1847,6 +1870,7 @@ def compute_loop( tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() + self.compute_sync_barrier.arrive_and_wait() tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] @@ -1875,22 +1899,20 @@ def compute_loop( if const_expr(stage == 0): pipeline_dS.producer_acquire(producer_state_dS) cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) - tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) - cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + if const_expr(not self.use_smem_dS_for_mma_dK): + tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) + cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) - cute.arch.fence_view_async_tmem_store() + if const_expr(not self.use_smem_dS_for_mma_dK): + cute.arch.fence_view_async_tmem_store() + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + self.compute_sync_barrier.arrive_and_wait() - cute.arch.sync_warp() # with cute.arch.elect_one(): # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) pipeline_dPsum.consumer_release(consumer_state_dPsum) consumer_state_dPsum.advance() - - cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta - ) - cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() @@ -2010,10 +2032,13 @@ def dQacc_reduce( gdQaccum = cute.flat_divide( gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) ) - mdQ_semaphore_cur = None + if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + delay_semaphore_release = self.is_causal + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM @@ -2025,11 +2050,6 @@ def dQacc_reduce( pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() - # semaphore acquire - if const_expr(self.deterministic): - barrier.wait_eq(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, n_block) - self.reduce_sync_barrier.arrive_and_wait() - gdQaccum_cur = gdQaccum[None, None, m_block] for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 @@ -2043,6 +2063,17 @@ def dQacc_reduce( cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta ) + # semaphore acquire + if const_expr(self.deterministic and stage == 0): + if const_expr(self.spt): + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n) + ) + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: @@ -2067,17 +2098,25 @@ def dQacc_reduce( # tdQrdQ_r2s[4 * i + 3], # utils.elem_pointer(tdQgdQ, 4 * i), # ) + # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): + if m_block > m_block_min: + barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1) # semaphore release # NOTE: arrive_inc calls red_release which issues membar - if const_expr(self.deterministic): - if tidx == 0: + if const_expr(self.deterministic and not delay_semaphore_release): + if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) - if warp_idx == 0: + if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2274,7 +2313,8 @@ def epilogue_dK_or_dV_tma( gdKV, (self.sdKV_flat_epi_tile,) ) # (tile_n * hdim / 2 / epi_stage, epi_stage) - if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 + if const_expr(deterministic_KV): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(self.qhead_per_kvhead == 1): @@ -2296,12 +2336,12 @@ def epilogue_dK_or_dV_tma( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) - read_flag = const_expr(not self.deterministic) + read_flag = const_expr(not deterministic_KV) pipeline_dKV.consumer_wait(consumer_state_dKV) # semaphore acquire - if const_expr(self.deterministic): + if const_expr(deterministic_KV): barrier.wait_eq( mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead ) @@ -2377,7 +2417,7 @@ def epilogue_dK_or_dV_tma( # semaphore release # NOTE: arrive_inc calls red_release which issues membar - if const_expr(self.deterministic): + if const_expr(deterministic_KV): if leader_warp: cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 28bcb994ee7..1e94453252e 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -561,6 +561,7 @@ def _flash_attn_bwd( cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + deterministic: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: compute_capability = torch.cuda.get_device_capability()[0] assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" @@ -659,6 +660,8 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 if compute_capability == 10: pack_gqa = False # override for now + if compute_capability != 10: + assert deterministic is False, "bwd deterministic only supported for sm100 for now" device = q.device # TODO: check if this is the right rounding @@ -757,6 +760,22 @@ def _flash_attn_bwd( else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] + if deterministic: + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") + else: + dQ_semaphore = None + + if deterministic and qhead_per_kvhead > 1: + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + else: + dK_semaphore = None + dV_semaphore = None + dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ + utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) + if t is not None else None + for t in (dQ_semaphore, dK_semaphore, dV_semaphore) + ] current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. @@ -831,6 +850,7 @@ def _flash_attn_bwd( num_threads, pack_gqa, cluster_size, + deterministic, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -885,6 +905,7 @@ def _flash_attn_bwd( # tile_n=n_block_size, cluster_size=cluster_size, # cluster_size=1, + deterministic=deterministic, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -904,6 +925,9 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) _flash_attn_bwd.compile_cache[compile_key]( q_tensor, @@ -921,6 +945,9 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + mdQ_semaphore=dQ_semaphore_tensor, + mdK_semaphore=dK_semaphore_tensor, + mdV_semaphore=dV_semaphore_tensor, ) num_threads = 256 if compute_capability == 9 else 128 @@ -1028,6 +1055,7 @@ def forward( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, @@ -1063,6 +1091,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod @@ -1078,6 +1107,7 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine @@ -1101,6 +1131,7 @@ def forward( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): out, lse = _flash_attn_fwd( q, @@ -1125,6 +1156,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap + ctx.deterministic = deterministic return out, lse @staticmethod @@ -1146,6 +1178,7 @@ def backward(ctx, dout, *args): cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, + deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) @@ -1162,6 +1195,7 @@ def flash_attn_func( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, @@ -1179,6 +1213,7 @@ def flash_attn_func( softcap, num_splits, pack_gqa, + deterministic, mask_mod, full_block_cnt, full_block_idx, @@ -1203,6 +1238,7 @@ def flash_attn_varlen_func( softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, + deterministic: bool = False, ): return FlashAttnVarlenFunc.apply( q, @@ -1220,6 +1256,7 @@ def flash_attn_varlen_func( softcap, num_splits, pack_gqa, + deterministic, ) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index f3a06c186e7..ad6ab099b0a 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -374,19 +374,28 @@ class SingleTileLPTBwdScheduler: @dataclass class Params(ParamsBase): total_blocks: Int32 + num_block: Int32 num_head_divmod: FastDivmod l2_minor_divmod: FastDivmod l2_major_divmod: FastDivmod l2_minor_residual_divmod: FastDivmod num_hb_quotient: Int32 cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + spt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTBwdScheduler.Params": - swizzle = 8 + size_l2 = 50 * 1024 * 1024 + size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + size_one_dqaccum_head = 0 + size_one_head = size_one_qdo_head + size_one_dqaccum_head + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 8 # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -396,6 +405,7 @@ def create( total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch, + num_block=num_block, num_head_divmod=FastDivmod.create(args.num_head), l2_minor_divmod=FastDivmod.create(swizzle), l2_major_divmod=FastDivmod.create(swizzle * num_block), @@ -404,6 +414,7 @@ def create( ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), cluster_shape_mn=args.cluster_shape_mn, + spt=args.lpt, ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): @@ -450,6 +461,8 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + if cutlass.const_expr(params.spt): + block = params.num_block - 1 - block return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index aa50c89c5bf..eb8b86cbe0b 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -71,6 +71,14 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) ) +def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor: + if stride_order is None: + stride_order = x.dim_order() + x_ = from_dlpack(x, assumed_align=alignment) + for i in range(x.ndim): + if i != leading_dim and (static_modes is None or i not in static_modes): + x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) + return x_ def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py new file mode 100644 index 00000000000..5cedc49d3c4 --- /dev/null +++ b/tests/cute/test_flash_attn_race_condition.py @@ -0,0 +1,341 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools +import os + +import pytest +import torch + +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, + _flash_attn_bwd, +) + + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (4224, 4224), + (2048, 4096), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, + seqlen_k, + d, + causal, + local, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + if (causal or local) and seqlen_k < seqlen_q: + pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + device = "cuda" + # set seed + torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 + pack_gqa_vals = [False] + # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + # pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and not local + and dv == d + and learnable_sink is None + # and mha_type == "mha" + # and False + and not ((causal or local) and seqlen_k < seqlen_q) + ): + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 100_000 + for i in range(num_iters): + dq2, dk2, dv2, = _flash_attn_bwd( + q, k, v, out, g, lse, + causal=causal, + deterministic=True, + ) + + diff_dq = (dq - dq2).abs() + max_idx = diff_dq.argmax() + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}") + + diff_dk = (dk - dk2).abs() + max_idx = diff_dk.argmax() + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}") + + diff_dv = (dv - dv2).abs() + max_idx = diff_dv.argmax() + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}") + + # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") + # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") + # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") + # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") + # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") + # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + + assert torch.equal(dq, dq2) + assert torch.equal(dk, dk2) + assert torch.equal(dv, dv2) + + print(f"✅ Iteration {i} passed!") + From 91942973d56c2cdcdbbc32fe7ecad6a274a0abde Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Mon, 24 Nov 2025 20:41:20 -0800 Subject: [PATCH 244/258] [NFC] Trivial fix to silence linter (#1928) Not much to see here, but this causes linter noise --- csrc/flash_attn/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index a7b5d36835d..c0c0e42176c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1340,7 +1340,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, - /*p_ptr=*/nullptr, + /*p_d=*/nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, From 5cc6fa48f93a1562d46c3abfd90192cd32c11775 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Mon, 24 Nov 2025 20:42:02 -0800 Subject: [PATCH 245/258] Add LICENSE and AUTHORS to flash_attn/cute (#2032) --- flash_attn/cute/AUTHORS | 1 + flash_attn/cute/LICENSE | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 flash_attn/cute/AUTHORS create mode 100644 flash_attn/cute/LICENSE diff --git a/flash_attn/cute/AUTHORS b/flash_attn/cute/AUTHORS new file mode 100644 index 00000000000..e35a781665e --- /dev/null +++ b/flash_attn/cute/AUTHORS @@ -0,0 +1 @@ +Tri Dao, trid@cs.stanford.edu \ No newline at end of file diff --git a/flash_attn/cute/LICENSE b/flash_attn/cute/LICENSE new file mode 100644 index 00000000000..5860e4b33f3 --- /dev/null +++ b/flash_attn/cute/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From 63b66f2cd988213d6a18c322a274c0045f1cf29c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Nov 2025 23:45:34 -0500 Subject: [PATCH 246/258] [Cute] Add authors --- flash_attn/cute/AUTHORS | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/AUTHORS b/flash_attn/cute/AUTHORS index e35a781665e..bc3991c676d 100644 --- a/flash_attn/cute/AUTHORS +++ b/flash_attn/cute/AUTHORS @@ -1 +1,5 @@ -Tri Dao, trid@cs.stanford.edu \ No newline at end of file +Tri Dao, tri@tridao.me +Jay Shah +Ted Zadouri +Markus Hoehnerbach +Vijay Thakkar \ No newline at end of file From 92ca9da8d66f7b34ff50dc080ec0fef9661260d6 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:43:48 -0500 Subject: [PATCH 247/258] [Cute,Fwd] enable mask mod without blocksparsity (#2031) --- flash_attn/cute/flash_fwd.py | 11 ++++++----- flash_attn/cute/flash_fwd_sm100.py | 18 ++++++++++++------ flash_attn/cute/interface.py | 4 ---- tests/cute/test_mask_mod.py | 13 ++++++++++--- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 0a4ded55d61..e341ac4feee 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -2047,7 +2047,7 @@ def mma( kv_consumer_state = process_first_half_block( n_block=n_block_max - 1, kv_consumer_state=kv_consumer_state, - mask_fn=mask_fn, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), score_mod_fn=score_mod_fn, is_first_block=True, ) @@ -2060,7 +2060,7 @@ def mma( n_block=n_block_max - 1, mma_pv_fn=partial(mma_pv_fn, zero_init=True), is_first_n_block=True, - mask_fn=partial(mask_fn, mask_seqlen=True), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) O_should_accumulate = True # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) @@ -2078,7 +2078,7 @@ def mma( kv_consumer_state, n_block=n_block_max - 1 - n_tile, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) @@ -2092,6 +2092,7 @@ def mma( kv_consumer_state, n_block=n_block_max - 1 - n_tile, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True # Separate iterations with local masking on the left @@ -2102,7 +2103,7 @@ def mma( kv_consumer_state, n_block=n_block_max - 1 - n_tile, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_seqlen=False), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True # Last "half" iteration @@ -2435,4 +2436,4 @@ def warp_scheduler_barrier_arrive(self): cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, - ) + ) \ No newline at end of file diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6ce6c6d9e98..2234d69ca99 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1624,10 +1624,10 @@ def softmax_loop( head_idx=head_idx, aux_tensors=aux_tensors, ) - block_mask_mod = self.mask_mod if const_expr(self.use_block_sparsity) else None + mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None mask_fn = partial( mask.apply_mask_sm100, - mask_mod=block_mask_mod, + mask_mod=mask_mod, fastdiv_mods=fastdiv_mods, **shared_mask_kwargs, ) @@ -1749,15 +1749,21 @@ def softmax_loop( ) ) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking + # The remaining iterations have no masking (but may still need mask_mod) n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block - ) + if const_expr(self.mask_mod is not None): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + else: + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + ) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 1e94453252e..4c3e52f46d5 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -369,10 +369,6 @@ def _flash_attn_fwd( ) if mask_mod is not None: - if not use_block_sparsity: - raise NotImplementedError( - "mask_mod requires the use of block sparsity. This will be fixed in a future PR." - ) if is_varlen: raise NotImplementedError( "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 52c09d03be9..9c2db48f22b 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -171,6 +171,7 @@ def _run_mask_test( window_right, tile_m, tile_n, + use_block_sparsity, ): torch.manual_seed(42) @@ -267,7 +268,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, - ) + ) if use_block_sparsity else None out_tuple = _flash_attn_fwd( q=tensors["q"], @@ -370,6 +371,7 @@ def test_mask_mod_ima_partial_block(): window_right=None, tile_m=128, tile_n=128, + use_block_sparsity=True, ) @@ -378,13 +380,14 @@ def test_mask_mod_ima_partial_block(): @pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) @pytest.mark.parametrize("headdim", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) @pytest.mark.parametrize( "mask_name", ["block_diagonal", "mini_causal"], ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) def test_static_masks( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, mask_name, tile_m, tile_n + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, tile_m, tile_n ): """Test static masks that don't require recompilation per seqlen pair. @@ -408,6 +411,7 @@ def test_static_masks( window_right=None, tile_m=tile_m, tile_n=tile_n, + use_block_sparsity=use_block_sparsity, ) @@ -416,6 +420,7 @@ def test_static_masks( @pytest.mark.parametrize("kv_mode", ["mha"]) @pytest.mark.parametrize("headdim", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) @pytest.mark.parametrize( "mask_name,window_size", [ @@ -429,7 +434,7 @@ def test_static_masks( ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) def test_parameterized_masks( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, mask_name, window_size, tile_m, tile_n + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n ): """Test parameterized masks that require recompilation per seqlen pair. @@ -456,6 +461,7 @@ def test_parameterized_masks( window_right=None, tile_m=tile_m, tile_n=tile_n, + use_block_sparsity=use_block_sparsity, ) @@ -506,3 +512,4 @@ def test_sm100_block_sparse_sink_all_masked(): if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) + \ No newline at end of file From 672381f72c927a4b4a92f30755dc5829c3d0eaa3 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:38:30 -0800 Subject: [PATCH 248/258] Bump pin (#2025) * Bump pin * Swtich to new fastdivmod * cleanup varlen on blackwell * Allow for only cute install --- benchmarks/benchmark_attn.py | 7 ++- flash_attn/cute/fast_math.py | 78 +------------------------- flash_attn/cute/flash_bwd_sm100.py | 24 +++++--- flash_attn/cute/flash_fwd.py | 10 ++-- flash_attn/cute/flash_fwd_combine.py | 20 +++---- flash_attn/cute/flash_fwd_sm100.py | 10 ++-- flash_attn/cute/mask.py | 8 +-- flash_attn/cute/paged_kv.py | 22 ++++++-- flash_attn/cute/pyproject.toml | 2 +- flash_attn/cute/softmax.py | 4 +- flash_attn/cute/tile_scheduler.py | 83 +++++++++++++++------------- tests/cute/test_flash_attn_varlen.py | 71 ++++++++++++++---------- 12 files changed, 155 insertions(+), 184 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 1a868e0a286..cb6bc44eae2 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -22,7 +22,12 @@ # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +except ImportError: + flash_attn_func = None + flash_attn_varlen_func = None from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python try: diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py index 943388fd291..c56ea89e798 100644 --- a/flash_attn/cute/fast_math.py +++ b/flash_attn/cute/fast_math.py @@ -1,12 +1,8 @@ # Copyright (c) 2025, Tri Dao. -from typing import Tuple - import cutlass import cutlass.cute as cute -from cutlass import Int32, Uint32 -from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm +from cutlass import Int32 @cute.jit @@ -23,75 +19,3 @@ def clz(x: Int32) -> Int32: res = Int32(i) done = True return res - - -def find_log2(x: Int32) -> Int32: - a: Int32 = Int32(31 - clz(x)) - return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. - - -@dsl_user_op -def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], - "mul.hi.u32 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -class FastDivmod: - def __init__( - self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None - ): - self.divisor = divisor - self.multiplier = multipler - self.shift_right = shift_right - self._loc = loc - - # called by host - @staticmethod - def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod": - """Construct the FastDivmod object, in host code. - This precomputes some values based on the divisor and is computationally expensive. - """ - p = Uint32(31 + find_log2(divisor)) - divisor_u32 = Uint32(divisor) - multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) - shift_right = Uint32(p - 32) - return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip) - - @cute.jit - def div(self, dividend: Int32) -> Int32: - return ( - Int32(umulhi(dividend, self.multiplier) >> self.shift_right) - if self.divisor != 1 - else dividend - ) - - def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: - quotient = self.div(dividend) - remainder = dividend - quotient * self.divisor - return quotient, remainder - - def __extract_mlir_values__(self): - values, self._values_pos = [], [] - for obj in [self.divisor, self.multiplier, self.shift_right]: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - obj_list = [] - for obj, n_items in zip( - [self.divisor, self.multiplier, self.shift_right], self._values_pos - ): - obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return FastDivmod(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index fb0e2e9b778..7fc45666638 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -414,8 +414,7 @@ def __call__( assert mdK_semaphore is not None assert mdV_semaphore is not None mdK_semaphore, mdV_semaphore = [ - utils.select(t, mode=semaphore_transpose) - for t in (mdK_semaphore, mdV_semaphore) + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) ] else: mdK_semaphore = None @@ -562,7 +561,7 @@ def __call__( cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]), 1, # num_splits - cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k + cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]), @@ -1905,7 +1904,9 @@ def compute_loop( if const_expr(not self.use_smem_dS_for_mma_dK): cute.arch.fence_view_async_tmem_store() - cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) self.compute_sync_barrier.arrive_and_wait() # with cute.arch.elect_one(): @@ -2032,7 +2033,7 @@ def dQacc_reduce( gdQaccum = cute.flat_divide( gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) ) - + if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] @@ -2068,12 +2069,17 @@ def dQacc_reduce( if const_expr(self.spt): n_block_max_for_m_block = min( n_block_global_max, - cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n) + cute.ceil_div( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, + self.tile_n, + ), ) lock_value = n_block_max_for_m_block - 1 - n_block else: lock_value = n_block - barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value) + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value + ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: @@ -2101,7 +2107,9 @@ def dQacc_reduce( # semaphore release for prior m_block if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: - barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1) + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1 + ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index e341ac4feee..57874f6559f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -44,7 +44,7 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_attn.cute.fast_math import FastDivmod +from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardBase: @@ -692,8 +692,8 @@ def __call__( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.kernel( @@ -1503,8 +1503,8 @@ def __call__( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.kernel( diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index b23ab8ba78e..02672e319de 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -14,8 +14,8 @@ from cutlass import Float32, Int32, const_expr from flash_attn.cute import utils -from flash_attn.cute.fast_math import FastDivmod from flash_attn.cute.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardCombine: @@ -257,9 +257,9 @@ class SharedStorage: num_head = mO_partial.shape[3] batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) - # Create FastDivmod objects for efficient division - seqlen_divmod = FastDivmod.create(seqlen) - head_divmod = FastDivmod.create(num_head) + # Create FastDivmodDivisor objects for efficient division + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) grid_dim = ( cute.ceil_div(seqlen * num_head, self.m_block_size), @@ -311,8 +311,8 @@ def kernel( gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_LSE: cute.TiledCopy, s2r_tiled_copy_LSE: cute.TiledCopy, - seqlen_divmod: FastDivmod, - head_divmod: FastDivmod, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, varlen: cutlass.Constexpr[bool], ): # Thread and block indices @@ -380,9 +380,9 @@ def kernel( mi = tLSEcLSE[0, 0, m][1] # Get m coordinate idx = m_block * self.m_block_size + mi if idx < max_idx: - # Calculate actual sequence position and head using FastDivmod + # Calculate actual sequence position and head using FastDivmodDivisor if const_expr(not varlen): - head_idx, m_idx = seqlen_divmod.divmod(idx) + head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen @@ -420,7 +420,7 @@ def kernel( mi = tOcO[0, m, 0][0] # m coordinate idx = m_block * self.m_block_size + mi if const_expr(not varlen): - tOhidx[m], tOmidx[m] = seqlen_divmod.divmod(idx) + tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) else: tOhidx[m] = idx // seqlen tOmidx[m] = idx - tOhidx[m] * seqlen @@ -536,7 +536,7 @@ def kernel( idx = m_block * self.m_block_size + mi if idx < max_idx: if const_expr(not varlen): - head_idx, m_idx = seqlen_divmod.divmod(idx) + head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 2234d69ca99..645ad97b003 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -45,7 +45,7 @@ from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.fast_math import FastDivmod +from cutlass.cute import FastDivmodDivisor from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, @@ -659,8 +659,8 @@ class SharedStorage: self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmod.create(seqlen_q) - seqlen_k_divmod = FastDivmod.create(seqlen_k) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) @@ -1190,7 +1190,7 @@ def load( mPageTable, mK, mV, - FastDivmod.create(page_size), + FastDivmodDivisor(page_size), batch_idx, head_idx_kv, tidx, @@ -2660,7 +2660,7 @@ def apply_score_mod( if cutlass.const_expr(aux_tensors is not None): seqlen_q_divmod, _ = fastdiv_mods - _, q_idx_logical = seqlen_q_divmod.divmod(q_idx_logical) + _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod) apply_score_mod_inner( tSrS_t2r, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index aa3d1bba099..da3ed8fb2d3 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -145,7 +145,7 @@ def apply_mask( global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m row_for_mod = global_row_idx if const_expr(wrap_aux_indices): - _, row_for_mod = fastdiv_mods[0].divmod(global_row_idx) + _, row_for_mod = divmod(global_row_idx, fastdiv_mods[0]) for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] @@ -153,7 +153,7 @@ def apply_mask( global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n col_for_mod = global_col_idx if const_expr(wrap_aux_indices): - _, col_for_mod = fastdiv_mods[1].divmod(global_col_idx) + _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) @@ -357,7 +357,7 @@ def apply_mask_sm100( mask_row = global_row mask_row_for_mod = mask_row if const_expr(wrap_aux_indices): - _, mask_row_for_mod = fastdiv_mods[0].divmod(mask_row) + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) @@ -366,7 +366,7 @@ def apply_mask_sm100( global_col = col_coord + n_block * self.tile_n global_col_for_mod = global_col if const_expr(wrap_aux_indices): - _, global_col_for_mod = fastdiv_mods[1].divmod(global_col) + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py index ccb2296b4a7..8b0949d1404 100644 --- a/flash_attn/cute/paged_kv.py +++ b/flash_attn/cute/paged_kv.py @@ -7,8 +7,8 @@ from cutlass import Int32, const_expr from flash_attn.cute import utils -from flash_attn.cute.fast_math import FastDivmod from flash_attn.cute.cute_dsl_utils import ParamsBase +from cutlass.cute import FastDivmodDivisor @dataclass @@ -18,7 +18,7 @@ class PagedKVManager(ParamsBase): mV_paged: cute.Tensor thread_idx: Int32 - page_size_divmod: FastDivmod + page_size_divmod: FastDivmodDivisor seqlen_k: Int32 leftpad_k: Int32 n_block_size: Int32 @@ -42,7 +42,7 @@ def create( mPageTable: cute.Tensor, mK_paged: cute.Tensor, mV_paged: cute.Tensor, - page_size_divmod: FastDivmod, + page_size_divmod: FastDivmodDivisor, bidb: Int32, bidh: Int32, thread_idx: Int32, @@ -118,7 +118,7 @@ def load_page_table(self, n_block: Int32): row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row row_idx = n_block * self.n_block_size + row - page_idx, page_offset = self.page_size_divmod.divmod(row_idx + self.leftpad_k) + page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod) is_valid = ( (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size @@ -173,4 +173,16 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): ) elif const_expr(K_or_V == "V"): # Don't need to clear out the rest of the smem for K since we'll mask out the scores anyway. - tXsX[None, m, None].fill(0) + fill_swizzled(tXsX[None, m, None], 0) + + +@cutlass.dsl_user_op +def fill_swizzled(tensor, value: cutlass.Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + """ + rTmp = cute.make_rmem_tensor_like(tensor, tensor.element_type) + rTmp.fill(value) + cute.autovec_copy(rTmp, tensor) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 1b21df4b227..8b5942b10d0 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl==4.3.0.dev0", + "nvidia-cutlass-dsl==4.3.0", "torch", "einops", "typing_extensions", diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 0ca08f3f2e3..658934ce753 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -392,12 +392,12 @@ def apply_score_mod_inner( if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed(index_tensor[i + j][0], qhead_per_kvhead) - _, q_idx_wrapped = seqlen_q_divmod.divmod(q_idx_floored) + _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) q_idx_vec[j] = q_idx_wrapped else: _, seqlen_k_divmod = fastdiv_mods - _, kv_idx_wrapped = seqlen_k_divmod.divmod(index_tensor[i + j][1]) + _, kv_idx_wrapped = divmod(index_tensor[i + j][1], seqlen_k_divmod) kv_idx_vec[j] = kv_idx_wrapped else: # No bounds checking - direct indexing diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ad6ab099b0a..ef47cedecdf 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -14,7 +14,8 @@ from cutlass import Int32, const_expr import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import FastDivmod, clz +from flash_attn.cute.fast_math import clz +from cutlass.cute import FastDivmodDivisor class WorkTileInfo(cutlass.utils.WorkTileInfo): @@ -80,7 +81,7 @@ class Params(ParamsBase): num_head: Int32 num_batch: Int32 num_splits: Int32 - num_splits_divmod: FastDivmod + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) @@ -93,7 +94,7 @@ def create( args.num_head, args.num_batch, args.num_splits, - FastDivmod.create(args.num_splits), + FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, ) @@ -133,7 +134,7 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord if const_expr(self.params.is_split_kv): - head_idx, split_idx = self.params.num_splits_divmod.divmod(head_idx) + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) else: split_idx = Int32(0) return WorkTileInfo( @@ -169,8 +170,8 @@ def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler: @dataclass class Params(ParamsBase): - num_block_divmod: FastDivmod - num_head_divmod: FastDivmod + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor total_blocks: Int32 @staticmethod @@ -179,7 +180,7 @@ def create( ) -> "StaticPersistentTileScheduler.Params": total_blocks = args.num_block * args.num_head * args.num_batch return StaticPersistentTileScheduler.Params( - FastDivmod.create(args.num_block), FastDivmod.create(args.num_head), total_blocks + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): @@ -211,8 +212,8 @@ def get_grid_shape( # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: - hn_idx, block_idx = self.params.num_block_divmod.divmod(self._tile_idx) - batch_idx, head_idx = self.params.num_head_divmod.divmod(hn_idx) + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks # if cute.arch.thread_idx()[0] == 0: # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) @@ -253,11 +254,13 @@ class SingleTileLPTScheduler: class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 - num_block_divmod: FastDivmod - num_head_divmod: FastDivmod - l2_minor_divmod: FastDivmod - l2_major_divmod: FastDivmod - l2_minor_residual_divmod: FastDivmod + num_block: Int32 + l2_minor: Int32 + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 is_split_kv: cutlass.Constexpr[bool] = False @@ -284,11 +287,13 @@ def create( num_hb_remainder = (args.num_head * args.num_batch) % swizzle return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, - num_block_divmod=FastDivmod.create(args.num_block), - num_head_divmod=FastDivmod.create(args.num_head), - l2_minor_divmod=FastDivmod.create(swizzle), - l2_major_divmod=FastDivmod.create(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmod.create( + num_block=args.num_block, + l2_minor=Int32(swizzle), + num_block_divmod=FastDivmodDivisor(args.num_block), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), @@ -327,18 +332,18 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params # Implement LPT scheduling coordinate calculation - bidhb, l2_mod = params.l2_major_divmod.divmod(self._tile_idx) + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: - block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: - block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) - bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual - batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block_divmod.divisor - 1 - block + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid @@ -375,10 +380,11 @@ class SingleTileLPTBwdScheduler: class Params(ParamsBase): total_blocks: Int32 num_block: Int32 - num_head_divmod: FastDivmod - l2_minor_divmod: FastDivmod - l2_major_divmod: FastDivmod - l2_minor_residual_divmod: FastDivmod + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) spt: cutlass.Constexpr[bool] = True @@ -406,10 +412,11 @@ def create( * args.num_head * args.num_batch, num_block=num_block, - num_head_divmod=FastDivmod.create(args.num_head), - l2_minor_divmod=FastDivmod.create(swizzle), - l2_major_divmod=FastDivmod.create(swizzle * num_block), - l2_minor_residual_divmod=FastDivmod.create( + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * num_block), + l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), @@ -448,16 +455,16 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] params = self.params # Implement LPT scheduling coordinate calculation - bidhb, l2_mod = params.l2_major_divmod.divmod(cluster_idx) + bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: - block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: - block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) - bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual - batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) is_valid = self._tile_idx < params.total_blocks bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 3a514664449..53d907eed94 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -29,7 +29,7 @@ def test_varlen( ): if min_seq_len > max_seq_len: pytest.skip("Skipping min_seq_len > max_seq_len") - + q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( batch_size=B, n_heads=H, @@ -40,30 +40,36 @@ def test_varlen( dtype=dtype ) - ok = check_backward_vs_torch_flash( - q, k, v, - cu_seqlens_q, cu_seqlens_k, - total_q=total_q, total_k=total_k, - softmax_scale=softmax_scale, + # SM100 (Blackwell) backward pass doesn't support varlen yet + compute_capability = torch.cuda.get_device_capability()[0] + skip_backward = (compute_capability == 10) + + ok = check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + total_q=total_q, total_k=total_k, + softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, + skip_backward=skip_backward, ) assert ok -def check_backward_vs_torch_flash( - q, k, v, - cu_seqlens_q=None, - cu_seqlens_k=None, - seqused_q=None, - seqused_k=None, +def check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, total_q=None, total_k=None, - softmax_scale=None, + softmax_scale=None, causal=True, mha_type='mha', softcap=0.0, - atol=3e-2, + atol=3e-2, rtol=3e-2, + skip_backward=False, ): assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" @@ -103,18 +109,27 @@ def clone_like(t): ) out_t = torch_flash_ref( - q_t, k_t, v_t, - cu_seqlens_q=cu_seqlens_q_t, - cu_seqlens_k=cu_seqlens_k_t, + q_t, k_t, v_t, + cu_seqlens_q=cu_seqlens_q_t, + cu_seqlens_k=cu_seqlens_k_t, seqused_q=seqused_q, seqused_k=seqused_k, total_q=total_q, total_k=total_k, - softmax_scale=softmax_scale, + softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, ) + + ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol) + if not ok_fwd: + return False + + # Skip backward if not supported (e.g., SM100 varlen) + if skip_backward: + return True + # Use the same upstream gradient to compare backward paths grad_out = torch.randn_like(out_fa) @@ -164,7 +179,7 @@ def generate_varlen_args( total_q = cu_seqlens_q[-1] total_k = cu_seqlens_k[-1] - + cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) @@ -187,15 +202,15 @@ def generate_varlen_args( # Simple for loop over batch dim implementation def torch_flash_ref( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor = None, - cu_seqlens_k: torch.Tensor = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_k: torch.Tensor = None, total_q: int = 0, total_k: int = 0, - softmax_scale: Optional[float] = None, - causal: bool = False, + softmax_scale: Optional[float] = None, + causal: bool = False, **kwargs ): @@ -255,7 +270,7 @@ def torch_flash_ref( for b in range(B): if hcseq_q is not None: q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) - qb = q[q_start:q_end] + qb = q[q_start:q_end] else: qb = q[b] @@ -266,7 +281,7 @@ def torch_flash_ref( else: kb = k[b] vb = v[b] - + qb = qb.permute(1, 0, 2).unsqueeze(0) kb = kb.permute(1, 0, 2).unsqueeze(0) vb = vb.permute(1, 0, 2).unsqueeze(0) From 91ba87d759fd0282eb67f11fbdfe60b4d5317bcc Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 2 Dec 2025 13:43:24 -0800 Subject: [PATCH 249/258] ruff all the smaller files (#2040) --- .pre-commit-config.yaml | 9 -- flash_attn/cute/copy_utils.py | 6 +- flash_attn/cute/flash_fwd_combine.py | 154 +++++++++++++++++++-------- flash_attn/cute/hopper_helpers.py | 1 - flash_attn/cute/pack_gqa.py | 2 - flash_attn/cute/testing.py | 20 +++- flash_attn/cute/utils.py | 91 ++++++++++++---- 7 files changed, 193 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67dcf8ba868..6118dfa2283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,19 +7,10 @@ repos: files: ^flash_attn/cute/.*\.py$ exclude: &cute_exclude | (?x)^flash_attn/cute/( - __init__| - copy_utils| - cute_dsl_utils| - fast_math| flash_bwd| flash_fwd| - flash_fwd_combine| flash_fwd_sm100| - hopper_helpers| interface| - pack_gqa| - testing| - utils )\.py$ - id: ruff-format files: ^flash_attn/cute/.*\.py$ diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index 45ec493aaa3..cfdcbdb80a0 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -1,11 +1,11 @@ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math -from typing import Optional, Type, Tuple, Callable +from typing import Optional, Type, Callable import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Boolean, const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cutlass_dsl import T, dsl_user_op @@ -279,7 +279,7 @@ def copy_bulk(src_idx, dst_idx, **new_kwargs): dst[None, dst_idx].iterator, size=size, **new_kwargs, - **kwargs + **kwargs, ) def copy_bulk_single_stage(**new_kwargs): diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 02672e319de..f97e127175d 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -55,8 +55,13 @@ def __init__( @staticmethod def can_implement( - dtype, dtype_partial, head_dim, m_block_size, k_block_size, - log_max_splits, num_threads, + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads, ) -> bool: """Check if the kernel can be implemented with the given parameters.""" if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: @@ -83,8 +88,7 @@ def _setup_attributes(self): assert self.k_block_size % async_copy_elems == 0 k_block_gmem = ( - 128 if self.k_block_size % 128 == 0 else - (64 if self.k_block_size % 64 == 0 else 32) + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) ) gmem_threads_per_row = k_block_gmem // async_copy_elems assert self.num_threads % gmem_threads_per_row == 0 @@ -111,16 +115,25 @@ def _setup_attributes(self): num_bits_per_copy=async_copy_elems * self.dtype.width, ) self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( - atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store + atom_universal_copy, + tOpartial_layout, + vOpartial_layout, # 4 vals per store ) # LSE copy setup with async copy (alignment = 1) lse_copy_bits = Float32.width # 1 element per copy, width is in bits m_block_smem = ( - 128 if self.m_block_size % 128 == 0 else - (64 if self.m_block_size % 64 == 0 else - (32 if self.m_block_size % 32 == 0 else - (16 if self.m_block_size % 16 == 0 else 8))) + 128 + if self.m_block_size % 128 == 0 + else ( + 64 + if self.m_block_size % 64 == 0 + else ( + 32 + if self.m_block_size % 32 == 0 + else (16 if self.m_block_size % 16 == 0 else 8) + ) + ) ) gmem_threads_per_row_lse = m_block_smem assert self.num_threads % gmem_threads_per_row_lse == 0 @@ -167,9 +180,7 @@ def _setup_attributes(self): else: smem_lse_swizzle = cute.make_swizzle(3, 2, 3) smem_layout_atom_lse = cute.make_composed_layout( - smem_lse_swizzle, - 0, - cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) ) self.smem_layout_lse = cute.tile_to_shape( smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) @@ -177,11 +188,9 @@ def _setup_attributes(self): # O partial shared memory layout (simple layout for pipeline stages) self.smem_layout_o = cute.make_ordered_layout( - (self.m_block_size, self.k_block_size, self.stages), - order=(1, 0, 2) + (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) ) - @cute.jit def __call__( self, @@ -200,38 +209,63 @@ def __call__( raise TypeError("O partial tensor must match dtype_partial") if const_expr(not (mO.element_type == self.dtype)): raise TypeError("O tensor must match dtype") - if const_expr(not mLSE_partial.element_type in [Float32]): + if const_expr(mLSE_partial.element_type not in [Float32]): raise TypeError("LSE partial tensor must be Float32") - if const_expr(mLSE is not None and not mLSE.element_type in [Float32]): + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") # Shape validation - input tensors are in user format, need to be converted to kernel format if const_expr(len(mO_partial.shape) not in [4, 5]): - raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)") + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) if const_expr(len(mLSE_partial.shape) not in [3, 4]): - raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)") + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) if const_expr(len(mO.shape) not in [3, 4]): - raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)") + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): - raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)") + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) - mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)] + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO_partial, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mO_partial, mO) + ] # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) - O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) - mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) # or (num_splits, total_q, h) -> (total_q, num_splits, h) LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] - mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose)) + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) # Determine if we have variable length sequences varlen = const_expr(cu_seqlens is not None or seqused is not None) @@ -243,9 +277,7 @@ class SharedStorage: sLSE: cute.struct.Align[ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 ] - sMaxValidSplit: cute.struct.Align[ - cute.struct.MemRange[Int32, self.m_block_size], 128 - ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] sO: cute.struct.Align[ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 ] @@ -255,7 +287,11 @@ class SharedStorage: # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) seqlen = mO_partial.shape[0] num_head = mO_partial.shape[3] - batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) # Create FastDivmodDivisor objects for efficient division seqlen_divmod = FastDivmodDivisor(seqlen) @@ -330,14 +366,18 @@ def kernel( # Handle semaphore reset if const_expr(semaphore_to_reset is not None): - if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and - k_block == cute.arch.grid_dim()[1] - 1 and - batch_idx == cute.arch.grid_dim()[2] - 1): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and batch_idx == cute.arch.grid_dim()[2] - 1 + ): semaphore_to_reset[0] = 0 # Get number of splits num_splits = ( - num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None) + num_splits_dynamic_ptr[batch_idx] + if const_expr(num_splits_dynamic_ptr is not None) else mLSE_partial.shape[1] ) # Handle variable length sequences using SeqlenInfo @@ -345,7 +385,7 @@ def kernel( batch_idx=batch_idx, seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, - seqused=seqused + seqused=seqused, ) seqlen, offset = seqlen_info.seqlen, seqlen_info.offset @@ -354,8 +394,9 @@ def kernel( max_idx = seqlen * num_head # Early exit for single split if dynamic - if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx): - + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + const_expr(not varlen) or m_block * self.m_block_size < max_idx + ): # =============================== # Step 1: Load LSE_partial from gmem to shared memory # =============================== @@ -390,7 +431,11 @@ def kernel( for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): si = tLSEcLSE[0, s, 0][0] # Get split coordinate if si < num_splits: - cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m]) + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) else: tLSEsLSE[None, s, m].fill(-Float32.inf) # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem @@ -424,7 +469,9 @@ def kernel( else: tOhidx[m] = idx // seqlen tOmidx[m] = idx - tOhidx[m] * seqlen - tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint() + tOrOptr[m] = utils.elem_pointer_i64( + mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]) + ).toint() if idx >= max_idx: tOhidx[m] = -1 @@ -483,7 +530,9 @@ def kernel( # Find max LSE value across splits threads_per_col = const_expr(self.smem_threads_per_col_lse) lse_max = utils.warp_reduce( - ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), op=cute.arch.fmax, width=threads_per_col, ) @@ -496,7 +545,9 @@ def kernel( # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) # Compute exp scales and sum - lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf LOG2_E = math.log2(math.e) lse_sum_cur = 0.0 for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): @@ -506,7 +557,9 @@ def kernel( lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) lse_sum[m] = utils.logf(lse_sum_cur) + lse_max # Normalize scales - inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + inv_sum = ( + 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ) ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) # Store the scales exp(lse - lse_logsum) back to smem cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) @@ -584,7 +637,10 @@ def kernel( # Accumulate scaled partial results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0 and scale[m] > 0.0: - tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32)) + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) # =============================== # Step 7: Write final O to gmem @@ -605,7 +661,9 @@ def kernel( # Write final results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0: - mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)) + mO_cur_copy = cute.tiled_divide( + mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,) + ) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_store if const_expr(self.is_even_k) or tOpO[k]: @@ -631,7 +689,9 @@ def load_O_partial( o_gmem_ptr = cute.make_ptr( tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 ) - mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_load @@ -640,5 +700,5 @@ def load_O_partial( gmem_tiled_copy_O_partial, # mO_partial_cur_copy[None, k_idx, split], utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], - tOsO_partial_cur[None, m, k] + tOsO_partial_cur[None, m, k], ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py index c98f85b568e..c6a1c301904 100644 --- a/flash_attn/cute/hopper_helpers.py +++ b/flash_attn/cute/hopper_helpers.py @@ -4,7 +4,6 @@ import cutlass.cute as cute from cutlass import Int32, Float32, Boolean, const_expr from cutlass.cute.nvgpu import warpgroup -from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import Numeric, dsl_user_op from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_og diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index 46d8dd38798..765e71307ad 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -1,7 +1,5 @@ # Copyright (c) 2025, Tri Dao. -import math -import operator import cutlass import cutlass.cute as cute diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 690d0145479..214ed09bc9e 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -99,7 +99,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 - padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) return padding_mask @@ -129,7 +131,9 @@ def generate_qkv( q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask ) - output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -138,7 +142,9 @@ def generate_qkv( ) seqused_q = None max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: @@ -256,7 +262,9 @@ def construct_local_mask( sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length + ), ) @@ -368,7 +376,9 @@ def attention_ref( key_leftpad=key_leftpad, device=q.device, ) - local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + local_mask = ( + torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + ) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index eb8b86cbe0b..f73f66cfccf 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -10,7 +10,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack @@ -24,9 +24,10 @@ cute.arch.calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2, - rnd=nvvm.RoundingModeKind.RN + rnd=nvvm.RoundingModeKind.RN, ) + def hash_callable(func: Callable) -> str: """Hash a callable based on the source code or bytecode and closure values.""" if hasattr(func, "__wrapped__"): @@ -62,6 +63,7 @@ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers): return scoremod_premask_fn + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -71,7 +73,10 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) ) -def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor: + +def convert_from_dlpack_leading_static( + x, leading_dim, alignment=16, static_modes=None, stride_order=None +) -> cute.Tensor: if stride_order is None: stride_order = x.dim_order() x_ = from_dlpack(x, assumed_align=alignment) @@ -80,6 +85,7 @@ def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_mode x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) return x_ + def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: @@ -258,7 +264,7 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: # the string here. swizzle_str = str(ptr.type.swizzle_type) # Extract the inner part "S" - match = re.search(r'S<(\d+),(\d+),(\d+)>', swizzle_str) + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) if match: b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) return cute.make_swizzle(b, m, s) @@ -298,6 +304,7 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: ) ) + @dsl_user_op def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: return log2f(a, loc=loc, ip=ip) * math.log(2.0) @@ -350,7 +357,11 @@ def fmax_reduce( # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) - local_max_0 = fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1]) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) local_max = [ local_max_0, fmax(res[2], res[3]), @@ -438,7 +449,9 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(x.stride) - assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) # HACK: we assume that applying the offset does not change the pointer alignment byte_offset = offset * x.element_type.width // 8 @@ -517,7 +530,10 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> return cutlass.Uint32( llvm.inline_asm( T.i32(), - [cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)], + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], "shr.s32 $0, $1, $2;", "=r,r,r", has_side_effects=False, @@ -543,7 +559,9 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> @dsl_user_op -def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32: +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" return cutlass.Int32( llvm.inline_asm( @@ -561,9 +579,11 @@ def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc @overload def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + @overload def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + @cute.jit def cvt_f16(src: cute.Tensor, dst_or_dtype): """Convert Float32 tensor to Float16/BFloat16. @@ -586,7 +606,9 @@ def cvt_f16(src: cute.Tensor, dst_or_dtype): dst = dst_or_dtype assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" - assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) assert src.element_type is Float32, "src must be Float32" dst_i32 = cute.recast_tensor(dst, cutlass.Int32) assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) @@ -606,7 +628,9 @@ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=N @dsl_user_op @cute.jit -def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]: +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: deg = len(poly) - 1 out = (poly[deg], poly[deg]) for i in cutlass.range_constexpr(deg - 1, -1, -1): @@ -621,7 +645,7 @@ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) llvm.inline_asm( T.f32(), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], - f"add.rm.ftz.f32 $0, $1, $2;", + "add.rm.ftz.f32 $0, $1, $2;", "=f,f,f", has_side_effects=False, is_align_stack=False, @@ -635,7 +659,10 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip= return cutlass.Float32( llvm.inline_asm( T.f32(), - [Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip)], + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], "{\n\t" ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" "mov.b32 x_rounded_i, $1;\n\t" @@ -657,7 +684,12 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip= @dsl_user_op def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: # We assume x <= 127.0 - poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) fp32_round_int = float(2**23 + 2**22) x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) @@ -674,11 +706,18 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: @dsl_user_op def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: # We assume x <= 127.0 and y <= 127.0 - poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625) + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) fp32_round_int = float(2**23 + 2**22) xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) - xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) # The integer floor of x & y are now in the last 8 bits of xy_rounded # We want the next 2 ops to round to nearest even. The rounding mode is important. xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) @@ -734,8 +773,12 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 + + @dsl_user_op -def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: assert isinstance(tensor.iterator, cute.Pointer) # We assume that applying the offset does not change the pointer alignment new_ptr = cute.make_ptr( @@ -751,9 +794,9 @@ def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, i def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(tensor.stride) - assert len(flat_coord_i64) == len( - flat_stride - ), "Coordinate and stride must have the same length" + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) assert isinstance(tensor.iterator, cute.Pointer) # HACK: we assume that applying the offset does not change the pointer alignment @@ -779,18 +822,20 @@ def coord_offset_i64( tensor.memspace, assumed_align=tensor.iterator.max_alignment, ) - new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))) + new_layout = cute.slice_( + tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)) + ) return cute.make_tensor(new_ptr, new_layout) @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: - """ Convert a scalar to a cute TensorSSA of shape (1,) and given dtype """ + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" vec = cute.make_fragment(1, dtype) vec[0] = a return vec.load() def ssa_to_scalar(val): - """ Could inline but nice for reflecting the above api """ - return val[0] \ No newline at end of file + """Could inline but nice for reflecting the above api""" + return val[0] From de6a6ad08b3d63a5f1acc7bf4dd7e248018a43d3 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 2 Dec 2025 13:58:14 -0800 Subject: [PATCH 250/258] [Flash] Fix head dim 64 bwd (#2035) --- flash_attn/cute/flash_bwd_sm100.py | 65 +++++++++++++++++++++--------- tests/cute/test_flash_attn.py | 2 +- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 7fc45666638..78506b77dba 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -320,11 +320,11 @@ def _setup_smem_layout(self): ) self.sdKV_epi_tile = ( self.tile_n, - 128 // (self.dk_dtype.width // 8), # 64 or 32 + min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] - self.num_epi_stages = (self.tile_hdim // 2) // self.sdKV_epi_tile[1] + # headdim_64 gets 1 stage + self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1]) self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages - # TODO: dK and dV could have different shapes if const_expr(self.qhead_per_kvhead == 1): self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( @@ -402,7 +402,7 @@ def __call__( else: layout_dKV_transpose = LSE_dPsum_dQaccum_transpose mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] - dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) + dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, b) mdO = utils.select(mdO, mode=dO_transpose) semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) @@ -524,15 +524,15 @@ def __call__( self.cluster_layout_vmnk.shape, ) dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( - self.cluster_shape_mnk, self.tiled_mma_dP.thr_id + self.cluster_shape_mnk, self.tiled_mma_dV.thr_id ) tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, dO_tma_op, mdO, cute.select(self.sdO_layout, mode=[0, 1, 2]), - self.mma_tiler_vdo, - self.tiled_mma_dP, + self.mma_tiler_pdo, + self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) @@ -580,6 +580,22 @@ def __call__( grid_dim = TileScheduler.get_grid_shape(tile_sched_params) # cute.printf("grid_dim = {}", grid_dim) + # Compute allocation sizes for shared buffers that are reused + # sQ is reused for sdK, sdO is reused for sdV + sQ_alloc_bytes = max( + cute.size_in_bytes(self.q_dtype, self.sQ_layout), + cute.size_in_bytes(self.dk_dtype, self.sdKV_layout), + ) + sdO_alloc_bytes = max( + cute.size_in_bytes(self.dv_dtype, self.sdKV_layout), + cute.size_in_bytes(self.do_dtype, self.sdO_layout), + ) + # Sanity check that layouts fit in allocation + sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout) + sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout) + assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation" + assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation" + @cute.struct class SharedStorage: Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] @@ -601,8 +617,10 @@ class SharedStorage: tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] # Smem tensors + + # sQ is reused for sdK which in the non-MHA case needs float32 sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(self.sQ_layout)], + cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], self.buffer_align_bytes, ] sK: cute.struct.Align[ @@ -613,8 +631,9 @@ class SharedStorage: cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], self.buffer_align_bytes, ] + # sdO is reused for sdV which in the non-MHA case needs float32 sdO: cute.struct.Align[ - cute.struct.MemRange[self.do_dtype, cute.cosize(self.sdO_layout)], + cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], self.buffer_align_bytes, ] sdS: cute.struct.Align[ @@ -879,15 +898,21 @@ def kernel( init_wait=True, ) - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sQt = cute.make_tensor(cute.recast_ptr(sQ.iterator, sQt_layout.inner), sQt_layout.outer) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype) + sQt = cute.make_tensor( + cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer + ) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) - sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) - sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, sdOt_layout.inner), sdOt_layout.outer) + sdO = storage.sdO.get_tensor( + sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype + ) + sdOt = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer + ) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) if const_expr(self.qhead_per_kvhead == 1): @@ -900,12 +925,10 @@ def kernel( else: sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype) sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype) - assert cute.size_in_bytes(self.do_dtype, sdO_layout) >= cute.size_in_bytes( - self.dv_dtype, sdKV_layout - ), "Not enough space for sdV" - assert cute.size_in_bytes(self.q_dtype, sQ_layout) >= cute.size_in_bytes( - self.dk_dtype, sdKV_layout - ), "Not enough space for sdK" + + # Buffer sizing is guaranteed by max(...) in SharedStorage declarations + # for both sQ (reused as sdK) and sdO (reused as sdV) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM @@ -993,6 +1016,7 @@ def kernel( self.load( thr_mma_S, thr_mma_dP, + thr_mma_dV, mQ, mK, mV, @@ -1138,6 +1162,7 @@ def load( self, thr_mma_S: cute.core.ThrMma, thr_mma_dP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, @@ -1206,7 +1231,7 @@ def load( gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) - tdPgdO = thr_mma_dP.partition_B(gdO) + tdPgdO = thr_mma_dV.partition_B(gdO) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 4b3398dd479..fc26fb34af8 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -56,7 +56,7 @@ # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From 26ba559ee1a618724c618198986101ae60258fde Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 2 Dec 2025 15:18:59 -0800 Subject: [PATCH 251/258] Add headdim64 tests (#2041) --- tests/cute/test_flash_attn_race_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 5cedc49d3c4..101e058d60e 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -57,7 +57,7 @@ # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From 56fdf3e232731535a4fa420a6cce53f72f3c10ba Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 5 Dec 2025 16:11:10 -0800 Subject: [PATCH 252/258] [Cute,Bwd,Sm100] Add local for sm100 bwd (#2046) * add local for sm100 bwd * add deterministic * update tests * ruff files * remove old code * move comment * override window_size = None for causal * revert to fwd test defaults --- flash_attn/cute/block_info.py | 16 +- flash_attn/cute/flash_bwd_sm100.py | 558 +++++++++++-------- flash_attn/cute/interface.py | 23 + flash_attn/cute/mask.py | 39 +- flash_attn/cute/testing.py | 6 +- tests/cute/test_flash_attn.py | 50 +- tests/cute/test_flash_attn_race_condition.py | 42 +- 7 files changed, 438 insertions(+), 296 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index eeaa0e3e740..be13e70f892 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -58,12 +58,16 @@ def get_n_block_min_max( def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 - if const_expr(self.is_causal): - m_block_min = max( - m_block_min, - (n_block * self.tile_n + seqlen_info.seqlen_q - seqlen_info.seqlen_k) - // self.tile_m, - ) + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right + m_block_min = max(m_block_min, m_idx_right // self.tile_m) + if const_expr(self.is_local and self.window_size_left is not None): + n_idx_max = (n_block + 1) * self.tile_n + m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_left = m_idx + self.window_size_left + m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) return m_block_min, m_block_max @cute.jit diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 78506b77dba..00c8cbf66d7 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -82,7 +82,7 @@ def __init__( self.cluster_shape_mn = (cluster_size, 1) self.is_persistent = is_persistent self.is_causal = is_causal - self.is_local = False + self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.use_tma_store = True @@ -384,11 +384,19 @@ def __call__( *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1], ) - (mdQaccum,) = [ + ( + mdQaccum, + mdK, + mdV, + ) = [ cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None - for t in (mdQaccum,) + for t in ( + mdQaccum, + mdK, + mdV, + ) ] layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) @@ -555,7 +563,8 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - self.spt = self.is_causal and self.deterministic + # reads n_blocks right-to-left + self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -657,6 +666,12 @@ class SharedStorage: LOG2_E = math.log2(math.e) softmax_scale_log2 = softmax_scale * LOG2_E + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -701,6 +716,8 @@ class SharedStorage: tiled_copy_r2s_dKV, softmax_scale, softmax_scale_log2, + window_size_left, + window_size_right, tile_sched_params, ).launch( grid=grid_dim, @@ -757,6 +774,8 @@ def kernel( tiled_copy_r2s_dKV: cute.TiledCopy, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], tile_sched_params: ParamsBase, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -975,8 +994,8 @@ def kernel( self.is_causal, self.is_local, False, # is_split_kv - None, - None, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( @@ -990,12 +1009,13 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # TODO: support local AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, ) # EMPTY @@ -1228,8 +1248,8 @@ def load( tdPgV = thr_mma_dP.partition_A(gV) gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_S.partition_B(gQ) - gLSE = cute.local_tile(mLSE_cur, (self.tile_n,), (None,)) - gdPsum = cute.local_tile(mPsum_cur, (self.tile_n,), (None,)) + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdPgdO = thr_mma_dV.partition_B(gdO) @@ -1272,80 +1292,83 @@ def load( # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - if const_expr(should_load_Q): - # K & Q - pipeline_Q.producer_acquire( - producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) - load_Q(m_block_min, producer_state=producer_state_Q_LSE) - pipeline_Q.producer_commit(producer_state_Q_LSE) - # LSE - pipeline_LSE.producer_acquire(producer_state_Q_LSE) - with cute.arch.elect_one(): - copy_stats( - gLSE[None, m_block_min], - sLSE[None, producer_state_Q_LSE.index], - mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), - ) - producer_state_Q_LSE.advance() - if const_expr(should_load_dO): - # V & dO - pipeline_dO.producer_acquire( - producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) - load_dO(m_block_min, producer_state=producer_state_dO_dPsum) - pipeline_dO.producer_commit(producer_state_dO_dPsum) - # dPsum - pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) - with cute.arch.elect_one(): - copy_stats( - gdPsum[None, m_block_min], - sdPsum[None, producer_state_dO_dPsum.index], - mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), - ) - producer_state_dO_dPsum.advance() - - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(not self.is_local) or m_block_min < m_block_max: + # First iteration: load K together w Q & LSE, then V together w dO & dPsum if const_expr(should_load_Q): - # Q - pipeline_Q.producer_acquire(producer_state_Q_LSE) - load_Q(m_block, producer_state=producer_state_Q_LSE) + # K & Q + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_min, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block], + gLSE[None, m_block_min], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): - # dO - pipeline_dO.producer_acquire(producer_state_dO_dPsum) - load_dO(m_block, producer_state=producer_state_dO_dPsum) + # V & dO + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_min, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block], + gdPsum[None, m_block_min], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() - if const_expr(should_load_Q): - pipeline_Q.producer_tail( - producer_state_Q_LSE.clone() - ) # will hang if we don't clone - pipeline_LSE.producer_tail(producer_state_Q_LSE) - if const_expr(should_load_dO): - pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) - pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + # Q + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + # dO + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + if const_expr(should_load_Q): + pipeline_Q.producer_tail( + producer_state_Q_LSE.clone() + ) # will hang if we don't clone + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1474,130 +1497,129 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) - - accumulate_dK = False - # ----------------------------------------------------------- - ###### Prologue - # ----------------------------------------------------------- - # 1. S = Q0 @ K.T - # 2. dP = V @ dO.T - # 3. dV = P @ dO - - # 1) S = Q0 @ K.T - handle_Q = pipeline_Q_consumer.wait_and_advance() - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_qk_fn(B_idx=handle_Q.index) - # Don't release Q yet - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # 2) dP = V @ dO.T - pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - # dQ uses the same tmem as dP - pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) - mma_dov_fn(B_idx=consumer_state_dO.index) - # Don't release dO yet - pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) - - producer_phase_acc ^= 1 - # 3) dV = P.T @ dO - # wait for P to be ready, which uses the same tmem as S - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) - pipeline_dO.consumer_release(consumer_state_dO) - consumer_state_dO.advance() - # ----------------------------------------------------------- - ###### MAIN LOOP - # ----------------------------------------------------------- - # 1. S = K @ Q.T - # 2. dQ = dS @ K - # 3. dK = dS.T @ Q - # 4. dP = V @ dO.T - # 5. dV = P.T @ dO - - for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - # 1) S = K @ Q_i - handle_Q_next = pipeline_Q_consumer.wait_and_advance() - # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready - mma_qk_fn(B_idx=handle_Q_next.index) + if const_expr(not self.is_local) or m_block_min < m_block_max: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dO.T + # 3. dV = P @ dO + # 1) S = Q0 @ K.T + handle_Q = pipeline_Q_consumer.wait_and_advance() + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=handle_Q.index) + # Don't release Q yet pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - # 2-3) - # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma - # Otherwise, reverse order - pipeline_dS.consumer_wait(consumer_state_dS) - - if const_expr(self.use_smem_dS_for_mma_dK): - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - else: - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - - # dP uses the same tmem as dQ - # However, if dS is ready, then dP must have been ready, - # so we don't need this wait before mma_dsk_fn() - # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - - pipeline_dS.consumer_release(consumer_state_dS) - consumer_state_dS.advance() - - # 4) dP = V @ dO.T + # 2) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dQ uses the same tmem as dP pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) + # Don't release dO yet pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) producer_phase_acc ^= 1 - # 5) dV += P @ dO + # 3) dV = P.T @ dO # wait for P to be ready, which uses the same tmem as S pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dO.T + # 5. dV = P.T @ dO + + for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # 1) S = K @ Q_i + handle_Q_next = pipeline_Q_consumer.wait_and_advance() + # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready + mma_qk_fn(B_idx=handle_Q_next.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order + pipeline_dS.consumer_wait(consumer_state_dS) + + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + + # dP uses the same tmem as dQ + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + # 4) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 5) dV += P @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + handle_Q = handle_Q_next - handle_Q = handle_Q_next - - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # signal to the epilogue that dV is ready - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) - - # ----------------------------------------------------------- - ###### Remaining 2 - # ----------------------------------------------------------- - # 1) dK += dS.T @ Q - pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - # signal to the epilogue that dK is ready - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - producer_phase_dKV ^= 1 - - # 2) dQ = dS @ K - # dS is done, so dP must have been ready, we don't need to wait - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier - handle_Q.release() - pipeline_dS.consumer_release(consumer_state_dS) - consumer_state_dS.advance() - - producer_phase_acc ^= 1 + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # signal to the epilogue that dV is ready + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + ###### Remaining 2 + # ----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + # signal to the epilogue that dK is ready + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + producer_phase_dKV ^= 1 + + # 2) dQ = dS @ K + # dS is done, so dP must have been ready, we don't need to wait + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier + handle_Q.release() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + producer_phase_acc ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1717,7 +1739,7 @@ def compute_loop( # 0: [256...384] # 1: [128...256] - tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # (128, 64) + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128 # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) @@ -1943,61 +1965,96 @@ def compute_loop( pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() - if const_expr(not self.use_tma_store): - consumer_state_dKV = self.epilogue_dKV( - dp_idx, - warp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dV, - thr_mma_dK, - tdVtdV, - tdKtdK, - mdV, - mdK, - pipeline_dKV, - consumer_state_dKV, - softmax_scale, - ) - else: - thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) - #### STORE dV - consumer_state_dKV = self.epilogue_dK_or_dV_tma( - dp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dV, - tdVtdV, - mdV_tma_tensor, - sdV, - tma_atom_dV, - thr_copy_r2s_dKV, - pipeline_dKV, - consumer_state_dKV, - None, # Don't scale - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id - mdV_semaphore, - ) - #### STORE dK - consumer_state_dKV = self.epilogue_dK_or_dV_tma( - dp_idx, - batch_idx, - head_idx, - n_block, - thr_mma_dK, - tdKtdK, - mdK_tma_tensor, - sdK, - tma_atom_dK, - thr_copy_r2s_dKV, - pipeline_dKV, - consumer_state_dKV, - softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, - int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id - mdK_semaphore, - ) + # Epilogue + if const_expr(not self.is_local) or m_block_min < m_block_max: + if const_expr(not self.use_tma_store): + consumer_state_dKV = self.epilogue_dKV( + dp_idx, + warp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + thr_mma_dK, + tdVtdV, + tdKtdK, + mdV, + mdK, + pipeline_dKV, + consumer_state_dKV, + softmax_scale, + ) + else: + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) + #### STORE dV + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dV, + tdVtdV, + mdV_tma_tensor, + sdV, + tma_atom_dV, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + None, # Don't scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdV_semaphore, + ) + #### STORE dK + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + thr_mma_dK, + tdKtdK, + mdK_tma_tensor, + sdK, + tma_atom_dK, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None, + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdK_semaphore, + ) + if const_expr(self.qhead_per_kvhead == 1 and self.is_local): + if m_block_min >= m_block_max: + # if tidx == 0: + # cute.printf("m_block_min = {}, m_block_max = {}", m_block_min, m_block_max) + # like other epis, currently assumes hdim == hdimv + gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( + self.dk_dtype, + self.tile_hdim, + 128, # num_threads + ) + gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) + tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV) + assert tdKgdK.shape[2] == 1 + assert tdVgdV.shape[2] == 1 + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV) + zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) + zero.fill(0.0) + if tidx < 128: + for i in cutlass.range_constexpr(tdKgdK.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0]) + else: + for i in cutlass.range_constexpr(tdVgdV.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0]) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2092,13 +2149,20 @@ def dQacc_reduce( # semaphore acquire if const_expr(self.deterministic and stage == 0): if const_expr(self.spt): - n_block_max_for_m_block = min( - n_block_global_max, - cute.ceil_div( - (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, - self.tile_n, - ), - ) + if const_expr( + self.is_causal or block_info.window_size_right is not None + ): + n_idx_right = ( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q + ) + if const_expr(block_info.window_size_right is not None): + n_idx_right += block_info.window_size_right + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div(n_idx_right, self.tile_n), + ) + else: + n_block_max_for_m_block = n_block_global_max lock_value = n_block_max_for_m_block - 1 - n_block else: lock_value = n_block @@ -2144,12 +2208,22 @@ def dQacc_reduce( self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) - if is_tma_warp: - cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - self.reduce_sync_barrier.arrive_and_wait() - # final semaphore release - if const_expr(self.deterministic and delay_semaphore_release): - barrier.arrive_inc(mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1) + if const_expr(not self.is_local) or m_block_min < m_block_max: + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1 + ) + + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2222,7 +2296,7 @@ def epilogue_dKV( dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) - gdV = cute.local_tile(mdV_cur, (self.tile_m, self.tile_hdimv), (None, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) gdV_tile = gdV[None, None, n_block] tdVgdV = thr_mma_dV.partition_C(gdV_tile) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4c3e52f46d5..651e9393135 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -295,6 +295,7 @@ def _flash_attn_fwd( if window_size_left is not None or window_size_right is not None: if window_size_left is None and window_size_right == 0: causal, local = True, False + window_size_right = None else: causal, local = False, True else: @@ -540,6 +541,8 @@ def _flash_attn_bwd( softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, @@ -575,6 +578,7 @@ def _flash_attn_bwd( AtomLayoutNdKV = 2 AtomLayoutMdQ = 1 cluster_size = 1 + assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" else: m_block_size = 128 n_block_size = 128 @@ -608,6 +612,16 @@ def _flash_attn_bwd( num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if local: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) @@ -840,6 +854,8 @@ def _flash_attn_bwd( head_dim_v, qhead_per_kvhead, causal, + window_size_left is not None, + window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, @@ -896,6 +912,7 @@ def _flash_attn_bwd( head_dim, head_dim_v, is_causal=causal, + is_local=local, qhead_per_kvhead=qhead_per_kvhead, # tile_m=m_block_size, # tile_n=n_block_size, @@ -921,6 +938,8 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, mdQ_semaphore=dQ_semaphore_tensor, mdK_semaphore=dK_semaphore_tensor, mdV_semaphore=dV_semaphore_tensor, @@ -941,6 +960,8 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + window_size_left=window_size_left, + window_size_right=window_size_right, mdQ_semaphore=dQ_semaphore_tensor, mdK_semaphore=dK_semaphore_tensor, mdV_semaphore=dV_semaphore_tensor, @@ -1103,6 +1124,8 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index da3ed8fb2d3..430c7d26fc5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -239,10 +239,10 @@ def apply_mask( ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right - if const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) @@ -411,10 +411,10 @@ def apply_mask_sm100( ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right - if const_expr(mask_seqlen): - col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) else: col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) @@ -447,28 +447,27 @@ def apply_mask_sm100_transposed( assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 + assert t0ScS_t2r[0][COL] == 0, "col0 == 0" thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): - if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + if seqlenk_col_limit <= 0: for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = -cutlass.Float32.inf else: # Causal or local thr_row_offset = tScS_t2r[0][ROW] - causal_row_offset = ( - seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset - ) + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + causal_offset = seqlenq_row_limit - seqlenk_col_limit if const_expr(mask_causal): - col0 = t0ScS_t2r[0][COL] - row_limit_top = col0 - causal_row_offset # tidx = cute.arch.thread_idx()[0] % 256 # if tidx < 32: - # cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0) + # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1]) + row_limit_top = causal_offset if const_expr(mask_seqlen): # If col is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. - if t0ScS_t2r[0][COL] >= seqlenk_col_limit: + if seqlenk_col_limit <= 0: row_limit_top = self.tile_m r2p = True if const_expr(not r2p): @@ -480,4 +479,18 @@ def apply_mask_sm100_transposed( num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 mask_r2p_transposed(acc_S, row_limit_top, num_rep) else: - assert False, "Local masking isn't supported yet" + if const_expr(self.window_size_right is not None): + row_limit_top = causal_offset - self.window_size_right + else: + row_limit_top = 0 + if const_expr(self.window_size_left is not None): + row_limit_bot = causal_offset + self.window_size_left + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 214ed09bc9e..a23a624d059 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -260,8 +260,12 @@ def construct_local_mask( return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + if window_size[1] is None: + local_mask_left = col_idx > sk + else: + local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk) return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + local_mask_left, torch.logical_and( col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length ), diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index fc26fb34af8..fe1d18afb6d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -29,7 +29,8 @@ DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" - +TEST_BWD_ONLY = False +VERBOSE = True # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -43,8 +44,8 @@ @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -92,7 +93,7 @@ def test_flash_attn_output( seqlen_k, d, causal, - local, + local_enum, softcap, deterministic, has_qv, @@ -100,8 +101,9 @@ def test_flash_attn_output( mha_type, dtype, ): - # if (causal or local) and seqlen_k < seqlen_q: - # pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + local = local_enum > 0 + if local and causal: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -115,7 +117,7 @@ def test_flash_attn_output( dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] @@ -157,6 +159,12 @@ def test_flash_attn_output( window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) @@ -228,7 +236,7 @@ def test_flash_attn_output( # pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -241,8 +249,9 @@ def test_flash_attn_output( # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, - # pack_gqa=pack_gqa, + pack_gqa=pack_gqa, num_splits=num_splits, + deterministic=deterministic, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -262,12 +271,9 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and not local and dv == d and learnable_sink is None - # and mha_type == "mha" # and False - and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -301,6 +307,26 @@ def test_flash_attn_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") + # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 101e058d60e..520cf6466a7 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -44,25 +44,17 @@ @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -# @pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64, 128, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) -# @pytest.mark.parametrize("d", [64, 96, 128, 192]) -# @pytest.mark.parametrize("d", [64, 128]) -# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (4224, 4224), - (2048, 4096), + (2000, 4000), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @@ -71,7 +63,7 @@ def test_flash_attn_output( seqlen_k, d, causal, - local, + local_enum, softcap, deterministic, has_qv, @@ -79,8 +71,9 @@ def test_flash_attn_output( mha_type, dtype, ): - if (causal or local) and seqlen_k < seqlen_q: - pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + local = local_enum > 0 + if local and causal: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -137,6 +130,12 @@ def test_flash_attn_output( window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) @@ -222,7 +221,7 @@ def test_flash_attn_output( # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, - # pack_gqa=pack_gqa, + pack_gqa=pack_gqa, num_splits=num_splits, deterministic=deterministic, ) @@ -244,12 +243,9 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and not local and dv == d and learnable_sink is None - # and mha_type == "mha" # and False - and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -303,11 +299,13 @@ def test_flash_attn_output( dv_pt - dv_ref ).abs().max().item() + dv_atol - num_iters = 100_000 + num_iters = 20_000 for i in range(num_iters): dq2, dk2, dv2, = _flash_attn_bwd( q, k, v, out, g, lse, causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], deterministic=True, ) From 4f1d8bb50c337e296468ba4e30a93fd3f29882b4 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:35:07 +0000 Subject: [PATCH 253/258] [Navi]add more triton config --- .../flash_attn_triton_amd/fwd_decode.py | 27 ++++++++++++++- .../flash_attn_triton_amd/fwd_prefill.py | 34 ++++++++++--------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 3f2d92c22d6..5c16cf4c552 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_rdna def get_cdna_autotune_configs(): return [ @@ -23,6 +23,26 @@ def get_cdna_autotune_configs(): num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] +def get_rdna_autotune_configs(): + return [ + # Most aggressive - 128x128 (best for large sequences) + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Large blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # Medium blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_cdna(): @@ -30,6 +50,11 @@ def get_autotune_configs(): fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + elif is_rdna(): + autotune_configs, autotune_keys = get_rdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) else: raise ValueError("Unknown Device Type") else: diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 6f69cd02813..7e814be4118 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,24 +186,25 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), + # === Configs for head_dim=64 (optimal) === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # === Configs for head_dim=128 (Wan2.2) - smaller blocks to reduce register pressure === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # === General fallback configs === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_rdna(): @@ -214,8 +215,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ + # Use BLOCK_N=32 to avoid register spilling on gfx1100 triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + {"BLOCK_M": 64, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), From 929a4bbc89cb52856b0400567849344e5a64793e Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:35:44 +0000 Subject: [PATCH 254/258] [Navi]enable exp2 by default --- flash_attn/flash_attn_triton_amd/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d3bf02e1f8..167b99e0d81 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -48,7 +48,7 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False + use_exp2: bool = True rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False From dc8b05da38ebed93060beff071676b0d326d3f3f Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:36:39 +0000 Subject: [PATCH 255/258] [Navi]Add support for arch gfx1100 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f0b476255ba..9b1bd10088a 100644 --- a/setup.py +++ b/setup.py @@ -197,7 +197,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100"] # Validate if each element in archs is in allowed_archs assert all( From 64c5924c5ac7a73fa5747e9759d410b325b95a2d Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:38:03 +0000 Subject: [PATCH 256/258] [ROCM]warp fa to support L2 cache aware to improve performance --- .../flash_attn_triton_amd/interface_fa.py | 233 ++++++++++++++- .../flash_attn_triton_amd/l2_cache_aware.py | 276 ++++++++++++++++++ 2 files changed, 507 insertions(+), 2 deletions(-) create mode 100644 flash_attn/flash_attn_triton_amd/l2_cache_aware.py diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 06ab7d24d56..c223ee93b6c 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -9,11 +9,17 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from .l2_cache_aware import is_head_grouping_beneficial, print_head_grouping_info from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union -def fwd(q: torch.Tensor, +# Environment variable to enable verbose head grouping output +L2_HEAD_GROUPING_DEBUG = os.environ.get('FLASH_ATTN_HEAD_GROUPING_DEBUG', '0') == '1' + + + +def _fwd_single_group(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor], @@ -31,6 +37,7 @@ def fwd(q: torch.Tensor, descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original fwd implementation for a single head group.""" if DEBUG: print() @@ -145,6 +152,112 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: bshd (batch, seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is bshd: [batch, seqlen, heads, head_dim] + batch, seqlen_q, nheads_q, head_dim = q.shape + seqlen_k = k.shape[1] + nheads_k = k.shape[2] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _fwd_single_group( + q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: bshd layout -> select heads on dim 2 + q_group = q[:, :, start_h:end_h, :].contiguous() + k_group = k[:, :, start_h:end_h, :].contiguous() + v_group = v[:, :, start_h:end_h, :].contiguous() + out_group = out[:, :, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _fwd_single_group( + q_group, k_group, v_group, out_group, alibi_group, + dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, :, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=1) # Assuming lse is [batch, heads, ...] + + return out, softmax_lse, None, rng_state + + + BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, @@ -349,7 +462,7 @@ def bwd( print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta -def varlen_fwd( +def _varlen_fwd_single_group( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -376,6 +489,7 @@ def varlen_fwd( descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original varlen_fwd implementation for a single head group.""" if DEBUG: print() @@ -490,6 +604,121 @@ def varlen_fwd( return out, softmax_lse, sd_mask, rng_state + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Variable-length flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: thd (total_seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is thd: [total_seqlen, heads, head_dim] + total_seqlen, nheads_q, head_dim = q.shape + nheads_k = k.shape[1] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _varlen_fwd_single_group( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, + block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, zero_tensors, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping varlen] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: thd layout -> select heads on dim 1 + q_group = q[:, start_h:end_h, :].contiguous() + k_group = k[:, start_h:end_h, :].contiguous() + v_group = v[:, start_h:end_h, :].contiguous() + out_group = out[:, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _varlen_fwd_single_group( + q_group, k_group, v_group, out_group, cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k, block_table_, alibi_group, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + zero_tensors, causal, window_size_left, window_size_right, + softcap, return_softmax, gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=0) # varlen lse is [heads, total_seqlen] + + return out, softmax_lse, None, rng_state + def varlen_bwd( dout: torch.Tensor, q: torch.Tensor, diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py new file mode 100644 index 00000000000..7dd851b41e4 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -0,0 +1,276 @@ +""" +L2 Cache-Aware Head Grouping for Flash Attention + +This module provides functionality to optimize flash attention by processing +heads in groups that fit in the L2 cache. This is particularly important for +consumer AMD GPUs like gfx1100 (RX 7900 XTX) where the L2 cache is smaller +than datacenter GPUs. + +The key insight is that for large sequence lengths, the K and V tensors for +all heads may exceed L2 cache capacity, causing cache thrashing. By processing +heads in groups that fit in L2, we can achieve up to 2x speedup. + +Example: gfx1100 with 96MB L2, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB L2) +- K,V for 10 heads = 88 MB (fits in 96 MB L2) +- Processing 10 heads at a time gives 1.95x speedup +""" + +import os +import functools +from typing import Optional, Tuple, Dict +import torch + +# L2 cache sizes for AMD GPUs in bytes +# Source: AMD documentation and hardware specs +AMD_L2_CACHE_SIZES: Dict[str, int] = { + # RDNA3 workstaion + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB +} + +# Environment variable to override L2 cache size (in MB) +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" +# Environment variable to disable head grouping +DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" + +# Cached L2 size per device +_l2_cache_size_cache: Dict[int, int] = {} + + +@functools.lru_cache(maxsize=None) +def get_gcn_arch_name(device_index: int = 0) -> str: + """Get the GCN architecture name for an AMD GPU.""" + try: + props = torch.cuda.get_device_properties(device_index) + if hasattr(props, 'gcnArchName'): + return props.gcnArchName + # Fallback: try to get from name + name = props.name.lower() + if 'gfx' in name: + # Extract gfxXXXX from name + import re + match = re.search(r'gfx\d+', name) + if match: + return match.group() + except Exception: + pass + return "unknown" + + +def get_l2_cache_size(device_index: int = 0) -> int: + """ + Get L2 cache size for the specified GPU device. + + Returns: + L2 cache size in bytes + """ + global _l2_cache_size_cache + + if device_index in _l2_cache_size_cache: + return _l2_cache_size_cache[device_index] + + # Check for environment override + if L2_CACHE_OVERRIDE_ENV in os.environ: + try: + size_mb = int(os.environ[L2_CACHE_OVERRIDE_ENV]) + size_bytes = size_mb * 1024 * 1024 + _l2_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass + + # Get architecture and look up cache size + arch = get_gcn_arch_name(device_index) + + # Check exact match first + if arch in AMD_L2_CACHE_SIZES: + size = AMD_L2_CACHE_SIZES[arch] + _l2_cache_size_cache[device_index] = size + return size + + # Check prefix match (e.g., gfx1100 matches gfx1100) + for known_arch, size in AMD_L2_CACHE_SIZES.items(): + if arch.startswith(known_arch): + _l2_cache_size_cache[device_index] = size + return size + + # Default: assume 96 MB (conservative for RDNA3) + default_size = 96 * 1024 * 1024 + _l2_cache_size_cache[device_index] = default_size + return default_size + + +def calculate_optimal_head_group_size( + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + l2_utilization: float = 1.0 #use higher utilization by default to improve 1280x720 performance +) -> int: + """ + Calculate the optimal number of heads to process together to fit K,V in L2. + + The calculation is: + K,V memory for N heads = N * seqlen_k * head_dim * dtype_size * 2 (for K and V) + + We want: K,V memory <= L2_cache * utilization + So: N <= (L2_cache * utilization) / (seqlen_k * head_dim * dtype_size * 2) + + Args: + seqlen_k: Sequence length of K/V + head_dim: Head dimension + dtype: Data type of tensors + device_index: GPU device index + l2_utilization: Fraction of L2 to target (default 0.9 to leave room for Q) + + Returns: + Optimal number of heads to process together (minimum 1) + """ + l2_size = get_l2_cache_size(device_index) + + # Get element size in bytes + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: + elem_size = 1 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 # Default to fp16 + + # Memory for K and V per head + kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V + + # Target L2 usage (leave some room for Q and other data) + target_l2 = int(l2_size * l2_utilization) + + # Calculate number of heads that fit + if kv_per_head == 0: + return 1 + + head_group_size = max(1, target_l2 // kv_per_head) + + return head_group_size + + +def is_head_grouping_beneficial( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + threshold_ratio: float = 1.5 +) -> Tuple[bool, int]: + """ + Determine if head grouping would be beneficial and return optimal group size. + + Head grouping is beneficial when: + 1. Total K,V memory exceeds L2 cache size by a significant margin + 2. Processing in groups allows K,V to fit in L2 + 3. The overhead of multiple kernel launches is worth the cache benefit + + Args: + nheads: Number of attention heads + seqlen_k: Sequence length of K/V + head_dim: Head dimension + dtype: Data type + device_index: GPU device index + threshold_ratio: K,V must exceed L2 by this ratio to enable grouping + + Returns: + (should_group, group_size): Whether to group and the optimal group size + """ + # Check if disabled via environment + if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": + return False, nheads + + l2_size = get_l2_cache_size(device_index) + + # Get element size + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: + elem_size = 1 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + # Total K,V memory for all heads + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + + # Only group if K,V significantly exceeds L2 + if total_kv < l2_size * threshold_ratio: + return False, nheads + + # Calculate optimal group size + group_size = calculate_optimal_head_group_size( + seqlen_k, head_dim, dtype, device_index + ) + + # Only group if we'd have at least 2 groups + # (otherwise grouping adds overhead with no benefit) + if group_size >= nheads: + return False, nheads + + # Minimum group size to avoid excessive kernel launches + min_group_size = max(1, nheads // 16) # At most 16 groups + group_size = max(group_size, min_group_size) + + return True, min(group_size, nheads) + + +def print_head_grouping_info( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0 +): + """Print diagnostic information about head grouping.""" + l2_size = get_l2_cache_size(device_index) + arch = get_gcn_arch_name(device_index) + + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + should_group, group_size = is_head_grouping_beneficial( + nheads, seqlen_k, head_dim, dtype, device_index + ) + + print(f"\n=== L2 Cache-Aware Head Grouping ===") + print(f"GPU: {arch}") + print(f"L2 Cache: {l2_size / (1024*1024):.1f} MB") + print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") + print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") + print(f"L2 Ratio: {total_kv / l2_size:.2f}x") + print(f"Should Group: {should_group}") + if should_group: + kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 + num_groups = (nheads + group_size - 1) // group_size + print(f"Group Size: {group_size} heads ({num_groups} groups)") + print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") + print("=" * 40 + "\n") From e935f3b5e7b7f36a8b8a017ddae807e03fe561a9 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 06:43:43 +0000 Subject: [PATCH 257/258] [Navi]renaming L2 cache to Infinity Cache (LLC) to avoid confusion --- .../flash_attn_triton_amd/l2_cache_aware.py | 203 +++++++++--------- 1 file changed, 99 insertions(+), 104 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py index 7dd851b41e4..f506e330aae 100644 --- a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -1,18 +1,19 @@ """ -L2 Cache-Aware Head Grouping for Flash Attention +Infinity Cache (LLC) Aware Head Grouping for Flash Attention This module provides functionality to optimize flash attention by processing -heads in groups that fit in the L2 cache. This is particularly important for -consumer AMD GPUs like gfx1100 (RX 7900 XTX) where the L2 cache is smaller -than datacenter GPUs. +heads in groups that fit in the Last Level Cache (LLC / Infinity Cache). -The key insight is that for large sequence lengths, the K and V tensors for -all heads may exceed L2 cache capacity, causing cache thrashing. By processing -heads in groups that fit in L2, we can achieve up to 2x speedup. +AMD RDNA3 cache hierarchy: +- L2 Cache: 6 MB (per-die, fast) +- Infinity Cache (L3/LLC): 96 MB (acts as memory-side cache) -Example: gfx1100 with 96MB L2, 40 heads, seqlen=17160, head_dim=128 -- K,V for all 40 heads = 352 MB (exceeds 96 MB L2) -- K,V for 10 heads = 88 MB (fits in 96 MB L2) +For large sequence lengths, we want K,V to fit in the 96 MB Infinity Cache. +By processing heads in groups that fit, we achieve up to 2x speedup. + +Example: gfx1100 with 96MB Infinity Cache, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB LLC) +- K,V for 10 heads = 88 MB (fits in 96 MB LLC) - Processing 10 heads at a time gives 1.95x speedup """ @@ -21,20 +22,28 @@ from typing import Optional, Tuple, Dict import torch -# L2 cache sizes for AMD GPUs in bytes -# Source: AMD documentation and hardware specs -AMD_L2_CACHE_SIZES: Dict[str, int] = { - # RDNA3 workstaion - "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB +# Infinity Cache (LLC) sizes for AMD GPUs in bytes +# Note: This is the L3/Infinity Cache, NOT the L2 cache +# RDNA3: L2=6MB, Infinity Cache (LLC)=96MB +AMD_LLC_CACHE_SIZES: Dict[str, int] = { + # RDNA3 consumer + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB Infinity Cache + "gfx1101": 64 * 1024 * 1024, # RX 7800 XT - 64 MB Infinity Cache + "gfx1102": 32 * 1024 * 1024, # RX 7600 - 32 MB Infinity Cache } -# Environment variable to override L2 cache size (in MB) -L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" +# Legacy alias for backwards compatibility +AMD_L2_CACHE_SIZES = AMD_LLC_CACHE_SIZES + +# Environment variable to override LLC cache size (in MB) +LLC_CACHE_OVERRIDE_ENV = "FLASH_ATTN_LLC_CACHE_MB" +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" # Legacy alias + # Environment variable to disable head grouping DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" -# Cached L2 size per device -_l2_cache_size_cache: Dict[int, int] = {} +# Cached LLC size per device +_llc_cache_size_cache: Dict[int, int] = {} @functools.lru_cache(maxsize=None) @@ -57,90 +66,100 @@ def get_gcn_arch_name(device_index: int = 0) -> str: return "unknown" -def get_l2_cache_size(device_index: int = 0) -> int: +def get_num_cus(device_index: int = 0) -> int: """ - Get L2 cache size for the specified GPU device. + Get the number of Compute Units for an AMD GPU. + + Note: PyTorch's multi_processor_count may be incorrect for some AMD GPUs. + We use known values for common architectures. + """ + arch = get_gcn_arch_name(device_index) + + # Known CU counts for common GPUs + known_cus = { + "gfx1100": 96, # RX 7900 XTX + "gfx1101": 60, # RX 7800 XT + "gfx1102": 32, # RX 7600 + } + + if arch in known_cus: + return known_cus[arch] + + # Fallback to PyTorch (may be incorrect) + try: + props = torch.cuda.get_device_properties(device_index) + return props.multi_processor_count + except Exception: + return 96 # Default + + +def get_llc_cache_size(device_index: int = 0) -> int: + """ + Get Infinity Cache (LLC) size for the specified GPU device. + + For RDNA3, this is the 96 MB Infinity Cache, not the 6 MB L2. Returns: - L2 cache size in bytes + LLC cache size in bytes """ - global _l2_cache_size_cache + global _llc_cache_size_cache - if device_index in _l2_cache_size_cache: - return _l2_cache_size_cache[device_index] + if device_index in _llc_cache_size_cache: + return _llc_cache_size_cache[device_index] - # Check for environment override - if L2_CACHE_OVERRIDE_ENV in os.environ: - try: - size_mb = int(os.environ[L2_CACHE_OVERRIDE_ENV]) - size_bytes = size_mb * 1024 * 1024 - _l2_cache_size_cache[device_index] = size_bytes - return size_bytes - except ValueError: - pass + # Check for environment override (new name first, then legacy) + for env_var in [LLC_CACHE_OVERRIDE_ENV, L2_CACHE_OVERRIDE_ENV]: + if env_var in os.environ: + try: + size_mb = int(os.environ[env_var]) + size_bytes = size_mb * 1024 * 1024 + _llc_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass # Get architecture and look up cache size arch = get_gcn_arch_name(device_index) # Check exact match first - if arch in AMD_L2_CACHE_SIZES: - size = AMD_L2_CACHE_SIZES[arch] - _l2_cache_size_cache[device_index] = size + if arch in AMD_LLC_CACHE_SIZES: + size = AMD_LLC_CACHE_SIZES[arch] + _llc_cache_size_cache[device_index] = size return size # Check prefix match (e.g., gfx1100 matches gfx1100) - for known_arch, size in AMD_L2_CACHE_SIZES.items(): + for known_arch, size in AMD_LLC_CACHE_SIZES.items(): if arch.startswith(known_arch): - _l2_cache_size_cache[device_index] = size + _llc_cache_size_cache[device_index] = size return size # Default: assume 96 MB (conservative for RDNA3) default_size = 96 * 1024 * 1024 - _l2_cache_size_cache[device_index] = default_size + _llc_cache_size_cache[device_index] = default_size return default_size +# Legacy alias +get_l2_cache_size = get_llc_cache_size + + def calculate_optimal_head_group_size( seqlen_k: int, head_dim: int, dtype: torch.dtype, device_index: int = 0, - l2_utilization: float = 1.0 #use higher utilization by default to improve 1280x720 performance + llc_utilization: float = 1.0 # Use higher utilization by default ) -> int: """ - Calculate the optimal number of heads to process together to fit K,V in L2. - - The calculation is: - K,V memory for N heads = N * seqlen_k * head_dim * dtype_size * 2 (for K and V) - - We want: K,V memory <= L2_cache * utilization - So: N <= (L2_cache * utilization) / (seqlen_k * head_dim * dtype_size * 2) - - Args: - seqlen_k: Sequence length of K/V - head_dim: Head dimension - dtype: Data type of tensors - device_index: GPU device index - l2_utilization: Fraction of L2 to target (default 0.9 to leave room for Q) - - Returns: - Optimal number of heads to process together (minimum 1) + Calculate the optimal number of heads to process together to fit K,V in LLC. """ - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) # Get element size in bytes if dtype in (torch.float16, torch.bfloat16): elem_size = 2 elif dtype == torch.float32: elem_size = 4 - elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: - elem_size = 1 elif 'float8' in str(dtype).lower(): elem_size = 1 else: @@ -149,14 +168,14 @@ def calculate_optimal_head_group_size( # Memory for K and V per head kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V - # Target L2 usage (leave some room for Q and other data) - target_l2 = int(l2_size * l2_utilization) + # Target LLC usage + target_llc = int(llc_size * llc_utilization) # Calculate number of heads that fit if kv_per_head == 0: return 1 - head_group_size = max(1, target_l2 // kv_per_head) + head_group_size = max(1, target_llc // kv_per_head) return head_group_size @@ -171,42 +190,18 @@ def is_head_grouping_beneficial( ) -> Tuple[bool, int]: """ Determine if head grouping would be beneficial and return optimal group size. - - Head grouping is beneficial when: - 1. Total K,V memory exceeds L2 cache size by a significant margin - 2. Processing in groups allows K,V to fit in L2 - 3. The overhead of multiple kernel launches is worth the cache benefit - - Args: - nheads: Number of attention heads - seqlen_k: Sequence length of K/V - head_dim: Head dimension - dtype: Data type - device_index: GPU device index - threshold_ratio: K,V must exceed L2 by this ratio to enable grouping - - Returns: - (should_group, group_size): Whether to group and the optimal group size """ # Check if disabled via environment if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": return False, nheads - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) # Get element size if dtype in (torch.float16, torch.bfloat16): elem_size = 2 elif dtype == torch.float32: elem_size = 4 - elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: - elem_size = 1 elif 'float8' in str(dtype).lower(): elem_size = 1 else: @@ -215,8 +210,8 @@ def is_head_grouping_beneficial( # Total K,V memory for all heads total_kv = nheads * seqlen_k * head_dim * elem_size * 2 - # Only group if K,V significantly exceeds L2 - if total_kv < l2_size * threshold_ratio: + # Only group if K,V significantly exceeds LLC + if total_kv < llc_size * threshold_ratio: return False, nheads # Calculate optimal group size @@ -225,7 +220,6 @@ def is_head_grouping_beneficial( ) # Only group if we'd have at least 2 groups - # (otherwise grouping adds overhead with no benefit) if group_size >= nheads: return False, nheads @@ -244,8 +238,9 @@ def print_head_grouping_info( device_index: int = 0 ): """Print diagnostic information about head grouping.""" - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) arch = get_gcn_arch_name(device_index) + num_cus = get_num_cus(device_index) if dtype in (torch.float16, torch.bfloat16): elem_size = 2 @@ -261,16 +256,16 @@ def print_head_grouping_info( nheads, seqlen_k, head_dim, dtype, device_index ) - print(f"\n=== L2 Cache-Aware Head Grouping ===") - print(f"GPU: {arch}") - print(f"L2 Cache: {l2_size / (1024*1024):.1f} MB") + print(f"\n=== Infinity Cache (LLC) Aware Head Grouping ===") + print(f"GPU: {arch} ({num_cus} CUs)") + print(f"Infinity Cache (LLC): {llc_size / (1024*1024):.1f} MB") print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") - print(f"L2 Ratio: {total_kv / l2_size:.2f}x") + print(f"LLC Ratio: {total_kv / llc_size:.2f}x") print(f"Should Group: {should_group}") if should_group: kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 num_groups = (nheads + group_size - 1) // group_size print(f"Group Size: {group_size} heads ({num_groups} groups)") print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") - print("=" * 40 + "\n") + print("=" * 48 + "\n") From 92cc73ac2bb27a044d83f99f6fdf30e7f7e036a1 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Fri, 9 Jan 2026 01:23:17 +0000 Subject: [PATCH 258/258] [ROCM]Optimized for gfx1100 (RDNA3) with LLC-aware head grouping for long seqlen --- .../flash_attn_triton_amd/fwd_prefill.py | 27 +++++++++---------- .../flash_attn_triton_amd/l2_cache_aware.py | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 7e814be4118..b0b320321b6 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,21 +186,18 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - # === Configs for head_dim=64 (optimal) === - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Best config from autotune on gfx1100: 32x16, warps=2, PRE_LOAD_V=True + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + # === Configs for head_dim=128 === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + # === Fallback configs === triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # === Configs for head_dim=128 (Wan2.2) - smaller blocks to reduce register pressure === - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # === General fallback configs === - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] @@ -215,9 +212,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ - # Use BLOCK_N=32 to avoid register spilling on gfx1100 + # Optimized for gfx1100 (RDNA3) with LLC-aware head grouping triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": True}, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py index f506e330aae..981f8ce5702 100644 --- a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -148,7 +148,7 @@ def calculate_optimal_head_group_size( head_dim: int, dtype: torch.dtype, device_index: int = 0, - llc_utilization: float = 1.0 # Use higher utilization by default + llc_utilization: float = 1.5 # Use 150% of LLC - optimal for long sequences ) -> int: """ Calculate the optimal number of heads to process together to fit K,V in LLC.