diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 491253bb999..0f9de4aae26 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -1270,7 +1270,9 @@ struct Smem_tile_mma { // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; @@ -1333,7 +1335,9 @@ struct Smem_tile_mma_transposed : public Base { uint4 dst; // fmha::ldsmt(dst, this->smem_ + offset); // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::ldsmt(dst, offset); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! @@ -1413,7 +1417,8 @@ struct Smem_tile_mma_epilogue : public Base { // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); // } @@ -1431,7 +1436,8 @@ struct Smem_tile_mma_epilogue : public Base { for( int mi = 0; mi < M; mi++ ) { for( int ni = 0; ni < N; ni++ ) { // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; @@ -1485,7 +1491,6 @@ struct Smem_tile_transpose { write_col ^= (write_row & 0x07) * 4; write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; int read_row, read_col; read_row = (tidx & 0x0f); @@ -1493,20 +1498,34 @@ struct Smem_tile_transpose { read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) { + uint32_t offset = base_offset; + const uint32_t reg0 = frag.reg(0); + const uint32_t reg1 = frag.reg(1); + const uint32_t reg2 = frag.reg(2); + const uint32_t reg3 = frag.reg(3); + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg0); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg2); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg1); + fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg3); + } + + inline __device__ uint4 load_fragment(uint32_t offset) { + uint4 dst; + fmha::ldsmt(dst, smem_ + offset); + return dst; } template inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } } @@ -1514,10 +1533,8 @@ struct Smem_tile_transpose { inline __device__ void load(Fragment_read (&frag_r)[N]) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1530,21 +1547,14 @@ struct Smem_tile_transpose { static_assert(COLS == Cta_tile::N); #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1555,8 +1565,6 @@ struct Smem_tile_transpose { uint32_t smem_; uint32_t write_offset_; uint32_t read_offset_; - // uint32_t smem_write_; - // uint32_t smem_read_; }; ////////////////////////////////////////////////////////////////////////////////////////////////////