Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 19 additions & 73 deletions flashmask/flash_mask/flash_attn_v4/copy_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.

import math
from typing import Optional, Type, Callable
Expand Down Expand Up @@ -90,6 +90,24 @@ def tiled_copy_1d(


def tiled_copy_2d(
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
) -> cute.TiledCopy:
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
copy_elems = num_copy_bits // dtype.width
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
gmem_threads_per_row = major_mode_size // copy_elems
assert num_threads % gmem_threads_per_row == 0
thr_layout = cute.make_ordered_layout(
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
order=(1, 0),
)
val_layout = cute.make_layout((1, copy_elems))
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)

# TODO(umiswing): Support calling the Quack kernel directly
# from Paddle rather than duplicating it here.
def quack_tiled_copy_2d(
dtype: Type[cutlass.Numeric],
threads_per_row: int,
num_threads: int,
Expand All @@ -107,7 +125,6 @@ def tiled_copy_2d(
val_layout = cute.make_layout((1, num_copy_elems))
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)


@dsl_user_op
def atomic_add_fp32x4(
a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
Expand Down Expand Up @@ -372,74 +389,3 @@ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwarg
)

return copy_fn


# --- Vendored from quack/copy_utils.py (BSD-3, Tri Dao et al.) ---
# Required by SM100 code paths that cannot depend on quack.

BIG_INT = 2**30
MAX_INT = 2**31 - 1
BIG_INT_INV = 2**64 // BIG_INT


@dsl_user_op
def create_ragged_tensor_for_tma(
T: cute.Tensor,
ragged_dim: int = 0,
ptr_shift: bool = False,
*,
loc=None,
ip=None,
) -> cute.Tensor:
rank = cute.rank(T)
if ragged_dim < 0:
ragged_dim += rank
if ptr_shift:
assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions"
new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,)
new_stride = T.stride + (T.stride[ragged_dim],)
ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1)
new_ptr = cute.domain_offset(ptr_offset, T).iterator
return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride))
else:
assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions"
stride_r = T.stride[ragged_dim]
new_shape = (
T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT)
)
new_stride = (
T.stride[:ragged_dim]
+ (stride_r,)
+ T.stride[ragged_dim + 1 :]
+ (BIG_INT_INV - stride_r, stride_r)
)
return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride))


@dsl_user_op
def offset_ragged_tensor(
T: cute.Tensor,
offset: Int32,
length: Int32,
ragged_dim: int = 0,
ptr_shift: bool = False,
*,
loc=None,
ip=None,
) -> cute.Tensor:
rank = cute.rank(T)
if ragged_dim < 0:
ragged_dim += rank
big_int = cute.size(T, mode=[ragged_dim])
offset_val = big_int - length
if ptr_shift:
# 1-extra-dim: rank = original_rank + 1
assert rank >= ragged_dim + 2
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2)
index_tuple = (None,) * (rank - 1) + (offset + length,)
else:
# 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims
assert rank >= ragged_dim + 3
offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3)
index_tuple = (None,) * (rank - 2) + (big_int, offset + length)
return cute.domain_offset(offset_tuple, T[index_tuple])
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _setup_attributes(self):

num_copy_elems = 128 // self.dtype.width
threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems
self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
self.gmem_tiled_copy_dQ = copy_utils.quack_tiled_copy_2d(
self.dtype, threads_per_row, self.num_threads, num_copy_elems
)
# ///////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def __call__(
async_copy_elems = 128 // self.q_dtype.width
num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids)
threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads)
gmem_tiled_copy_Q = copy_utils.tiled_copy_2d(
gmem_tiled_copy_Q = copy_utils.quack_tiled_copy_2d(
self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True
)

Expand Down
5 changes: 4 additions & 1 deletion flashmask/flash_mask/flash_attn_v4/paddle/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,9 @@ def flash_attn_varlen_func(
min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch.
Used with gather_kv_indices to determine if we need oob masking.
"""
print(f"wsm debug {cu_seqlens_q=}")
print(f"wsm debug {cu_seqlens_k=}")
print(f"wsm debug {causal=}")
return FlashAttnVarlenFunc.apply(
q,
k,
Expand Down Expand Up @@ -2552,4 +2555,4 @@ def flash_attn_combine(
seqused,
varlen_batch_idx=varlen_batch_idx,
)
return out, lse
return out, lse
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from flash_mask.flash_attn_v4.sm100_hd256_2cta_fmha_backward_dkdvkernel import (
BlackwellFusedMultiHeadAttentionBackwardDKDVKernel,
)
from flash_mask.flash_attn_v4.cute_dsl_utils import assume_tensor_aligned


def _as_bshkrd_tensor(
tensor: cute.Tensor,
Expand Down Expand Up @@ -252,6 +254,8 @@ def __call__(
else:
b = Q.shape[0]

Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)]

Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen)
K = _as_bshkrd_tensor(K, h_k, 1, varlen)
V = _as_bshkrd_tensor(V, h_k, 1, varlen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams,
)

import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils

LAYOUT_RANK_CONSTANT = 3

Expand Down Expand Up @@ -2811,38 +2812,53 @@ def epilogue_clear(
dK.iterator + mdK_offset,
cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride),
)
gdK = cute.local_tile(
mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None)
)
gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None))
gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch]
cdK = cute.domain_offset(
(blk_coord_k * self.tile_shape_K, 0),
cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])),
cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])),
)

mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64)
mdV = cute.make_tensor(
dV.iterator + mdV_offset,
cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride),
)
gdV = cute.local_tile(
mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None)
)
gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None))
gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch]
cdV = cute.domain_offset(
(blk_coord_k * self.tile_shape_K, 0),
cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])),
cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])),
)

for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8):
if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])):
gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8))
gdK_i.fill(0)
num_zero_epi_threads = 256

tiled_copy_r2g = fa_copy_utils.tiled_copy_2d(
dK.element_type, self.cta_tiler[2], num_zero_epi_threads
)

thr_copy_r2g = tiled_copy_r2g.get_slice(tidx)

tRG_gdK = thr_copy_r2g.partition_D(gdK)
tRG_cdK = thr_copy_r2g.partition_D(cdK)
tRG_gdV = thr_copy_r2g.partition_D(gdV)
tRG_cdV = thr_copy_r2g.partition_D(cdV)

zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None])
zero_frg.fill(dK.element_type(0.0))

# check we don't need zero fragment duplication
V_frg_size = cute.size(tRG_gdV[None, 0, None])
assert cute.size(zero_frg) == V_frg_size

if tidx < num_zero_epi_threads:
for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True):
if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]):
cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None])

for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8):
if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])):
gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8))
gdV_i.fill(0)
for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True):
if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]):
cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None])

@cute.jit
def epilogue(
Expand Down
Loading