From e0e68b28c98c0c949679d0f5c4774cff91c60fc4 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 25 May 2026 11:07:56 +0800 Subject: [PATCH 1/3] [test] support cuda132 build --- .../sm90_pipeline_no_cluster.hpp | 6 +- .../src/fmha/smem_tile.h | 66 +++++++++++-------- .../flashmask_v2/sm90_pipeline_no_cluster.hpp | 6 +- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp +++ b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 491253bb999..85900e5a2a0 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -1270,7 +1270,9 @@ struct Smem_tile_mma { // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; @@ -1333,7 +1335,9 @@ struct Smem_tile_mma_transposed : public Base { uint4 dst; // fmha::ldsmt(dst, this->smem_ + offset); // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::ldsmt(dst, offset); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! @@ -1413,7 +1417,8 @@ struct Smem_tile_mma_epilogue : public Base { // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); // } @@ -1431,7 +1436,8 @@ struct Smem_tile_mma_epilogue : public Base { for( int mi = 0; mi < M; mi++ ) { for( int ni = 0; ni < N; ni++ ) { // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; @@ -1485,7 +1491,6 @@ struct Smem_tile_transpose { write_col ^= (write_row & 0x07) * 4; write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; int read_row, read_col; read_row = (tidx & 0x0f); @@ -1493,20 +1498,34 @@ struct Smem_tile_transpose { read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) { + uint32_t offset = smem_ + base_offset; + const uint32_t reg0 = frag.reg(0); + const uint32_t reg1 = frag.reg(1); + const uint32_t reg2 = frag.reg(2); + const uint32_t reg3 = frag.reg(3); + fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, reg1); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg3); + } + + inline __device__ uint4 load_fragment(uint32_t offset) { + uint4 dst; + fmha::ldsmt(dst, smem_ + offset); + return dst; } template inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } } @@ -1514,10 +1533,8 @@ struct Smem_tile_transpose { inline __device__ void load(Fragment_read (&frag_r)[N]) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1530,21 +1547,14 @@ struct Smem_tile_transpose { static_assert(COLS == Cta_tile::N); #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1555,8 +1565,6 @@ struct Smem_tile_transpose { uint32_t smem_; uint32_t write_offset_; uint32_t read_offset_; - // uint32_t smem_write_; - // uint32_t smem_read_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp +++ b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( From bda9b377eaa261158d8ee2c7c793191e6a7db245 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 25 May 2026 23:57:41 +0800 Subject: [PATCH 2/3] rollback --- csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp | 6 ++---- csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp index c24f6150baa..65a3d1554b3 100644 --- a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp +++ b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp @@ -17,8 +17,7 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Count consumers in whole warpgroups. A single consumer warp still needs one -// mbarrier arrival count. +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -40,8 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = - (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp index c24f6150baa..65a3d1554b3 100644 --- a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp +++ b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp @@ -17,8 +17,7 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Count consumers in whole warpgroups. A single consumer warp still needs one -// mbarrier arrival count. +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -40,8 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = - (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( From 5518810894de1fbb5fe8f415a9436a4287f316d9 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 26 May 2026 22:12:31 +0800 Subject: [PATCH 3/3] fix store_fragment The incorrect calculation of the memory offset in the code --- .../flash_attn_with_bias_and_mask/src/fmha/smem_tile.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 85900e5a2a0..0f9de4aae26 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -1501,16 +1501,16 @@ struct Smem_tile_transpose { } inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) { - uint32_t offset = smem_ + base_offset; + uint32_t offset = base_offset; const uint32_t reg0 = frag.reg(0); const uint32_t reg1 = frag.reg(1); const uint32_t reg2 = frag.reg(2); const uint32_t reg3 = frag.reg(3); - fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); - fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg0); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg2); offset ^= 4 * BYTES_PER_STS; - fmha::sts(offset + 0 * BYTES_PER_ROW, reg1); - fmha::sts(offset + 8 * BYTES_PER_ROW, reg3); + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg1); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg3); } inline __device__ uint4 load_fragment(uint32_t offset) {