Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,9 @@ struct Smem_tile_mma {
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset =
smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW +
ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
Expand Down Expand Up @@ -1333,7 +1335,9 @@ struct Smem_tile_mma_transposed : public Base {
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset =
smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW +
ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
Expand Down Expand Up @@ -1413,7 +1417,8 @@ struct Smem_tile_mma_epilogue : public Base {
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset =
(this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
Expand All @@ -1431,7 +1436,8 @@ struct Smem_tile_mma_epilogue : public Base {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset =
(this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
Expand Down Expand Up @@ -1485,39 +1491,50 @@ struct Smem_tile_transpose {
write_col ^= (write_row & 0x07) * 4;

write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;

int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;

read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}

inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) {
uint32_t offset = base_offset;
const uint32_t reg0 = frag.reg(0);
const uint32_t reg1 = frag.reg(1);
const uint32_t reg2 = frag.reg(2);
const uint32_t reg3 = frag.reg(3);
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg0);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg2);
offset ^= 4 * BYTES_PER_STS;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个offset的变换是不是和修改前不等价

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, reg1);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, reg3);
}

inline __device__ uint4 load_fragment(uint32_t offset) {
uint4 dst;
fmha::ldsmt(dst, smem_ + offset);
return dst;
}

template<int M, int N>
inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
const uint32_t base =
write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
store_fragment(base, frag_w[ni][mi]);
}
}

template<int N>
inline __device__ void load(Fragment_read (&frag_r)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
const uint4 dst = load_fragment(offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
Expand All @@ -1530,21 +1547,14 @@ struct Smem_tile_transpose {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
const uint32_t base =
write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
store_fragment(base, frag_w[ni][mi]);
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
const uint4 dst = load_fragment(offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
Expand All @@ -1555,8 +1565,6 @@ struct Smem_tile_transpose {
uint32_t smem_;
uint32_t write_offset_;
uint32_t read_offset_;
// uint32_t smem_write_;
// uint32_t smem_read_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down