Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion flashmask/flash_mask/flash_attn_v4/copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ def tiled_copy_1d(


def tiled_copy_2d(
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
) -> cute.TiledCopy:
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
copy_elems = num_copy_bits // dtype.width
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
gmem_threads_per_row = major_mode_size // copy_elems
assert num_threads % gmem_threads_per_row == 0
thr_layout = cute.make_ordered_layout(
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
order=(1, 0),
)
val_layout = cute.make_layout((1, copy_elems))
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)

# TODO(umiswing): Support calling the Quack kernel directly
# from Paddle rather than duplicating it here.
def quack_tiled_copy_2d(
dtype: Type[cutlass.Numeric],
threads_per_row: int,
num_threads: int,
Expand All @@ -107,7 +125,6 @@ def tiled_copy_2d(
val_layout = cute.make_layout((1, num_copy_elems))
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)


@dsl_user_op
def atomic_add_fp32x4(
a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _setup_attributes(self):

num_copy_elems = 128 // self.dtype.width
threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems
self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
self.gmem_tiled_copy_dQ = copy_utils.quack_tiled_copy_2d(
self.dtype, threads_per_row, self.num_threads, num_copy_elems
)
# ///////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _setup_attributes(self):
)
num_copy_elems = 128 // self.dtype.width
threads_per_row = gmem_k_block_size // num_copy_elems
self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(
self.gmem_tiled_copy_O = copy_utils.quack_tiled_copy_2d(
self.dtype, threads_per_row, self.num_threads, num_copy_elems
)
universal_copy_bits = 128
Expand Down Expand Up @@ -367,4 +367,4 @@ def kernel(
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,))
LOG2_E = math.log2(math.e)
if tidx < seqlen_q_rounded - m_block * self.tile_m:
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
2 changes: 1 addition & 1 deletion flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def __call__(
async_copy_elems = 128 // self.q_dtype.width
num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids)
threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads)
gmem_tiled_copy_Q = copy_utils.tiled_copy_2d(
gmem_tiled_copy_Q = copy_utils.quack_tiled_copy_2d(
self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from flash_mask.flash_attn_v4.sm100_hd256_2cta_fmha_backward_dkdvkernel import (
BlackwellFusedMultiHeadAttentionBackwardDKDVKernel,
)
from flash_mask.flash_attn_v4.cute_dsl_utils import assume_tensor_aligned


def _as_bshkrd_tensor(
tensor: cute.Tensor,
Expand Down Expand Up @@ -252,6 +254,8 @@ def __call__(
else:
b = Q.shape[0]

Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)]

Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen)
K = _as_bshkrd_tensor(K, h_k, 1, varlen)
V = _as_bshkrd_tensor(V, h_k, 1, varlen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams,
)

import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils

LAYOUT_RANK_CONSTANT = 3

Expand Down Expand Up @@ -2811,38 +2812,53 @@ def epilogue_clear(
dK.iterator + mdK_offset,
cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride),
)
gdK = cute.local_tile(
mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None)
)
gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None))
gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch]
cdK = cute.domain_offset(
(blk_coord_k * self.tile_shape_K, 0),
cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])),
cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])),
)

mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64)
mdV = cute.make_tensor(
dV.iterator + mdV_offset,
cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride),
)
gdV = cute.local_tile(
mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None)
)
gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None))
gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch]
cdV = cute.domain_offset(
(blk_coord_k * self.tile_shape_K, 0),
cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])),
cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])),
)

for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8):
if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])):
gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8))
gdK_i.fill(0)
num_zero_epi_threads = 256

tiled_copy_r2g = fa_copy_utils.tiled_copy_2d(
dK.element_type, self.cta_tiler[2], num_zero_epi_threads
)

thr_copy_r2g = tiled_copy_r2g.get_slice(tidx)

tRG_gdK = thr_copy_r2g.partition_D(gdK)
tRG_cdK = thr_copy_r2g.partition_D(cdK)
tRG_gdV = thr_copy_r2g.partition_D(gdV)
tRG_cdV = thr_copy_r2g.partition_D(cdV)

zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None])
zero_frg.fill(dK.element_type(0.0))

# check we don't need zero fragment duplication
V_frg_size = cute.size(tRG_gdV[None, 0, None])
assert cute.size(zero_frg) == V_frg_size

if tidx < num_zero_epi_threads:
for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True):
if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]):
cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None])

for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8):
if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])):
gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8))
gdV_i.fill(0)
for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True):
if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]):
cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None])

@cute.jit
def epilogue(
Expand Down
Loading