From 598f152cb033be01574cafeb8ad48e33b705b978 Mon Sep 17 00:00:00 2001 From: GuoxiaWang Date: Wed, 28 Aug 2024 23:00:27 +0800 Subject: [PATCH] optimize skip block calculate in bwd --- csrc/flash_attn/src/flash_bwd_kernel.h | 88 +++++++++++++------ .../src/flash_bwd_launch_template.h | 4 - 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index ae4e1bf35c0..6c1d681a3d5 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -75,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, template -inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, +__forceinline__ __device__ void dot_do_o(Tensor const &do_, Tensor const &o, Tensor &dP_sum, Tensor &sdPsum, const int gdP_col_stride, const float scale) { static_assert(Layout0::rank == 3, "Only support 3D Tensor"); @@ -425,7 +425,7 @@ inline __device__ void convert_dKV(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { +__forceinline__ __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { const bool Is_sparse_attn_mask = params.flashmask_downstart_ptr != nullptr; int flashmask_startrow = 0; @@ -488,9 +488,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr; int flashmask_upendrow = params.seqlen_q; +#define SPARSE_MASKED_DOWN \ + (((m_block * kBlockM) >= flashmask_downstartmax) && (!flashmask_has_end || (m_block + 1) * kBlockM < flashmask_downendmin)) + +#define SPARSE_MASKED_UP \ + (!Is_causal && (m_block + 1) * kBlockM < flashmask_upendmin && (!flashmask_has_end || m_block * kBlockM >= flashmask_upstartmax)) + +#define SPARSE_MASKED \ + (SPARSE_MASKED_DOWN || SPARSE_MASKED_UP) + const bool enable_mask_bypass = params.enable_mask_bypass; - if (Is_sparse_attn_mask && enable_mask_bypass) { + int flashmask_downstartmax = std::numeric_limits::max(); + int flashmask_downendmin = 0; + int flashmask_upendmin = 0; + int flashmask_upstartmax = std::numeric_limits::max(); + + if(params.flashmask_downstart_nblockmax != nullptr) + flashmask_downstartmax = gSparseMaskDownMax[n_block]; + if(params.flashmask_downend_nblockmin != nullptr) + flashmask_downendmin = gSparseMaskDownEndMin[n_block]; + if(params.flashmask_upend_nblockmin != nullptr) + flashmask_upendmin = gSparseMaskUpMin[n_block]; + if(params.flashmask_upstart_nblockmax != nullptr) + flashmask_upstartmax = gSparseMaskUpStartMax[n_block]; + + if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end) { m_block_max = min(m_block_max, cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM)); /* @@ -744,7 +767,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in int m_block = m_block_max - 1; int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM; - if(Is_sparse_attn_mask && enable_mask_bypass){ + if(Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end){ if (!Is_causal) { m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM); } @@ -922,8 +945,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // } // if (cute::thread0()) { print(tSrK); } - flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + + if (!SPARSE_MASKED) { + flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + } // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); @@ -1005,7 +1031,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. - flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + if (!SPARSE_MASKED) { + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + } if (Is_dropout) { uint32_t warp_id = tidx / 32; uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; @@ -1048,21 +1076,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dP_sum); } - flash::gemm( - acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV - ); - - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { - return p * (!Is_dropout || p >= 0 ? dp - d : d); - }; - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { + if (!SPARSE_MASKED) { + flash::gemm( + acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV + ); + + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } } } // if (cute::thread0()) { print(dS); } @@ -1104,8 +1134,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + if (!SPARSE_MASKED) { + flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + } // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } // if (cute::thread0()) { print(acc_dv); } @@ -1124,8 +1156,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } - flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); + if (!SPARSE_MASKED) { + flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); + } // if (cute::thread0()) { print(acc_dq); } if (m_block > m_block_min) { @@ -1163,8 +1197,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + if (!SPARSE_MASKED) { + flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + } // if (cute::thread0()) { print(acc_dk); } if (Double_buffer) { // Double buffer for sQ tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 749c47395b6..1dd48138e4e 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,10 +64,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); - if (params.flashmask_downend_ptr != nullptr) { - // bypass is not supported for flashmask_downend - params.enable_mask_bypass = false; - } prepare_sparsemask(params, stream); BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {