From c2766ac4b01b91bbffbad2b290eab729391a1edf Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 14:36:25 +0000 Subject: [PATCH 1/9] add aiter code --- flash_attn/flash_attn_triton_amd/mha.py | 2020 +++++++++++++++++ .../flash_attn_triton_amd/mha_fused_bwd.py | 1272 +++++++++++ .../mha_onekernel_bwd.py | 1806 +++++++++++++++ 3 files changed, 5098 insertions(+) create mode 100644 flash_attn/flash_attn_triton_amd/mha.py create mode 100644 flash_attn/flash_attn_triton_amd/mha_fused_bwd.py create mode 100644 flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py diff --git a/flash_attn/flash_attn_triton_amd/mha.py b/flash_attn/flash_attn_triton_amd/mha.py new file mode 100644 index 00000000000..b425db59351 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/mha.py @@ -0,0 +1,2020 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Tuple +import functools +import json +import torch +import triton +import triton.language as tl + +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.ops.triton.utils.pid_preprocessing import remap_xcd +from aiter.ops.triton.mha_onekernel_bwd import flash_attn_onekernel_backward +from aiter.ops.triton.mha_fused_bwd import flash_attn_fused_backward +from aiter.ops.triton.utils.mha_kernel_utils import ( + _compute_fp8_scaling_factors, + _is_fp8, +) + +global _USE_FUSED_BWD_KERNEL +_USE_FUSED_BWD_KERNEL = False + + +def mha_set_use_fused_bwd_kernel(value: bool): + global _USE_FUSED_BWD_KERNEL + _USE_FUSED_BWD_KERNEL = value + + +_USE_INT64_STRIDES = True + + +def mha_set_use_int64_strides(value: bool): + """Use 64-bit integer strides to prevent integer overflows with very large tensors.""" + global _USE_INT64_STRIDES + _USE_INT64_STRIDES = value + + +def _cast_to_fp8( + x: torch.Tensor, + fp8_dtype, + layout, + clamp_val=1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a tensor to FP8 format, returning an FP8 tensor and a descale factor. + Args: + - x (torch.Tensor): shape [batch, seq_len, heads, dim] + Returns: + - x_fp8 (torch.Tensor): FP8 tensor with the same shape as x + - descale_factor (torch.Tensor): tensor of shape [batch, 1, heads, 1] + """ + if len(x.shape) != 4: + raise ValueError( + f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}" + ) + reduce_dims = (1, 3) # seq_len and dim dimensions + + # Compute the absolute max along reduce_dims, clamped to avoid 0-scale + x_abs_max = x.abs().amax(dim=reduce_dims) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Unsqueeze back to a shape suitable for broadcast + unsqueeze_dims = sorted(reduce_dims) + for d in unsqueeze_dims: + x_abs_max = x_abs_max.unsqueeze(d) + + # compute scale and descale + fp8_max = torch.finfo(fp8_dtype).max + scale = fp8_max / x_abs_max + descale_factor = x_abs_max / fp8_max + + # cast to FP8, optionally setting requires_grad + x_fp8 = (x * scale).to(fp8_dtype) + + return x_fp8, descale_factor + + +def _cast_varlen_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + cu_seqlens, + clamp_val: float = 1e-9, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert a tensor of sequences with variable seq_len into fp8. + Args: + - x (torch.Tensor): shape [total_seq_len, heads, dim] + Returns: + - x_fp8 (torch.Tensor): shape [total_seq_len, heads, dim] + - descale_factors (torch.Tensor): shape [batch, heads] + """ + # validate tensor shape + if len(x.shape) != 3: + raise ValueError( + f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}" + ) + num_heads = x.shape[1] + + # Get batch size from cu_seqlens + batch = cu_seqlens.shape[0] - 1 + fp8_max = torch.finfo(fp8_dtype).max + + # Compute scale and descale factors per sequence + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros( + (batch, num_heads), device=x.device, dtype=torch.float32 + ) + + for i in range(batch): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + x_slice = x[start:end] # Slice for current sequence + + # Standard tensor (0: seq_len, 2: head_dim) + x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] + + # apply minimum clamping + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # compute scale and descale factors + scale_i = fp8_max / x_abs_max + descale_i = x_abs_max / fp8_max + + # store descale factors + descale_factors[i, :] = descale_i + + scale_reshape = scale_i.reshape(1, num_heads, 1) + + # scale and cast to FP8 + x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) + + return x_fp8, descale_factors + + +@triton.jit +def _cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def _load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & ( + offset_second[None, :] < boundary_second + ) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def _compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False +): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vk, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + sd_mask_ptrs, + dropout_mask_ptrs, + philox_seed, + philox_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + descale_q, + descale_k, + descale_v, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + SM_SCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + PADDED_HEAD: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) + k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + mask = tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1) + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + + # remove the old if condition + # if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + # Though this will unconditionally compute mask_partial at runtime, + # the causal for loop does not have the if-else block any more, which + # helps instruction scheduling and register pressure. + bound_cond = (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0) + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask_partial = size_n < boundary_m[:, None] + mask = tl.where(bound_cond, mask_partial, mask) + + # compute masks + q_mask = OFFS_M[:, None] < seqlen_q + k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k + p_mask = q_mask & k_mask + + # -- compute qk ---- + if IS_FP8: + qk += tl.dot(q, k) * descale_q * descale_k + else: + qk += tl.dot(q, k) + + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + mask = mask and causal_mask + + qk = tl.where(mask, qk, float("-inf")) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = _compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, global_m_positions, global_n_positions + ) + qk += alibi_block / SM_SCALE + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + m_ij_scaled = m_ij * SM_SCALE * RCP_LN2 + + # scale and subtract max + q_shifted = qk * SM_SCALE * RCP_LN2 - m_ij_scaled[:, None] + + # Compute scaled QK and softmax probabilities + p = tl.math.exp2(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + rng_output = tl.rand( + philox_seed, philox_ptrs + ) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + tl.store(sd_mask_ptrs, p, mask=p_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff_scaled = m_i * SM_SCALE * RCP_LN2 - m_ij_scaled + alpha = tl.math.exp2(m_diff_scaled) + acc = acc * alpha[:, None] + v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if IS_FP8: + scale_p, descale_p = _compute_fp8_scaling_factors(p, FP8_MAX) + acc += ( + tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v + ) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if RETURN_SCORES: + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn + + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + q_ptr: torch.Tensor, + k_ptr: torch.Tensor, + v_ptr: torch.Tensor, + descale_q_ptr: torch.Tensor, + descale_k_ptr: torch.Tensor, + descale_v_ptr: torch.Tensor, + out_ptr: torch.Tensor, + alibi_slopes_ptr: torch.Tensor, + s_dmask_ptr: torch.Tensor, + dropout_mask_ptr: torch.Tensor, + softmax_lse_ptr: torch.Tensor, + stride_qz_in, + stride_qh_in, + stride_qm_in, + stride_qk_in, + stride_kz_in, + stride_kh_in, + stride_kn_in, + stride_kk_in, + stride_vz_in, + stride_vh_in, + stride_vn_in, + stride_vk_in, + stride_descale_q_z_in, + stride_descale_k_z_in, + stride_descale_v_z_in, + stride_oz_in, + stride_oh_in, + stride_om_in, + stride_on_in, + stride_alibi_z_in, + stride_alibi_h_in, + stride_sd_z_in, + stride_sd_h_in, + stride_sd_m_in, + stride_sd_n_in, + stride_lse_z_in, + stride_lse_h_in, + stride_lse_m_in, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base_in, + SEQLEN_Q: tl.constexpr, + SEQLEN_K: tl.constexpr, + IS_CAUSAL: tl.constexpr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + VARLEN: tl.constexpr, + BATCH, + NUM_XCD: tl.constexpr, + USE_INT64_STRIDES: tl.constexpr, +): + NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M + # calculate offsets + wid = tl.program_id( + 0 + ) # workgroup id ranging: 0,1,2,...., (BATCH * NUM_Q_HEADS * NUM_BLOCKS - 1) + # num blocks along seqlen + + off_q_head = wid % NUM_Q_HEADS + off_q_head = remap_xcd(off_q_head, NUM_Q_HEADS, NUM_XCD) + start_m = (wid // NUM_Q_HEADS) % NUM_BLOCKS + off_z = (wid // (NUM_BLOCKS * NUM_Q_HEADS)) % BATCH + + # offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_POW2) + + # NOTE: + # Workaround for int64 strides, In the absence of strides being int64, parts of the offset + # computation is done in 32 bit and overflows resulting in segfaults + # If input strides are defined as int64, it disables vectorized loads which drops perf + # If we define new strides as stride_x = stride_x_in.to(tl.int64), that does not work + # because strides are tl.constexpr and cannot be upcasted + # If we define new strides as stride_x: tl.int64 = stride_x_in, segfault remains + # The permanent solution is to enable upcasting of tl.constexpr + # In the meantime, the following workaround provides correctness and does not drop perf + if USE_INT64_STRIDES: + stride_qz = tl.cast(stride_qz_in, tl.int64) + stride_qh = tl.cast(stride_qh_in, tl.int64) + stride_qm = tl.cast(stride_qm_in, tl.int64) + stride_qk = tl.cast(stride_qk_in, tl.int64) + stride_kz = tl.cast(stride_kz_in, tl.int64) + stride_kh = tl.cast(stride_kh_in, tl.int64) + stride_kn = tl.cast(stride_kn_in, tl.int64) + stride_kk = tl.cast(stride_kk_in, tl.int64) + stride_vz = tl.cast(stride_vz_in, tl.int64) + stride_vh = tl.cast(stride_vh_in, tl.int64) + stride_vn = tl.cast(stride_vn_in, tl.int64) + stride_vk = tl.cast(stride_vk_in, tl.int64) + if IS_FP8: + stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) + stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) + stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) + stride_oz = tl.cast(stride_oz_in, tl.int64) + stride_oh = tl.cast(stride_oh_in, tl.int64) + stride_om = tl.cast(stride_om_in, tl.int64) + stride_on = tl.cast(stride_on_in, tl.int64) + stride_alibi_z = tl.cast(stride_alibi_z_in, tl.int64) + stride_alibi_h = tl.cast(stride_alibi_h_in, tl.int64) + + # NOTE: philox offset is need in dropout pointer calculations + philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) + stride_sd_z = tl.cast(stride_sd_z_in, tl.int64) + stride_sd_h = tl.cast(stride_sd_h_in, tl.int64) + stride_sd_m = tl.cast(stride_sd_m_in, tl.int64) + stride_sd_n = tl.cast(stride_sd_n_in, tl.int64) + stride_lse_z = tl.cast(stride_lse_z_in, tl.int64) + stride_lse_h = tl.cast(stride_lse_h_in, tl.int64) + stride_lse_m = tl.cast(stride_lse_m_in, tl.int64) + else: + stride_qz = stride_qz_in + stride_qm = stride_qm_in + stride_qk = stride_qk_in + stride_qh = stride_qh_in + stride_kz = stride_kz_in + stride_kh = stride_kh_in + stride_kn = stride_kn_in + stride_kk = stride_kk_in + stride_vz = stride_vz_in + stride_vh = stride_vh_in + stride_vn = stride_vn_in + stride_vk = stride_vk_in + stride_descale_q_z = stride_descale_q_z_in + stride_descale_k_z = stride_descale_k_z_in + stride_descale_v_z = stride_descale_v_z_in + stride_oz = stride_oz_in + stride_oh = stride_oh_in + stride_om = stride_om_in + stride_on = stride_on_in + stride_alibi_z = stride_alibi_z_in + stride_alibi_h = stride_alibi_h_in + philox_offset_base = philox_offset_base_in + stride_sd_z = stride_sd_z_in + stride_sd_h = stride_sd_h_in + stride_sd_m = stride_sd_m_in + stride_sd_n = stride_sd_n_in + stride_lse_z = stride_lse_z_in + stride_lse_h = stride_lse_h_in + stride_lse_m = stride_lse_m_in + + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = SEQLEN_Q + seqlen_k = SEQLEN_K + + n_blocks = _cdiv_fn(seqlen_k, BLOCK_N) + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = _cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) + + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + offs_out = ( + off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) + out_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) + tl.store(out_ptr + offs_out, acc, mask=out_mask) + + if softmax_lse_ptr is not None: + offs_lse = ( + off_z * stride_lse_z + + off_q_head * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m * stride_lse_m + ) + lse_mask = offs_m < SEQLEN_Q + lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) + tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + + return + + grp_sz: tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS + if grp_sz != 1: # Grouped Query Attention + off_k_head = off_q_head // grp_sz + else: + off_k_head = off_q_head + + # q,k,v offsets + q_offs = ( + off_z * stride_qz + + off_q_head * stride_qh + + cu_seqlens_q_start * stride_qm + + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk + ) + q_ptrs = q_ptr + q_offs + + k_offs = ( + off_z * stride_kz + + off_k_head * stride_kh + + cu_seqlens_k_start * stride_kn + + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn + ) + k_ptrs = k_ptr + k_offs + + v_offs = ( + off_z * stride_vz + + off_k_head * stride_vh + + cu_seqlens_k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d[None, :] * stride_vk + ) + v_ptrs = v_ptr + v_offs + + # alibi slopes + if alibi_slopes_ptr is not None: + alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h + alibi_slope = tl.load(alibi_slopes_ptr + alibi_offs) + else: + alibi_slope = None + + # s_dmask (return_scores) + if s_dmask_ptr is not None: + s_dmask_offs = ( + off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + s_dmask_ptrs = s_dmask_ptr + s_dmask_offs + else: + s_dmask_ptrs = None + + # dropout + if dropout_mask_ptr is not None: + dropout_mask_offs = ( + off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs + philox_ptrs = ( + philox_offset_base + + off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + else: + dropout_mask_ptrs = None + philox_ptrs = None + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) + if BLOCK_DMODEL == BLOCK_DMODEL_POW2: + q_mask = offs_m[:, None] < seqlen_q + else: + q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + if IS_FP8: + descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) + descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) + descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + + # if CAUSAL, then determine masked_blocks and full blocks + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vn, + stride_sd_n, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + s_dmask_ptrs, + dropout_mask_ptrs, + philox_seed, + philox_ptrs, + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + descale_q, + descale_k, + descale_v, + offs_m, + offs_n, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + BLOCK_DMODEL_POW2, + sm_scale, + False, + MASK_STEPS=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, + PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vn + if RETURN_SCORES: + s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vn, + stride_sd_n, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + s_dmask_ptrs, + dropout_mask_ptrs, + philox_seed, + philox_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + descale_q, + descale_k, + descale_v, + offs_m, + offs_n, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + BLOCK_DMODEL_POW2, + sm_scale, + IS_CAUSAL, + MASK_STEPS=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, + PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full( + (BLOCK_DMODEL_POW2,), causal_start_idx, dtype=tl.int32 + ) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + overflow_size = end_m_idx - seqlen_q + if softmax_lse_ptr is not None: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + # mi_base2 = m_i * RCP_LN2 + mi_base2 = m_i * RCP_LN2 * sm_scale + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + + if IS_CAUSAL: + # zero out nans caused by -infs when doing causal + lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx + softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + offs_lse = ( + off_z * stride_lse_z + + off_q_head * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m * stride_lse_m + ) + if overflow_size > 0: + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + lse_mask = tl.arange(0, BLOCK_M) < boundary + tl.store( + softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask + ) # the log of the normalization constant + else: + tl.store( + softmax_lse_ptr + offs_lse, softmax_lse + ) # the log of the normalization constant + + # write back O + offs_out = ( + off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on + ) + out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) + if overflow_size > 0: + out_mask = out_mask & (offs_m[:, None] < seqlen_q) + if BLOCK_DMODEL != BLOCK_DMODEL_POW2: + out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) + op = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + offs_out, op, mask=out_mask) + + +@functools.lru_cache(maxsize=1024) +def _get_config( + enable_dropout: bool, + dtype: torch.dtype, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + _get_config._config_dict = {} + fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict["default"] = config + + if enable_dropout or dtype == torch.float32: + return _get_config._config_dict["default"]["fwd"]["dropout_or_fp32"] + else: + return _get_config._config_dict["default"]["fwd"]["default"] + + +def _flash_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + return_lse: bool, + return_softmax: bool, + max_seqlen_q: int, + max_seqlen_k: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + config: Optional[dict[str, any]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if bias is not None: + raise ValueError("Bias is not supported yet in the Triton Backend") + if window_size_left != -1 or window_size_right != -1: + raise ValueError("Sliding Window is not supported yet in the Triton Backend") + + # FP8 + IS_FP8 = _is_fp8(q) + FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + is_varlen = True if cu_seqlens_q is not None else False + + if IS_FP8: + o = torch.zeros_like(q, dtype=torch.float32) + else: + o = torch.zeros_like(q) + if is_varlen: + # Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = ( + len(cu_seqlens_q) - 1, + max_seqlen_q, + q.shape[1], + q.shape[2], + ) + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k = k.shape[1] + num_k_heads = k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # padding for head_dim. Power of 2 or 16 + BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) + + # softmax_lse [batch, num_q_heads, seqlen_q] + if is_varlen: + softmax_lse = torch.zeros( + (q.shape[0], num_q_heads), device=q.device, dtype=torch.float32 + ) + stride_lse_z, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(1), + softmax_lse.stride(0), + ) + else: + softmax_lse = torch.zeros( + (batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32 + ) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] + enable_dropout = dropout_p > 0.0 + if enable_dropout: + philox_seed = torch.randint(0, 0xFFFFFF, (1,))[ + 0 + ].item() # No specific reason to restrict range to 0xffffff + philox_offset = torch.randint(0, 0xFFFFFF, (1,))[ + 0 + ].item() # Pass in an int, not Tensor + else: + philox_seed = 0 + philox_offset = 0 + if return_softmax or enable_dropout: + s_dmask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + else: + s_dmask = None + dropout_mask = None + + if config is None: + config = _get_config(enable_dropout, q.dtype) + + """ + # Tuned for MI300x + config = { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 1, + } + # Dropout significantly increases VGPR usage so use small tiles + if enable_dropout or q.dtype == torch.float32: + config = { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1, + } + """ + + grid = lambda META: ( # noqa: E731 + batch * num_q_heads * triton.cdiv(seqlen_q, META["BLOCK_M"]), + ) + + _attn_fwd[grid]( + q, + k, + v, + descale_q, + descale_k, + descale_v, + o, + alibi_slopes, + s_dmask, + dropout_mask, + softmax_lse, + *q_strides, + *k_strides, + *v_strides, + descale_q.stride(0) if descale_q is not None else 0, + descale_k.stride(0) if descale_k is not None else 0, + descale_v.stride(0) if descale_v is not None else 0, + *o_strides, + alibi_slopes.stride(0) if alibi_slopes is not None else 0, + alibi_slopes.stride(1) if alibi_slopes is not None else 0, + s_dmask.stride(0) if s_dmask is not None else 0, + s_dmask.stride(1) if s_dmask is not None else 0, + s_dmask.stride(2) if s_dmask is not None else 0, + s_dmask.stride(3) if s_dmask is not None else 0, + stride_lse_z if softmax_lse is not None else 0, + stride_lse_h if softmax_lse is not None else 0, + stride_lse_m if softmax_lse is not None else 0, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset, + SEQLEN_Q=max_seqlen_q, + SEQLEN_K=max_seqlen_k, + IS_CAUSAL=causal, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_DMODEL=head_sz, + BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, + RETURN_SCORES=return_softmax, + ENABLE_DROPOUT=enable_dropout, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + VARLEN=is_varlen, + BATCH=batch, + NUM_XCD=8, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + **config, + ) + + return o, softmax_lse, s_dmask, philox_seed, philox_offset + + +class _FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + bias, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + config=None, + ): + is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( + _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=int(window_size[0]), + window_size_right=int(window_size[1]), + bias=bias, + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + config=config, + ) + ) + + if is_grad: + ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.bias = bias + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return result[0] if len(result) == 1 else tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + bias = ctx.bias + dbias = torch.empty_like(bias) if bias is not None else None + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + print("Using fused backward kernel:", _USE_FUSED_BWD_KERNEL) + + if _USE_FUSED_BWD_KERNEL: + flash_attn_fused_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + else: + flash_attn_onekernel_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return ( + dq, + dk, + dv, + None, + None, + None, + None, + dbias, + None, + None, + None, + None, + None, + None, + ) + + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + bias=None, + alibi_slopes=None, + deterministic=True, + return_lse=False, + return_attn_probs=False, + config: Optional[dict[str, any]] = None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim_q). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + bias: (seqlen_q, seqlen_k) + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + + return _FlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + bias, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled(), + config, + ) + + +class _FlashAttnFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + config=None, + ): + is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = arch_info.get_fp8_e4m3_dtype() + q_fp8, descale_q = _cast_to_fp8(q, fp8_dtype, "bshd") + k_fp8, descale_k = _cast_to_fp8(k, fp8_dtype, "bshd") + v_fp8, descale_v = _cast_to_fp8(v, fp8_dtype, "bshd") + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( + _flash_attn_forward( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=int(window_size[0]), + window_size_right=int(window_size[1]), + bias=None, + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + cu_seqlens_q=None, + cu_seqlens_k=None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + config=config, + ) + ) + + if is_grad: + ctx.save_for_backward( + q_fp8, + k_fp8, + v_fp8, + out_padded, + softmax_lse, + descale_q, + descale_k, + descale_v, + ) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return result[0] if len(result) == 1 else tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ( + ctx.saved_tensors + ) + dq, dk, dv = ( + torch.zeros_like(q_fp8, dtype=torch.float32), + torch.zeros_like(k_fp8, dtype=torch.float32), + torch.zeros_like(v_fp8, dtype=torch.float32), + ) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = arch_info.get_fp8_e4m3_dtype() + do_padded_fp8, descale_do = _cast_to_fp8(do_padded, fp8_dtype, "bshd") + if _USE_FUSED_BWD_KERNEL: + flash_attn_fused_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + None, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q_fp8.shape[1], + max_seqlen_k=k_fp8.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + else: + flash_attn_onekernel_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + None, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q_fp8.shape[1], + max_seqlen_k=k_fp8.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + + # dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + # dk = dk[..., : k_fp8.shape[-1]] + # dv = dv[..., : v_fp8.shape[-1]] + return ( + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def flash_attn_fp8_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + config: Optional[dict[str, any]] = None, +): + return _FlashAttnFP8Func.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled(), + config, + ) + + +class _FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + bias, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + out, + is_grad_enabled, + config=None, + ): + is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( + _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=int(window_size[0]), + window_size_right=int(window_size[1]), + bias=bias, + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0.0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + config=config, + ) + ) + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + ) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.bias = bias + ctx.alibi_slopes = alibi_slopes + out = out_padded[..., :head_size_og] + + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return result[0] if len(result) == 1 else tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + bias = ctx.bias + dbias = torch.empty_like(bias) if bias is not None else None + head_size_og = do.size(2) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + + if _USE_FUSED_BWD_KERNEL: + flash_attn_fused_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + else: + flash_attn_onekernel_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return ( + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + dbias, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + bias=None, + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None, + out=None, + config: Optional[dict[str, any]] = None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + bias: (seqlen_q, seqlen_k) + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return _FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + bias, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + out, + torch.is_grad_enabled(), + config, + ) + + +class _FlashAttnVarlenFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + is_grad_enabled, + config=None, + ): + is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = arch_info.get_fp8_e4m3_dtype() + q_fp8, descale_q = _cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) + k_fp8, descale_k = _cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) + v_fp8, descale_v = _cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( + _flash_attn_forward( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=int(window_size[0]), + window_size_right=int(window_size[1]), + bias=None, + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + config=config, + ) + ) + if is_grad: + ctx.save_for_backward( + q_fp8, + k_fp8, + v_fp8, + out_padded, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + descale_q, + descale_k, + descale_v, + ) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return result[0] if len(result) == 1 else tuple(result) + + @staticmethod + def backward(ctx, do, *args): + ( + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + descale_q, + descale_k, + descale_v, + ) = ctx.saved_tensors + dq, dk, dv = ( + torch.zeros_like(q_fp8, dtype=torch.float32), + torch.zeros_like(k_fp8, dtype=torch.float32), + torch.zeros_like(v_fp8, dtype=torch.float32), + ) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = arch_info.get_fp8_e4m3_dtype() + do_padded_fp8, descale_do = _cast_varlen_to_fp8( + do_padded, fp8_dtype, "thd", cu_seqlens_q + ) + if _USE_FUSED_BWD_KERNEL: + flash_attn_fused_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + None, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + else: + flash_attn_onekernel_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + None, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + USE_INT64_STRIDES=_USE_INT64_STRIDES, + ) + dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k_fp8.shape[-1]] + dv = dv[..., : v_fp8.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_fp8_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None, + config: Optional[dict[str, any]] = None, +): + return _FlashAttnVarlenFP8Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + torch.is_grad_enabled(), + config, + ) diff --git a/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py b/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py new file mode 100644 index 00000000000..7073d969602 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py @@ -0,0 +1,1272 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Dict +import functools +import json +import torch +import triton +import triton.language as tl + +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.ops.triton.utils.pid_preprocessing import remap_xcd +from aiter.ops.triton.utils.mha_kernel_utils import ( + _compute_fp8_scaling_factors, + _is_fp8, +) + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +@triton.jit +def _bwd_preprocess( + o_ptr, + do_ptr, # noqa: E741 + delta_ptr, + stride_o_b, + stride_o_h, + stride_o_m, + stride_o_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, + max_seqlen_q, + descale_do_ptr, + BLOCK_M: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hid = tl.program_id(2) # head + + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # Offset O/DO by batch, head and q_start + offs = ( + bid * stride_o_b + + hid * stride_o_h + + q_start * stride_o_m + + offs_m[:, None] * stride_o_m + + offs_k[None, :] * stride_o_k + ) + + # create masks + mask_m = offs_m < seqlen_q + mask = mask_m[:, None] + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D_MODEL + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs, mask=mask, other=0.0) + do = tl.load(do_ptr + offs, mask=mask, other=0.0) + + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + + offs_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + + +@triton.jit +def _bwd_dkdvdq_inner( + dk, + dv, + Q, + k, + v, + DO, + DQ, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_dq_m, + stride_dq_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id, +): + tl.assume(stride_q_m >= 0) + tl.assume(stride_q_k >= 0) + tl.assume(stride_dq_m >= 0) + tl.assume(stride_dq_k >= 0) + tl.assume(stride_do_m >= 0) + tl.assume(stride_do_k >= 0) + tl.assume(stride_deltam >= 0) + + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = ( + DQ + offs_m[:, None] * stride_dq_m + offs_k[None, :] * stride_dq_k + ) # [BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + + for iter in range(num_steps): + # Permute the iteration order to reduce the probability that concurrent workgroups (that share the same q head idx and batch idx) are at the same iteration + blk_idx = (iter + workgroup_id) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = _compute_fp8_scaling_factors( + pT_dropout, FP8_MAX + ) + dv += ( + tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) + * descale_p_dropout + * descale_do + ) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = _compute_fp8_scaling_factors(pT, FP8_MAX) + dv += ( + tl.dot((pT * scale_pT).to(do.type.element_ty), do) + * descale_pT + * descale_do + ) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + scale_dsT, descale_dsT = _compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) + * descale_dsT + * descale_q + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = ( + tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k + ) + else: + dq_partial = tl.dot(dsT.to(k.dtype).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None] & (offs_k[None, :] < BLOCK_D_MODEL), + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_dkdvdq_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b_in, + stride_q_h_in, + stride_q_m_in, + stride_q_k_in, + stride_k_b_in, + stride_k_h_in, + stride_k_n_in, + stride_k_k_in, + stride_v_b_in, + stride_v_h_in, + stride_v_n_in, + stride_v_k_in, + stride_dk_b_in, + stride_dk_h_in, + stride_dk_n_in, + stride_dk_k_in, + stride_dq_b_in, + stride_dq_h_in, + stride_dq_m_in, + stride_dq_k_in, + stride_delta_b_in, + stride_delta_h_in, + stride_delta_m_in, + stride_do_b_in, + stride_do_h_in, + stride_do_m_in, + stride_do_k_in, + stride_dropout_b_in, + stride_dropout_h_in, + stride_dropout_m_in, + stride_dropout_n_in, + stride_descale_q_z_in, + stride_descale_k_z_in, + stride_descale_v_z_in, + stride_descale_do_z_in, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base_in, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + NUM_SMS: tl.constexpr, + USE_INT64_STRIDES: tl.constexpr, +): + if USE_INT64_STRIDES: + stride_q_b = tl.cast(stride_q_b_in, tl.int64) + stride_q_h = tl.cast(stride_q_h_in, tl.int64) + stride_q_m = tl.cast(stride_q_m_in, tl.int64) + stride_q_k = tl.cast(stride_q_k_in, tl.int64) + stride_k_b = tl.cast(stride_k_b_in, tl.int64) + stride_k_h = tl.cast(stride_k_h_in, tl.int64) + stride_k_n = tl.cast(stride_k_n_in, tl.int64) + stride_k_k = tl.cast(stride_k_k_in, tl.int64) + stride_v_b = tl.cast(stride_v_b_in, tl.int64) + stride_v_h = tl.cast(stride_v_h_in, tl.int64) + stride_v_n = tl.cast(stride_v_n_in, tl.int64) + stride_v_k = tl.cast(stride_v_k_in, tl.int64) + stride_dk_b = tl.cast(stride_dk_b_in, tl.int64) + stride_dk_h = tl.cast(stride_dk_h_in, tl.int64) + stride_dk_n = tl.cast(stride_dk_n_in, tl.int64) + stride_dk_k = tl.cast(stride_dk_k_in, tl.int64) + stride_dq_b = tl.cast(stride_dq_b_in, tl.int64) + stride_dq_h = tl.cast(stride_dq_h_in, tl.int64) + stride_dq_m = tl.cast(stride_dq_m_in, tl.int64) + stride_dq_k = tl.cast(stride_dq_k_in, tl.int64) + stride_delta_b = tl.cast(stride_delta_b_in, tl.int64) + stride_delta_h = tl.cast(stride_delta_h_in, tl.int64) + stride_delta_m = tl.cast(stride_delta_m_in, tl.int64) + stride_do_b = tl.cast(stride_do_b_in, tl.int64) + stride_do_h = tl.cast(stride_do_h_in, tl.int64) + stride_do_m = tl.cast(stride_do_m_in, tl.int64) + stride_do_k = tl.cast(stride_do_k_in, tl.int64) + stride_dropout_b = tl.cast(stride_dropout_b_in, tl.int64) + stride_dropout_h = tl.cast(stride_dropout_h_in, tl.int64) + stride_dropout_m = tl.cast(stride_dropout_m_in, tl.int64) + stride_dropout_n = tl.cast(stride_dropout_n_in, tl.int64) + philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) + if IS_FP8: + stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) + stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) + stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) + stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) + else: + stride_q_b = stride_q_b_in + stride_q_h = stride_q_h_in + stride_q_m = stride_q_m_in + stride_q_k = stride_q_k_in + stride_k_b = stride_k_b_in + stride_k_h = stride_k_h_in + stride_k_n = stride_k_n_in + stride_k_k = stride_k_k_in + stride_v_b = stride_v_b_in + stride_v_h = stride_v_h_in + stride_v_n = stride_v_n_in + stride_v_k = stride_v_k_in + stride_dk_b = stride_dk_b_in + stride_dk_h = stride_dk_h_in + stride_dk_n = stride_dk_n_in + stride_dk_k = stride_dk_k_in + stride_dq_b = stride_dq_b_in + stride_dq_h = stride_dq_h_in + stride_dq_m = stride_dq_m_in + stride_dq_k = stride_dq_k_in + stride_delta_b = stride_delta_b_in + stride_delta_h = stride_delta_h_in + stride_delta_m = stride_delta_m_in + stride_do_b = stride_do_b_in + stride_do_h = stride_do_h_in + stride_do_m = stride_do_m_in + stride_do_k = stride_do_k_in + stride_dropout_b = stride_dropout_b_in + stride_dropout_h = stride_dropout_h_in + stride_dropout_m = stride_dropout_m_in + stride_dropout_n = stride_dropout_n_in + philox_offset_base = philox_offset_base_in + stride_descale_q_z = stride_descale_q_z_in + stride_descale_k_z = stride_descale_k_z_in + stride_descale_v_z = stride_descale_v_z_in + stride_descale_do_z = stride_descale_do_z_in + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_Q_PIDS * BATCH * NUM_K_HEADS - 1 + + NUM_XCD: tl.constexpr = 8 + head_q_idx = wid % NUM_Q_HEADS + head_q_idx = remap_xcd(head_q_idx, NUM_Q_HEADS, NUM_XCD) + seq_k_blk_idx = (wid // NUM_Q_HEADS) % NUM_K_PIDS + batch_idx = (wid // (NUM_K_PIDS * NUM_Q_HEADS)) % BATCH + + # In the backward we dont want concurrent workgroups to handle consecutive heads or blocks, so remap them to be far apart. + head_q_idx = (head_q_idx * 29) % NUM_Q_HEADS + # seq_k_blk_idx = (seq_k_blk_idx * 29) % NUM_K_PIDS + + head_k_idx = head_q_idx // GROUP_SIZE + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + # for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq + + adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + batch_idx * stride_dropout_b + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load( + descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx + ) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if unaligned start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdvdq_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for dq + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.atomic_add(dv_ptr + offs_dkdv, dv, mask=mask_kv, sem="relaxed") + dk *= sm_scale + tl.atomic_add(dk_ptr + offs_dkdv, dk, mask=mask_kv, sem="relaxed") + + +@triton.jit +def _bwd_kernel_dkdvdq_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + DQ, + M, + Delta, + stride_qb_in, + stride_qh_in, + stride_qm_in, + stride_qk_in, + stride_kb_in, + stride_kh_in, + stride_kn_in, + stride_kk_in, + stride_vb_in, + stride_vh_in, + stride_vn_in, + stride_vk_in, + stride_dkb_in, + stride_dkh_in, + stride_dkn_in, + stride_dkk_in, + stride_dqb_in, + stride_dqh_in, + stride_dqm_in, + stride_dqk_in, + stride_deltab_in, + stride_deltah_in, + stride_deltam_in, + stride_dob_in, + stride_doh_in, + stride_dom_in, + stride_dok_in, + stride_dropoutb_in, + stride_dropouth_in, + stride_dropoutm_in, + stride_dropoutn_in, + stride_descale_q_z_in, + stride_descale_k_z_in, + stride_descale_v_z_in, + stride_descale_do_z_in, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + NUM_SMS: tl.constexpr, + USE_INT64_STRIDES: tl.constexpr, +): + if USE_INT64_STRIDES: + stride_qb = tl.cast(stride_qb_in, tl.int64) + stride_qh = tl.cast(stride_qh_in, tl.int64) + stride_qm = tl.cast(stride_qm_in, tl.int64) + stride_qk = tl.cast(stride_qk_in, tl.int64) + stride_kb = tl.cast(stride_kb_in, tl.int64) + stride_kh = tl.cast(stride_kh_in, tl.int64) + stride_kn = tl.cast(stride_kn_in, tl.int64) + stride_kk = tl.cast(stride_kk_in, tl.int64) + stride_vb = tl.cast(stride_vb_in, tl.int64) + stride_vh = tl.cast(stride_vh_in, tl.int64) + stride_vn = tl.cast(stride_vn_in, tl.int64) + stride_vk = tl.cast(stride_vk_in, tl.int64) + stride_dkb = tl.cast(stride_dkb_in, tl.int64) + stride_dkh = tl.cast(stride_dkh_in, tl.int64) + stride_dkn = tl.cast(stride_dkn_in, tl.int64) + stride_dkk = tl.cast(stride_dkk_in, tl.int64) + stride_dqb = tl.cast(stride_dqb_in, tl.int64) + stride_dqh = tl.cast(stride_dqh_in, tl.int64) + stride_dqm = tl.cast(stride_dqm_in, tl.int64) + stride_dqk = tl.cast(stride_dqk_in, tl.int64) + stride_deltab = tl.cast(stride_deltab_in, tl.int64) + stride_deltah = tl.cast(stride_deltah_in, tl.int64) + stride_deltam = tl.cast(stride_deltam_in, tl.int64) + stride_dob = tl.cast(stride_dob_in, tl.int64) + stride_doh = tl.cast(stride_doh_in, tl.int64) + stride_dom = tl.cast(stride_dom_in, tl.int64) + stride_dok = tl.cast(stride_dok_in, tl.int64) + stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) + stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) + stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) + stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) + if IS_FP8: + stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) + stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) + stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) + stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) + else: + stride_qb = stride_qb_in + stride_qh = stride_qh_in + stride_qm = stride_qm_in + stride_qk = stride_qk_in + stride_kb = stride_kb_in + stride_kh = stride_kh_in + stride_kn = stride_kn_in + stride_kk = stride_kk_in + stride_vb = stride_vb_in + stride_vh = stride_vh_in + stride_vn = stride_vn_in + stride_vk = stride_vk_in + stride_dkb = stride_dkb_in + stride_dkh = stride_dkh_in + stride_dkn = stride_dkn_in + stride_dkk = stride_dkk_in + stride_dqb = stride_dqb_in + stride_dqh = stride_dqh_in + stride_dqm = stride_dqm_in + stride_dqk = stride_dqk_in + stride_deltab = stride_deltab_in + stride_deltah = stride_deltah_in + stride_deltam = stride_deltam_in + stride_dob = stride_dob_in + stride_doh = stride_doh_in + stride_dom = stride_dom_in + stride_dok = stride_dok_in + stride_dropoutb = stride_dropoutb_in + stride_dropouth = stride_dropouth_in + stride_dropoutm = stride_dropoutm_in + stride_dropoutn = stride_dropoutn_in + stride_descale_q_z = stride_descale_q_z_in + stride_descale_k_z = stride_descale_k_z_in + stride_descale_v_z = stride_descale_v_z_in + stride_descale_do_z = stride_descale_do_z_in + + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_dq + + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + DQ_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dqm, + stride_dqk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=wid, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@functools.lru_cache(maxsize=1024) +def _get_config(): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + _get_config._config_dict = {} + fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + return _get_config._config_dict["bkwd_fused"] + + +def flash_attn_fused_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + dbias: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int] = 0, + philox_offset: Optional[int] = 0, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + USE_INT64_STRIDES: Optional[bool] = False, + config: Optional[Dict[str, any]] = None, +): + if dbias is not None: + raise ValueError("Bias is not supported yet in the Triton Backend") + + IS_FP8 = _is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + descale_strides = ( + descale_q.stride(0), + descale_k.stride(0), + descale_v.stride(0), + descale_do.stride(0), + ) + else: + FP8_MAX = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( + stride_descale_do_z + ) = None + descale_strides = ( + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + ) + + IS_VARLEN = True if cu_seqlens_q is not None else False + + # get strides and shape + if IS_VARLEN: + # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = ( + len(cu_seqlens_q) - 1, + max_seqlen_q, + q.shape[1], + q.shape[2], + ) + _, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) + dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) + do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) + else: + # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + _, num_k_heads = k.shape[1], k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) + dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) + do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + + # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 + # padding for head_dim. Power of 2 or 16 + BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + + # init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + # [total_tokens, num_q_heads, seqlen_q] + delta_strides = (0, delta.stride(1), delta.stride(0)) + else: + # [batch, num_q_heads, seqlen_q] + delta_strides = delta.stride() + + # preprocess + # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + if config is None: + config = _get_config() + + pre_grid = ( + triton.cdiv(max_seqlen_q, config["preprocess_kernel"]["PRE_BLOCK"]), + batch, + num_q_heads, + ) + + _bwd_preprocess[pre_grid]( + o, + do, + delta, + *o_strides, + *delta_strides, + descale_strides[3], + cu_seqlens_q, + max_seqlen_q, + descale_do, + BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + # dropout_mask + use_dropout = dropout_p > 0.0 + if use_dropout: + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + dropout_strides = dropout_mask.stride() + else: + dropout_mask = None + dropout_strides = (0, 0, 0, 0) + + # Fuses dk,dv and dq computations into one kernel using atomics + if BLOCK_D_MODEL_POW2 > 160 or q.dtype == torch.float32: + config_dkdvdq = config["dkdvdq_kernel_N64"] + else: + config_dkdvdq = config["dkdvdq_kernel_N128"] + + num_k_pids = (max_seqlen_k + config_dkdvdq["BLOCK_N"] - 1) // config_dkdvdq[ + "BLOCK_N" + ] + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + if causal: + grid_dkdvdq = (batch * num_q_heads * num_k_pids,) + + _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + NUM_SMS=NUM_SMS, + USE_INT64_STRIDES=USE_INT64_STRIDES, + **config_dkdvdq, + ) + else: + # in non causal inner loop over grouped q heads + grid_dkdvdq = (batch * num_k_heads * num_k_pids,) + _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + NUM_SMS=NUM_SMS, + USE_INT64_STRIDES=USE_INT64_STRIDES, + **config_dkdvdq, + ) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py b/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py new file mode 100644 index 00000000000..aace3dabc45 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py @@ -0,0 +1,1806 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional, Dict +import functools +import json +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.ops.triton.utils.mha_kernel_utils import ( + _compute_fp8_scaling_factors, + _is_fp8, +) + + +# NOTE: triton fails to import tl.constexprs so create them here for the file +DROPOUT_USE_PYTORCH = False +DROPOUT_DUMP = False + +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +@triton.jit +def _bwd_preprocess( + o_ptr, + do_ptr, # noqa: E741 + delta_ptr, + stride_o_b, + stride_o_h, + stride_o_m, + stride_o_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, + max_seqlen_q, + descale_do_ptr, + BLOCK_M: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hid = tl.program_id(2) # head + + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # Offset O/DO by batch, head and q_start + offs = ( + bid * stride_o_b + + hid * stride_o_h + + q_start * stride_o_m + + offs_m[:, None] * stride_o_m + + offs_k[None, :] * stride_o_k + ) + + # create masks + mask_m = offs_m < seqlen_q + mask = mask_m[:, None] + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D_MODEL + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs, mask=mask, other=0.0) + do = tl.load(do_ptr + offs, mask=mask, other=0.0) + + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + + offs_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, + dv, # output + Q, + k, + v, + DO, + M, + D, + sm_scale, # input tensor + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = ( + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_nm) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT_scaled - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print( + f"qkT after causal: {qkT.shape}\n", + tl.where(causal_mask, qkT * sm_scale, 0.0), + ) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = _compute_fp8_scaling_factors( + pT_dropout, FP8_MAX + ) + dv += ( + tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) + * descale_p_dropout + * descale_do + ) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = _compute_fp8_scaling_factors(pT, FP8_MAX) + dv += ( + tl.dot((pT * scale_pT).to(do.type.element_ty), do) + * descale_pT + * descale_do + ) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = _compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) + * descale_dsT + * descale_q + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, + K, + V, + do, + m, + Delta, + sm_scale, # input + # shared by Q/K/V. + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, + seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, + start_n, + end_n, + num_steps, # + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: + print( + f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}" + ) # noqa: E701 + if DEBUG_TRITON_DETAIL: + print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = ( + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk_scaled - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = tl.dot(do, vT) * descale_do * descale_v + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp - delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = _compute_fp8_scaling_factors(ds, FP8_MAX) + dq += ( + tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) + * descale_ds + * descale_k + ) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb_in, + stride_qh_in, + stride_qm_in, + stride_qd_in, + stride_kb_in, + stride_kh_in, + stride_kn_in, + stride_kd_in, + stride_vb_in, + stride_vh_in, + stride_vn_in, + stride_vd_in, + stride_dqb_in, + stride_dqh_in, + stride_dqm_in, + stride_dqd_in, + stride_dkb_in, + stride_dkh_in, + stride_dkn_in, + stride_dkd_in, + stride_dvb_in, + stride_dvh_in, + stride_dvn_in, + stride_dvd_in, + stride_deltab_in, + stride_deltah_in, + stride_deltam_in, + stride_dob_in, + stride_doh_in, + stride_dom_in, + stride_dod_in, + stride_dropoutb_in, + stride_dropouth_in, + stride_dropoutm_in, + stride_dropoutn_in, + stride_descale_q_z_in, + stride_descale_k_z_in, + stride_descale_v_z_in, + stride_descale_do_z_in, + stride_az_in, + stride_ah_in, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base_in, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + Descale_do, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, + USE_INT64_STRIDES: tl.constexpr, +): + if USE_INT64_STRIDES: + stride_qb = tl.cast(stride_qb_in, tl.int64) + stride_qh = tl.cast(stride_qh_in, tl.int64) + stride_qm = tl.cast(stride_qm_in, tl.int64) + stride_qd = tl.cast(stride_qd_in, tl.int64) + stride_kb = tl.cast(stride_kb_in, tl.int64) + stride_kh = tl.cast(stride_kh_in, tl.int64) + stride_kn = tl.cast(stride_kn_in, tl.int64) + stride_kd = tl.cast(stride_kd_in, tl.int64) + stride_vb = tl.cast(stride_vb_in, tl.int64) + stride_vh = tl.cast(stride_vh_in, tl.int64) + stride_vn = tl.cast(stride_vn_in, tl.int64) + stride_vd = tl.cast(stride_vd_in, tl.int64) + stride_dqb = tl.cast(stride_dqb_in, tl.int64) + stride_dqh = tl.cast(stride_dqh_in, tl.int64) + stride_dqm = tl.cast(stride_dqm_in, tl.int64) + stride_dqd = tl.cast(stride_dqd_in, tl.int64) + stride_dkb = tl.cast(stride_dkb_in, tl.int64) + stride_dkh = tl.cast(stride_dkh_in, tl.int64) + stride_dkn = tl.cast(stride_dkn_in, tl.int64) + stride_dkd = tl.cast(stride_dkd_in, tl.int64) + stride_dvb = tl.cast(stride_dvb_in, tl.int64) + stride_dvh = tl.cast(stride_dvh_in, tl.int64) + stride_dvn = tl.cast(stride_dvn_in, tl.int64) + stride_dvd = tl.cast(stride_dvd_in, tl.int64) + stride_deltab = tl.cast(stride_deltab_in, tl.int64) + stride_deltah = tl.cast(stride_deltah_in, tl.int64) + stride_deltam = tl.cast(stride_deltam_in, tl.int64) + stride_dob = tl.cast(stride_dob_in, tl.int64) + stride_doh = tl.cast(stride_doh_in, tl.int64) + stride_dom = tl.cast(stride_dom_in, tl.int64) + stride_dod = tl.cast(stride_dod_in, tl.int64) + philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) + stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) + stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) + stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) + stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) + if IS_FP8: + stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) + stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) + stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) + stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) + stride_az = tl.cast(stride_az_in, tl.int64) + stride_ah = tl.cast(stride_ah_in, tl.int64) + else: + stride_qb = stride_qb_in + stride_qh = stride_qh_in + stride_qm = stride_qm_in + stride_qd = stride_qd_in + stride_kb = stride_kb_in + stride_kh = stride_kh_in + stride_kn = stride_kn_in + stride_kd = stride_kd_in + stride_vb = stride_vb_in + stride_vh = stride_vh_in + stride_vn = stride_vn_in + stride_vd = stride_vd_in + stride_dqb = stride_dqb_in + stride_dqh = stride_dqh_in + stride_dqm = stride_dqm_in + stride_dqd = stride_dqd_in + stride_dkb = stride_dkb_in + stride_dkh = stride_dkh_in + stride_dkn = stride_dkn_in + stride_dkd = stride_dkd_in + stride_dvb = stride_dvb_in + stride_dvh = stride_dvh_in + stride_dvn = stride_dvn_in + stride_dvd = stride_dvd_in + stride_deltab = stride_deltab_in + stride_deltah = stride_deltah_in + stride_deltam = stride_deltam_in + stride_dob = stride_dob_in + stride_doh = stride_doh_in + stride_dom = stride_dom_in + stride_dod = stride_dod_in + philox_offset_base = philox_offset_base_in + stride_dropoutb = stride_dropoutb_in + stride_dropouth = stride_dropouth_in + stride_dropoutm = stride_dropoutm_in + stride_dropoutn = stride_dropoutn_in + stride_descale_q_z = stride_descale_q_z_in + stride_descale_k_z = stride_descale_k_z_in + stride_descale_v_z = stride_descale_v_z_in + stride_descale_do_z = stride_descale_do_z_in + stride_az = stride_az_in + stride_ah = stride_ah_in + + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: + print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + offs_d = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: + print( + f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}" + ) # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: + print( + f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}" + ) # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] + + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: + print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + ) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: + print( + f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" + ) # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM, + ACTUAL_HEAD_DIM, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: + print( + f"start_m after Masked step: {start_m}; num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print( + f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM, + ACTUAL_HEAD_DIM, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: + print( + f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}" + ) # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: + print( + f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}" + ) # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: + print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + ) + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM, + ACTUAL_HEAD_DIM, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: + print( + f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" + ) # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, + ACTUAL_HEAD_DIM, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + + +@triton.jit +def bwd_kernel_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb_in, + stride_qh_in, + stride_qm_in, + stride_qd_in, + stride_kb_in, + stride_kh_in, + stride_kn_in, + stride_kd_in, + stride_vb_in, + stride_vh_in, + stride_vn_in, + stride_vd_in, + stride_dqb_in, + stride_dqh_in, + stride_dqm_in, + stride_dqd_in, + stride_dkb_in, + stride_dkh_in, + stride_dkn_in, + stride_dkd_in, + stride_dvb_in, + stride_dvh_in, + stride_dvn_in, + stride_dvd_in, + stride_deltab_in, + stride_deltah_in, + stride_deltam_in, + stride_dob_in, + stride_doh_in, + stride_dom_in, + stride_dod_in, + stride_dropoutb_in, + stride_dropouth_in, + stride_dropoutm_in, + stride_dropoutn_in, + stride_descale_q_z_in, + stride_descale_k_z_in, + stride_descale_v_z_in, + stride_descale_do_z_in, + stride_az_in, + stride_ah_in, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base_in, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + Descale_do, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, + USE_INT64_STRIDES: tl.constexpr, +): + if USE_INT64_STRIDES: + stride_qb = tl.cast(stride_qb_in, tl.int64) + stride_qh = tl.cast(stride_qh_in, tl.int64) + stride_qm = tl.cast(stride_qm_in, tl.int64) + stride_qd = tl.cast(stride_qd_in, tl.int64) + stride_kb = tl.cast(stride_kb_in, tl.int64) + stride_kh = tl.cast(stride_kh_in, tl.int64) + stride_kn = tl.cast(stride_kn_in, tl.int64) + stride_kd = tl.cast(stride_kd_in, tl.int64) + stride_vb = tl.cast(stride_vb_in, tl.int64) + stride_vh = tl.cast(stride_vh_in, tl.int64) + stride_vn = tl.cast(stride_vn_in, tl.int64) + stride_vd = tl.cast(stride_vd_in, tl.int64) + stride_dqb = tl.cast(stride_dqb_in, tl.int64) + stride_dqh = tl.cast(stride_dqh_in, tl.int64) + stride_dqm = tl.cast(stride_dqm_in, tl.int64) + stride_dqd = tl.cast(stride_dqd_in, tl.int64) + stride_dkb = tl.cast(stride_dkb_in, tl.int64) + stride_dkh = tl.cast(stride_dkh_in, tl.int64) + stride_dkn = tl.cast(stride_dkn_in, tl.int64) + stride_dkd = tl.cast(stride_dkd_in, tl.int64) + stride_dvb = tl.cast(stride_dvb_in, tl.int64) + stride_dvh = tl.cast(stride_dvh_in, tl.int64) + stride_dvn = tl.cast(stride_dvn_in, tl.int64) + stride_dvd = tl.cast(stride_dvd_in, tl.int64) + stride_deltab = tl.cast(stride_deltab_in, tl.int64) + stride_deltah = tl.cast(stride_deltah_in, tl.int64) + stride_deltam = tl.cast(stride_deltam_in, tl.int64) + stride_dob = tl.cast(stride_dob_in, tl.int64) + stride_doh = tl.cast(stride_doh_in, tl.int64) + stride_dom = tl.cast(stride_dom_in, tl.int64) + stride_dod = tl.cast(stride_dod_in, tl.int64) + philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) + stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) + stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) + stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) + stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) + if IS_FP8: + stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) + stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) + stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) + stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) + stride_az = tl.cast(stride_az_in, tl.int64) + stride_ah = tl.cast(stride_ah_in, tl.int64) + else: + stride_qb = stride_qb_in + stride_qh = stride_qh_in + stride_qm = stride_qm_in + stride_qd = stride_qd_in + stride_kb = stride_kb_in + stride_kh = stride_kh_in + stride_kn = stride_kn_in + stride_kd = stride_kd_in + stride_vb = stride_vb_in + stride_vh = stride_vh_in + stride_vn = stride_vn_in + stride_vd = stride_vd_in + stride_dqb = stride_dqb_in + stride_dqh = stride_dqh_in + stride_dqm = stride_dqm_in + stride_dqd = stride_dqd_in + stride_dkb = stride_dkb_in + stride_dkh = stride_dkh_in + stride_dkn = stride_dkn_in + stride_dkd = stride_dkd_in + stride_dvb = stride_dvb_in + stride_dvh = stride_dvh_in + stride_dvn = stride_dvn_in + stride_dvd = stride_dvd_in + stride_deltab = stride_deltab_in + stride_deltah = stride_deltah_in + stride_deltam = stride_deltam_in + stride_dob = stride_dob_in + stride_doh = stride_doh_in + stride_dom = stride_dom_in + stride_dod = stride_dod_in + philox_offset_base = philox_offset_base_in + stride_dropoutb = stride_dropoutb_in + stride_dropouth = stride_dropouth_in + stride_dropoutm = stride_dropoutm_in + stride_dropoutn = stride_dropoutn_in + stride_descale_q_z = stride_descale_q_z_in + stride_descale_k_z = stride_descale_k_z_in + stride_descale_v_z = stride_descale_v_z_in + stride_descale_do_z = stride_descale_do_z_in + stride_az = stride_az_in + stride_ah = stride_ah_in + + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + offs_d = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] + # NOTE: don't assume that the strides for k and v are the same! + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + ) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM, + ACTUAL_HEAD_DIM, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + ) + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, + ACTUAL_HEAD_DIM, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() + + +@functools.lru_cache(maxsize=1024) +def _get_config(): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + _get_config._config_dict = {} + fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + return _get_config._config_dict["bkwd_onekernel"] + + +def flash_attn_onekernel_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + dbias: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int] = 0, + philox_offset: Optional[int] = 0, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + USE_INT64_STRIDES: Optional[bool] = False, + config: Optional[Dict[str, any]] = None, +): + if dbias is not None: + raise ValueError("Bias is not supported yet in the Triton Backend") + + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + + IS_FP8 = _is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + descale_strides = ( + descale_q.stride(0), + descale_k.stride(0), + descale_v.stride(0), + descale_do.stride(0), + ) + else: + FP8_MAX = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( + stride_descale_do_z + ) = None + descale_strides = ( + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + ) + + IS_VARLEN = True if cu_seqlens_q is not None else False + + # get strides and shape + if IS_VARLEN: + # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = ( + len(cu_seqlens_q) - 1, + max_seqlen_q, + q.shape[1], + q.shape[2], + ) + _, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) + dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) + dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) + do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) + else: + # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + _, num_k_heads = k.shape[1], k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) + dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) + dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) + do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + + # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 + # padding for head_dim. Power of 2 or 16 + BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + + # Configs + if config is None: + config = _get_config() + + # init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + # [total_tokens, num_q_heads, seqlen_q] + delta_strides = (0, delta.stride(1), delta.stride(0)) + else: + # [batch, num_q_heads, seqlen_q] + delta_strides = delta.stride() + + # preprocess + # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + pre_grid = ( + triton.cdiv(max_seqlen_q, config["preprocess_kernel"]["PRE_BLOCK"]), + batch, + num_q_heads, + ) + _bwd_preprocess[pre_grid]( + o, + do, + delta, + *o_strides, + *delta_strides, + descale_strides[3], + cu_seqlens_q, + max_seqlen_q, + descale_do, + BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + + # dropout_mask + use_dropout = dropout_p > 0.0 + if use_dropout: + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + dropout_strides = dropout_mask.stride() + else: + dropout_mask = None + dropout_strides = (0, 0, 0, 0) + + seqlen = max(max_seqlen_q, max_seqlen_k) + + config_onekernel = config["onekernel"] + grid = ( + num_k_heads, + triton.cdiv(seqlen, config_onekernel["BLOCK_N1"]), + batch, + ) + + if causal: + bwd_kernel_causal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *dk_strides, + *dv_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + stride_az, + stride_ah, + num_q_heads, + num_k_heads, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + descale_do, + HEAD_DIM=head_sz, + ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=True, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=False, + DEBUG_TRITON=False, + DEBUG_TRITON_DETAIL=False, + USE_INT64_STRIDES=USE_INT64_STRIDES, + **config_onekernel, + ) + else: + bwd_kernel_noncausal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *dk_strides, + *dv_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + stride_az, + stride_ah, + num_q_heads, + num_k_heads, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + descale_do, + HEAD_DIM=head_sz, + ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=True, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=False, + DEBUG_TRITON=False, + DEBUG_TRITON_DETAIL=False, + USE_INT64_STRIDES=USE_INT64_STRIDES, + **config_onekernel, + ) + + return delta From aeb5fc2579c74706f6944b96ebd9f85d68914a15 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 15:05:43 +0000 Subject: [PATCH 2/9] remove aiter stuff --- flash_attn/flash_attn_triton_amd/mha.py | 2020 ----------------- .../flash_attn_triton_amd/mha_fused_bwd.py | 1272 ----------- .../mha_onekernel_bwd.py | 1806 --------------- 3 files changed, 5098 deletions(-) delete mode 100644 flash_attn/flash_attn_triton_amd/mha.py delete mode 100644 flash_attn/flash_attn_triton_amd/mha_fused_bwd.py delete mode 100644 flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py diff --git a/flash_attn/flash_attn_triton_amd/mha.py b/flash_attn/flash_attn_triton_amd/mha.py deleted file mode 100644 index b425db59351..00000000000 --- a/flash_attn/flash_attn_triton_amd/mha.py +++ /dev/null @@ -1,2020 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -from typing import Optional, Tuple -import functools -import json -import torch -import triton -import triton.language as tl - -import aiter.ops.triton.utils.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -from aiter.ops.triton.utils.pid_preprocessing import remap_xcd -from aiter.ops.triton.mha_onekernel_bwd import flash_attn_onekernel_backward -from aiter.ops.triton.mha_fused_bwd import flash_attn_fused_backward -from aiter.ops.triton.utils.mha_kernel_utils import ( - _compute_fp8_scaling_factors, - _is_fp8, -) - -global _USE_FUSED_BWD_KERNEL -_USE_FUSED_BWD_KERNEL = False - - -def mha_set_use_fused_bwd_kernel(value: bool): - global _USE_FUSED_BWD_KERNEL - _USE_FUSED_BWD_KERNEL = value - - -_USE_INT64_STRIDES = True - - -def mha_set_use_int64_strides(value: bool): - """Use 64-bit integer strides to prevent integer overflows with very large tensors.""" - global _USE_INT64_STRIDES - _USE_INT64_STRIDES = value - - -def _cast_to_fp8( - x: torch.Tensor, - fp8_dtype, - layout, - clamp_val=1e-9, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Convert a tensor to FP8 format, returning an FP8 tensor and a descale factor. - Args: - - x (torch.Tensor): shape [batch, seq_len, heads, dim] - Returns: - - x_fp8 (torch.Tensor): FP8 tensor with the same shape as x - - descale_factor (torch.Tensor): tensor of shape [batch, 1, heads, 1] - """ - if len(x.shape) != 4: - raise ValueError( - f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}" - ) - reduce_dims = (1, 3) # seq_len and dim dimensions - - # Compute the absolute max along reduce_dims, clamped to avoid 0-scale - x_abs_max = x.abs().amax(dim=reduce_dims) - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # Unsqueeze back to a shape suitable for broadcast - unsqueeze_dims = sorted(reduce_dims) - for d in unsqueeze_dims: - x_abs_max = x_abs_max.unsqueeze(d) - - # compute scale and descale - fp8_max = torch.finfo(fp8_dtype).max - scale = fp8_max / x_abs_max - descale_factor = x_abs_max / fp8_max - - # cast to FP8, optionally setting requires_grad - x_fp8 = (x * scale).to(fp8_dtype) - - return x_fp8, descale_factor - - -def _cast_varlen_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - cu_seqlens, - clamp_val: float = 1e-9, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert a tensor of sequences with variable seq_len into fp8. - Args: - - x (torch.Tensor): shape [total_seq_len, heads, dim] - Returns: - - x_fp8 (torch.Tensor): shape [total_seq_len, heads, dim] - - descale_factors (torch.Tensor): shape [batch, heads] - """ - # validate tensor shape - if len(x.shape) != 3: - raise ValueError( - f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}" - ) - num_heads = x.shape[1] - - # Get batch size from cu_seqlens - batch = cu_seqlens.shape[0] - 1 - fp8_max = torch.finfo(fp8_dtype).max - - # Compute scale and descale factors per sequence - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros( - (batch, num_heads), device=x.device, dtype=torch.float32 - ) - - for i in range(batch): - start = cu_seqlens[i] - end = cu_seqlens[i + 1] - x_slice = x[start:end] # Slice for current sequence - - # Standard tensor (0: seq_len, 2: head_dim) - x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] - - # apply minimum clamping - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # compute scale and descale factors - scale_i = fp8_max / x_abs_max - descale_i = x_abs_max / fp8_max - - # store descale factors - descale_factors[i, :] = descale_i - - scale_reshape = scale_i.reshape(1, num_heads, 1) - - # scale and cast to FP8 - x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) - - return x_fp8, descale_factors - - -@triton.jit -def _cdiv_fn(x, y): - return (x + y - 1) // y - - -@triton.jit -def _load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & ( - offset_second[None, :] < boundary_second - ) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor - - -@triton.jit -def _compute_alibi_block( - alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False -): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vk, - stride_sn, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - sd_mask_ptrs, - dropout_mask_ptrs, - philox_seed, - philox_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - descale_q, - descale_k, - descale_v, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - SM_SCALE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_SCORES: tl.constexpr, - PADDED_HEAD: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - RCP_LN2: tl.constexpr = 1.4426950408889634 - - # loop over k, v, and update accumulator - - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) - k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - mask = tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1) - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - - # remove the old if condition - # if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - # Though this will unconditionally compute mask_partial at runtime, - # the causal for loop does not have the if-else block any more, which - # helps instruction scheduling and register pressure. - bound_cond = (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0) - boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask_partial = size_n < boundary_m[:, None] - mask = tl.where(bound_cond, mask_partial, mask) - - # compute masks - q_mask = OFFS_M[:, None] < seqlen_q - k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k - p_mask = q_mask & k_mask - - # -- compute qk ---- - if IS_FP8: - qk += tl.dot(q, k) * descale_q * descale_k - else: - qk += tl.dot(q, k) - - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - mask = mask and causal_mask - - qk = tl.where(mask, qk, float("-inf")) - - if alibi_slope is not None: - # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = _compute_alibi_block( - alibi_slope, seqlen_q, seqlen_k, global_m_positions, global_n_positions - ) - qk += alibi_block / SM_SCALE - # get max scores so far - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - m_ij_scaled = m_ij * SM_SCALE * RCP_LN2 - - # scale and subtract max - q_shifted = qk * SM_SCALE * RCP_LN2 - m_ij_scaled[:, None] - - # Compute scaled QK and softmax probabilities - p = tl.math.exp2(q_shifted) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - rng_output = tl.rand( - philox_seed, philox_ptrs - ) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) - - # apply dropout mask in place - p = tl.where(dropout_mask, p, 0.0) - elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - tl.store(sd_mask_ptrs, p, mask=p_mask) - - # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff_scaled = m_i * SM_SCALE * RCP_LN2 - m_ij_scaled - alpha = tl.math.exp2(m_diff_scaled) - acc = acc * alpha[:, None] - v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - - if IS_FP8: - scale_p, descale_p = _compute_fp8_scaling_factors(p, FP8_MAX) - acc += ( - tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v - ) - else: - acc += tl.dot(p.to(v.type.element_ty), v) - - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if RETURN_SCORES: - sd_mask_ptrs += BLOCK_N * stride_sn - - if ENABLE_DROPOUT: - dropout_mask_ptrs += BLOCK_N * stride_sn - philox_ptrs += BLOCK_N * stride_sn - - return acc, l_i, m_i - - -@triton.jit -def _attn_fwd( - q_ptr: torch.Tensor, - k_ptr: torch.Tensor, - v_ptr: torch.Tensor, - descale_q_ptr: torch.Tensor, - descale_k_ptr: torch.Tensor, - descale_v_ptr: torch.Tensor, - out_ptr: torch.Tensor, - alibi_slopes_ptr: torch.Tensor, - s_dmask_ptr: torch.Tensor, - dropout_mask_ptr: torch.Tensor, - softmax_lse_ptr: torch.Tensor, - stride_qz_in, - stride_qh_in, - stride_qm_in, - stride_qk_in, - stride_kz_in, - stride_kh_in, - stride_kn_in, - stride_kk_in, - stride_vz_in, - stride_vh_in, - stride_vn_in, - stride_vk_in, - stride_descale_q_z_in, - stride_descale_k_z_in, - stride_descale_v_z_in, - stride_oz_in, - stride_oh_in, - stride_om_in, - stride_on_in, - stride_alibi_z_in, - stride_alibi_h_in, - stride_sd_z_in, - stride_sd_h_in, - stride_sd_m_in, - stride_sd_n_in, - stride_lse_z_in, - stride_lse_h_in, - stride_lse_m_in, - sm_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base_in, - SEQLEN_Q: tl.constexpr, - SEQLEN_K: tl.constexpr, - IS_CAUSAL: tl.constexpr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - RETURN_SCORES: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - VARLEN: tl.constexpr, - BATCH, - NUM_XCD: tl.constexpr, - USE_INT64_STRIDES: tl.constexpr, -): - NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M - # calculate offsets - wid = tl.program_id( - 0 - ) # workgroup id ranging: 0,1,2,...., (BATCH * NUM_Q_HEADS * NUM_BLOCKS - 1) - # num blocks along seqlen - - off_q_head = wid % NUM_Q_HEADS - off_q_head = remap_xcd(off_q_head, NUM_Q_HEADS, NUM_XCD) - start_m = (wid // NUM_Q_HEADS) % NUM_BLOCKS - off_z = (wid // (NUM_BLOCKS * NUM_Q_HEADS)) % BATCH - - # offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_POW2) - - # NOTE: - # Workaround for int64 strides, In the absence of strides being int64, parts of the offset - # computation is done in 32 bit and overflows resulting in segfaults - # If input strides are defined as int64, it disables vectorized loads which drops perf - # If we define new strides as stride_x = stride_x_in.to(tl.int64), that does not work - # because strides are tl.constexpr and cannot be upcasted - # If we define new strides as stride_x: tl.int64 = stride_x_in, segfault remains - # The permanent solution is to enable upcasting of tl.constexpr - # In the meantime, the following workaround provides correctness and does not drop perf - if USE_INT64_STRIDES: - stride_qz = tl.cast(stride_qz_in, tl.int64) - stride_qh = tl.cast(stride_qh_in, tl.int64) - stride_qm = tl.cast(stride_qm_in, tl.int64) - stride_qk = tl.cast(stride_qk_in, tl.int64) - stride_kz = tl.cast(stride_kz_in, tl.int64) - stride_kh = tl.cast(stride_kh_in, tl.int64) - stride_kn = tl.cast(stride_kn_in, tl.int64) - stride_kk = tl.cast(stride_kk_in, tl.int64) - stride_vz = tl.cast(stride_vz_in, tl.int64) - stride_vh = tl.cast(stride_vh_in, tl.int64) - stride_vn = tl.cast(stride_vn_in, tl.int64) - stride_vk = tl.cast(stride_vk_in, tl.int64) - if IS_FP8: - stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) - stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) - stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) - stride_oz = tl.cast(stride_oz_in, tl.int64) - stride_oh = tl.cast(stride_oh_in, tl.int64) - stride_om = tl.cast(stride_om_in, tl.int64) - stride_on = tl.cast(stride_on_in, tl.int64) - stride_alibi_z = tl.cast(stride_alibi_z_in, tl.int64) - stride_alibi_h = tl.cast(stride_alibi_h_in, tl.int64) - - # NOTE: philox offset is need in dropout pointer calculations - philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) - stride_sd_z = tl.cast(stride_sd_z_in, tl.int64) - stride_sd_h = tl.cast(stride_sd_h_in, tl.int64) - stride_sd_m = tl.cast(stride_sd_m_in, tl.int64) - stride_sd_n = tl.cast(stride_sd_n_in, tl.int64) - stride_lse_z = tl.cast(stride_lse_z_in, tl.int64) - stride_lse_h = tl.cast(stride_lse_h_in, tl.int64) - stride_lse_m = tl.cast(stride_lse_m_in, tl.int64) - else: - stride_qz = stride_qz_in - stride_qm = stride_qm_in - stride_qk = stride_qk_in - stride_qh = stride_qh_in - stride_kz = stride_kz_in - stride_kh = stride_kh_in - stride_kn = stride_kn_in - stride_kk = stride_kk_in - stride_vz = stride_vz_in - stride_vh = stride_vh_in - stride_vn = stride_vn_in - stride_vk = stride_vk_in - stride_descale_q_z = stride_descale_q_z_in - stride_descale_k_z = stride_descale_k_z_in - stride_descale_v_z = stride_descale_v_z_in - stride_oz = stride_oz_in - stride_oh = stride_oh_in - stride_om = stride_om_in - stride_on = stride_on_in - stride_alibi_z = stride_alibi_z_in - stride_alibi_h = stride_alibi_h_in - philox_offset_base = philox_offset_base_in - stride_sd_z = stride_sd_z_in - stride_sd_h = stride_sd_h_in - stride_sd_m = stride_sd_m_in - stride_sd_n = stride_sd_n_in - stride_lse_z = stride_lse_z_in - stride_lse_h = stride_lse_h_in - stride_lse_m = stride_lse_m_in - - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = SEQLEN_Q - seqlen_k = SEQLEN_K - - n_blocks = _cdiv_fn(seqlen_k, BLOCK_N) - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = _cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N - ) - - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. - if n_blocks <= 0: - offs_out = ( - off_z * stride_oz - + off_q_head * stride_oh - + cu_seqlens_q_start * stride_om - + offs_m[:, None] * stride_om - + offs_d[None, :] * stride_on - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) - out_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) - tl.store(out_ptr + offs_out, acc, mask=out_mask) - - if softmax_lse_ptr is not None: - offs_lse = ( - off_z * stride_lse_z - + off_q_head * stride_lse_h - + cu_seqlens_q_start * stride_lse_m - + offs_m * stride_lse_m - ) - lse_mask = offs_m < SEQLEN_Q - lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) - tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) - # TODO: Should dropout and return encoded softmax be handled here too? - - return - - grp_sz: tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS - if grp_sz != 1: # Grouped Query Attention - off_k_head = off_q_head // grp_sz - else: - off_k_head = off_q_head - - # q,k,v offsets - q_offs = ( - off_z * stride_qz - + off_q_head * stride_qh - + cu_seqlens_q_start * stride_qm - + offs_m[:, None] * stride_qm - + offs_d[None, :] * stride_qk - ) - q_ptrs = q_ptr + q_offs - - k_offs = ( - off_z * stride_kz - + off_k_head * stride_kh - + cu_seqlens_k_start * stride_kn - + offs_d[:, None] * stride_kk - + offs_n[None, :] * stride_kn - ) - k_ptrs = k_ptr + k_offs - - v_offs = ( - off_z * stride_vz - + off_k_head * stride_vh - + cu_seqlens_k_start * stride_vn - + offs_n[:, None] * stride_vn - + offs_d[None, :] * stride_vk - ) - v_ptrs = v_ptr + v_offs - - # alibi slopes - if alibi_slopes_ptr is not None: - alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h - alibi_slope = tl.load(alibi_slopes_ptr + alibi_offs) - else: - alibi_slope = None - - # s_dmask (return_scores) - if s_dmask_ptr is not None: - s_dmask_offs = ( - off_z * stride_sd_z - + off_q_head * stride_sd_h - + offs_m[:, None] * stride_sd_m - + offs_n[None, :] * stride_sd_n - ) - s_dmask_ptrs = s_dmask_ptr + s_dmask_offs - else: - s_dmask_ptrs = None - - # dropout - if dropout_mask_ptr is not None: - dropout_mask_offs = ( - off_z * stride_sd_z - + off_q_head * stride_sd_h - + offs_m[:, None] * stride_sd_m - + offs_n[None, :] * stride_sd_n - ) - dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs - philox_ptrs = ( - philox_offset_base - + off_z * stride_sd_z - + off_q_head * stride_sd_h - + offs_m[:, None] * stride_sd_m - + offs_n[None, :] * stride_sd_n - ) - else: - dropout_mask_ptrs = None - philox_ptrs = None - - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) - if BLOCK_DMODEL == BLOCK_DMODEL_POW2: - q_mask = offs_m[:, None] < seqlen_q - else: - q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_mask, other=0.0) - if IS_FP8: - descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) - descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) - descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) - else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - - # if CAUSAL, then determine masked_blocks and full blocks - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vn, - stride_sd_n, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - s_dmask_ptrs, - dropout_mask_ptrs, - philox_seed, - philox_ptrs, - block_min, - block_max, - 0, - 0, - 0, - alibi_slope, - descale_q, - descale_k, - descale_v, - offs_m, - offs_n, - BLOCK_M, - BLOCK_N, - BLOCK_DMODEL, - BLOCK_DMODEL_POW2, - sm_scale, - False, - MASK_STEPS=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, - PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vn - if RETURN_SCORES: - s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - if ENABLE_DROPOUT: - dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vn, - stride_sd_n, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - s_dmask_ptrs, - dropout_mask_ptrs, - philox_seed, - philox_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - descale_q, - descale_k, - descale_v, - offs_m, - offs_n, - BLOCK_M, - BLOCK_N, - BLOCK_DMODEL, - BLOCK_DMODEL_POW2, - sm_scale, - IS_CAUSAL, - MASK_STEPS=True, - ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, - PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - if ENABLE_DROPOUT: - dropout_scale = 1 / (1 - dropout_p) - acc = acc * dropout_scale - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_CAUSAL: - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full( - (BLOCK_DMODEL_POW2,), causal_start_idx, dtype=tl.int32 - ) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - - # write back LSE(Log Sum Exponents), the log of the normalization constant - overflow_size = end_m_idx - seqlen_q - if softmax_lse_ptr is not None: - RCP_LN2: tl.constexpr = 1.4426950408889634 - LN2: tl.constexpr = 0.6931471824645996 - # compute log-sum-exp in base 2 units - # mi_base2 = m_i * RCP_LN2 - mi_base2 = m_i * RCP_LN2 * sm_scale - softmax_lse = mi_base2 + tl.math.log2(l_i) - # convert back to natural units - softmax_lse *= LN2 - - if IS_CAUSAL: - # zero out nans caused by -infs when doing causal - lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx - softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) - - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve - offs_lse = ( - off_z * stride_lse_z - + off_q_head * stride_lse_h - + cu_seqlens_q_start * stride_lse_m - + offs_m * stride_lse_m - ) - if overflow_size > 0: - boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - lse_mask = tl.arange(0, BLOCK_M) < boundary - tl.store( - softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask - ) # the log of the normalization constant - else: - tl.store( - softmax_lse_ptr + offs_lse, softmax_lse - ) # the log of the normalization constant - - # write back O - offs_out = ( - off_z * stride_oz - + off_q_head * stride_oh - + cu_seqlens_q_start * stride_om - + offs_m[:, None] * stride_om - + offs_d[None, :] * stride_on - ) - out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) - if overflow_size > 0: - out_mask = out_mask & (offs_m[:, None] < seqlen_q) - if BLOCK_DMODEL != BLOCK_DMODEL_POW2: - out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) - op = acc.to(out_ptr.dtype.element_ty) - tl.store(out_ptr + offs_out, op, mask=out_mask) - - -@functools.lru_cache(maxsize=1024) -def _get_config( - enable_dropout: bool, - dtype: torch.dtype, -): - if not hasattr(_get_config, "_config_dict"): - dev = arch_info.get_device() - _get_config._config_dict = {} - fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" - with open(fpath, "r") as file: - config = json.load(file) - _get_config._config_dict["default"] = config - - if enable_dropout or dtype == torch.float32: - return _get_config._config_dict["default"]["fwd"]["dropout_or_fp32"] - else: - return _get_config._config_dict["default"]["fwd"]["default"] - - -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - bias: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - return_lse: bool, - return_softmax: bool, - max_seqlen_q: int, - max_seqlen_k: int, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - config: Optional[dict[str, any]] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if bias is not None: - raise ValueError("Bias is not supported yet in the Triton Backend") - if window_size_left != -1 or window_size_right != -1: - raise ValueError("Sliding Window is not supported yet in the Triton Backend") - - # FP8 - IS_FP8 = _is_fp8(q) - FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max - is_varlen = True if cu_seqlens_q is not None else False - - if IS_FP8: - o = torch.zeros_like(q, dtype=torch.float32) - else: - o = torch.zeros_like(q) - if is_varlen: - # Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( - len(cu_seqlens_q) - 1, - max_seqlen_q, - q.shape[1], - q.shape[2], - ) - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k = k.shape[1] - num_k_heads = k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - # padding for head_dim. Power of 2 or 16 - BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) - - # softmax_lse [batch, num_q_heads, seqlen_q] - if is_varlen: - softmax_lse = torch.zeros( - (q.shape[0], num_q_heads), device=q.device, dtype=torch.float32 - ) - stride_lse_z, stride_lse_h, stride_lse_m = ( - 0, - softmax_lse.stride(1), - softmax_lse.stride(0), - ) - else: - softmax_lse = torch.zeros( - (batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32 - ) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - - # exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] - enable_dropout = dropout_p > 0.0 - if enable_dropout: - philox_seed = torch.randint(0, 0xFFFFFF, (1,))[ - 0 - ].item() # No specific reason to restrict range to 0xffffff - philox_offset = torch.randint(0, 0xFFFFFF, (1,))[ - 0 - ].item() # Pass in an int, not Tensor - else: - philox_seed = 0 - philox_offset = 0 - if return_softmax or enable_dropout: - s_dmask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, - ) - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, - ) - else: - s_dmask = None - dropout_mask = None - - if config is None: - config = _get_config(enable_dropout, q.dtype) - - """ - # Tuned for MI300x - config = { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 2, - "num_warps": 4, - "num_ctas": 1, - "num_stages": 1, - } - # Dropout significantly increases VGPR usage so use small tiles - if enable_dropout or q.dtype == torch.float32: - config = { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 1, - "num_warps": 2, - "num_ctas": 1, - "num_stages": 1, - } - """ - - grid = lambda META: ( # noqa: E731 - batch * num_q_heads * triton.cdiv(seqlen_q, META["BLOCK_M"]), - ) - - _attn_fwd[grid]( - q, - k, - v, - descale_q, - descale_k, - descale_v, - o, - alibi_slopes, - s_dmask, - dropout_mask, - softmax_lse, - *q_strides, - *k_strides, - *v_strides, - descale_q.stride(0) if descale_q is not None else 0, - descale_k.stride(0) if descale_k is not None else 0, - descale_v.stride(0) if descale_v is not None else 0, - *o_strides, - alibi_slopes.stride(0) if alibi_slopes is not None else 0, - alibi_slopes.stride(1) if alibi_slopes is not None else 0, - s_dmask.stride(0) if s_dmask is not None else 0, - s_dmask.stride(1) if s_dmask is not None else 0, - s_dmask.stride(2) if s_dmask is not None else 0, - s_dmask.stride(3) if s_dmask is not None else 0, - stride_lse_z if softmax_lse is not None else 0, - stride_lse_h if softmax_lse is not None else 0, - stride_lse_m if softmax_lse is not None else 0, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset, - SEQLEN_Q=max_seqlen_q, - SEQLEN_K=max_seqlen_k, - IS_CAUSAL=causal, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_DMODEL=head_sz, - BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, - RETURN_SCORES=return_softmax, - ENABLE_DROPOUT=enable_dropout, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - VARLEN=is_varlen, - BATCH=batch, - NUM_XCD=8, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - **config, - ) - - return o, softmax_lse, s_dmask, philox_seed, philox_offset - - -class _FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - bias, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=bias, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - config=config, - ) - ) - - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.bias = bias - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return result[0] if len(result) == 1 else tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - bias = ctx.bias - dbias = torch.empty_like(bias) if bias is not None else None - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - print("Using fused backward kernel:", _USE_FUSED_BWD_KERNEL) - - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return ( - dq, - dk, - dv, - None, - None, - None, - None, - dbias, - None, - None, - None, - None, - None, - None, - ) - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - bias=None, - alibi_slopes=None, - deterministic=True, - return_lse=False, - return_attn_probs=False, - config: Optional[dict[str, any]] = None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim_q). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - bias: (seqlen_q, seqlen_k) - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - - return _FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - bias, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - config, - ) - - -class _FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = arch_info.get_fp8_e4m3_dtype() - q_fp8, descale_q = _cast_to_fp8(q, fp8_dtype, "bshd") - k_fp8, descale_k = _cast_to_fp8(k, fp8_dtype, "bshd") - v_fp8, descale_v = _cast_to_fp8(v, fp8_dtype, "bshd") - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=None, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - cu_seqlens_q=None, - cu_seqlens_k=None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - config=config, - ) - ) - - if is_grad: - ctx.save_for_backward( - q_fp8, - k_fp8, - v_fp8, - out_padded, - softmax_lse, - descale_q, - descale_k, - descale_v, - ) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return result[0] if len(result) == 1 else tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ( - ctx.saved_tensors - ) - dq, dk, dv = ( - torch.zeros_like(q_fp8, dtype=torch.float32), - torch.zeros_like(k_fp8, dtype=torch.float32), - torch.zeros_like(v_fp8, dtype=torch.float32), - ) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = arch_info.get_fp8_e4m3_dtype() - do_padded_fp8, descale_do = _cast_to_fp8(do_padded, fp8_dtype, "bshd") - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - - # dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - # dk = dk[..., : k_fp8.shape[-1]] - # dv = dv[..., : v_fp8.shape[-1]] - return ( - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - config: Optional[dict[str, any]] = None, -): - return _FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - config, - ) - - -class _FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - bias, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - out, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=bias, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0.0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - config=config, - ) - ) - if is_grad: - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k - ) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.bias = bias - ctx.alibi_slopes = alibi_slopes - out = out_padded[..., :head_size_og] - - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return result[0] if len(result) == 1 else tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - bias = ctx.bias - dbias = torch.empty_like(bias) if bias is not None else None - head_size_og = do.size(2) - do_padded = do - if head_size_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - dbias, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return ( - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - dbias, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - bias=None, - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - out=None, - config: Optional[dict[str, any]] = None, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - bias: (seqlen_q, seqlen_k) - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return _FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - bias, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - out, - torch.is_grad_enabled(), - config, - ) - - -class _FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = arch_info.get_fp8_e4m3_dtype() - q_fp8, descale_q = _cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) - k_fp8, descale_k = _cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) - v_fp8, descale_v = _cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=None, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - config=config, - ) - ) - if is_grad: - ctx.save_for_backward( - q_fp8, - k_fp8, - v_fp8, - out_padded, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - descale_q, - descale_k, - descale_v, - ) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return result[0] if len(result) == 1 else tuple(result) - - @staticmethod - def backward(ctx, do, *args): - ( - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - descale_q, - descale_k, - descale_v, - ) = ctx.saved_tensors - dq, dk, dv = ( - torch.zeros_like(q_fp8, dtype=torch.float32), - torch.zeros_like(k_fp8, dtype=torch.float32), - torch.zeros_like(v_fp8, dtype=torch.float32), - ) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = arch_info.get_fp8_e4m3_dtype() - do_padded_fp8, descale_do = _cast_varlen_to_fp8( - do_padded, fp8_dtype, "thd", cu_seqlens_q - ) - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k_fp8.shape[-1]] - dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - config: Optional[dict[str, any]] = None, -): - return _FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - config, - ) diff --git a/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py b/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py deleted file mode 100644 index 7073d969602..00000000000 --- a/flash_attn/flash_attn_triton_amd/mha_fused_bwd.py +++ /dev/null @@ -1,1272 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -from typing import Optional, Dict -import functools -import json -import torch -import triton -import triton.language as tl - -import aiter.ops.triton.utils.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -from aiter.ops.triton.utils.pid_preprocessing import remap_xcd -from aiter.ops.triton.utils.mha_kernel_utils import ( - _compute_fp8_scaling_factors, - _is_fp8, -) - - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_preprocess( - o_ptr, - do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, - stride_o_h, - stride_o_m, - stride_o_k, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, - max_seqlen_q, - descale_do_ptr, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_m = tl.program_id(0) # seqlen - bid = tl.program_id(1) # batch - hid = tl.program_id(2) # head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = ( - bid * stride_o_b - + hid * stride_o_h - + q_start * stride_o_m - + offs_m[:, None] * stride_o_m - + offs_k[None, :] * stride_o_k - ) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = ( - bid * stride_delta_b - + hid * stride_delta_h - + q_start * stride_delta_m - + offs_m * stride_delta_m - ) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) - - -@triton.jit -def _bwd_dkdvdq_inner( - dk, - dv, - Q, - k, - v, - DO, - DQ, - M, - D, - sm_scale, - stride_q_m, - stride_q_k, - stride_dq_m, - stride_dq_k, - stride_do_m, - stride_do_k, - stride_dropout_m, - stride_dropout_n, - stride_deltam, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - seqlen_q, - seqlen_k, - start_n, - start_m, - num_steps, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - workgroup_id, -): - tl.assume(stride_q_m >= 0) - tl.assume(stride_q_k >= 0) - tl.assume(stride_dq_m >= 0) - tl.assume(stride_dq_k >= 0) - tl.assume(stride_do_m >= 0) - tl.assume(stride_do_k >= 0) - tl.assume(stride_deltam >= 0) - - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - - qT_ptrs_start = ( - Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k - ) # [BLOCK_D_MODEL_POW2, BLOCK_M] - dq_ptrs_start = ( - DQ + offs_m[:, None] * stride_dq_m + offs_k[None, :] * stride_dq_k - ) # [BLOCK_M, BLOCK_D_MODEL_POW2] - - do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - - # Iterate over blocks(BLOCK_M size) of Q while calculating - # a fixed block(BLOCK_N) of dk and dv. Note, during backward - # pass P has to be recomputed. However, this kernel computes - # dV and dK, so we compute we need P^T and S^T. See backward pass - # equations - # - # From Flash Attention Paper: - # ForwardPass: S = QkT, P=softmax(S), O=PV - # - # BackwardPass equations - # dV = P^TdO - # dP = dOV^T - # dS = dsoftmax(dP) - # dQ = dSK - # dK = QdS^T - - for iter in range(num_steps): - # Permute the iteration order to reduce the probability that concurrent workgroups (that share the same q head idx and batch idx) are at the same iteration - blk_idx = (iter + workgroup_id) % num_steps - - curr_m = start_m + blk_idx * step_m - qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m - dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m - do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m - - offs_m = curr_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < BLOCK_D_MODEL - mask_do &= offs_k[None, :] < BLOCK_D_MODEL - - # load qT - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - - # dropout - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = ( - curr_philox_offset - + offs_m[None, :] * stride_dropout_m - + offs_n[:, None] * stride_dropout_n - ) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - - # Load M - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - - # Compute qkT - if IS_FP8: - qkT = tl.dot(k, qT) * descale_q * descale_k - else: - qkT = tl.dot(k, qT) - - # Compute pT(use m and also apply sm_scale) - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - if MASK: - causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) - mask = causal_mask & mask_nm - pT = tl.where(mask, pT, 0.0) - - # load DO - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - - # dV - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = _compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = _compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - # Load delta - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - - # Compute dP and dS - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - - # compute dk - if IS_FP8: - scale_dsT, descale_dsT = _compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - - # We can compute the dq_partial here and do a atomic add to the correct memory location - # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before - # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) - if IS_FP8: - dq_partial = ( - tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k - ) - else: - dq_partial = tl.dot(dsT.to(k.dtype).T, k) - tl.atomic_add( - dq_ptrs, - dq_partial * sm_scale, - mask=mask_m[:, None] & (offs_k[None, :] < BLOCK_D_MODEL), - sem="relaxed", - ) - - return dk, dv - - -@triton.jit -def _bwd_kernel_dkdvdq_causal( - q_ptr, - k_ptr, - v_ptr, - sm_scale, - do_ptr, - dk_ptr, - dv_ptr, - dq_ptr, - m_ptr, - delta_ptr, - stride_q_b_in, - stride_q_h_in, - stride_q_m_in, - stride_q_k_in, - stride_k_b_in, - stride_k_h_in, - stride_k_n_in, - stride_k_k_in, - stride_v_b_in, - stride_v_h_in, - stride_v_n_in, - stride_v_k_in, - stride_dk_b_in, - stride_dk_h_in, - stride_dk_n_in, - stride_dk_k_in, - stride_dq_b_in, - stride_dq_h_in, - stride_dq_m_in, - stride_dq_k_in, - stride_delta_b_in, - stride_delta_h_in, - stride_delta_m_in, - stride_do_b_in, - stride_do_h_in, - stride_do_m_in, - stride_do_k_in, - stride_dropout_b_in, - stride_dropout_h_in, - stride_dropout_m_in, - stride_dropout_n_in, - stride_descale_q_z_in, - stride_descale_k_z_in, - stride_descale_v_z_in, - stride_descale_do_z_in, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset_base_in, - descale_q_ptr, - descale_k_ptr, - descale_v_ptr, - descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - NUM_SMS: tl.constexpr, - USE_INT64_STRIDES: tl.constexpr, -): - if USE_INT64_STRIDES: - stride_q_b = tl.cast(stride_q_b_in, tl.int64) - stride_q_h = tl.cast(stride_q_h_in, tl.int64) - stride_q_m = tl.cast(stride_q_m_in, tl.int64) - stride_q_k = tl.cast(stride_q_k_in, tl.int64) - stride_k_b = tl.cast(stride_k_b_in, tl.int64) - stride_k_h = tl.cast(stride_k_h_in, tl.int64) - stride_k_n = tl.cast(stride_k_n_in, tl.int64) - stride_k_k = tl.cast(stride_k_k_in, tl.int64) - stride_v_b = tl.cast(stride_v_b_in, tl.int64) - stride_v_h = tl.cast(stride_v_h_in, tl.int64) - stride_v_n = tl.cast(stride_v_n_in, tl.int64) - stride_v_k = tl.cast(stride_v_k_in, tl.int64) - stride_dk_b = tl.cast(stride_dk_b_in, tl.int64) - stride_dk_h = tl.cast(stride_dk_h_in, tl.int64) - stride_dk_n = tl.cast(stride_dk_n_in, tl.int64) - stride_dk_k = tl.cast(stride_dk_k_in, tl.int64) - stride_dq_b = tl.cast(stride_dq_b_in, tl.int64) - stride_dq_h = tl.cast(stride_dq_h_in, tl.int64) - stride_dq_m = tl.cast(stride_dq_m_in, tl.int64) - stride_dq_k = tl.cast(stride_dq_k_in, tl.int64) - stride_delta_b = tl.cast(stride_delta_b_in, tl.int64) - stride_delta_h = tl.cast(stride_delta_h_in, tl.int64) - stride_delta_m = tl.cast(stride_delta_m_in, tl.int64) - stride_do_b = tl.cast(stride_do_b_in, tl.int64) - stride_do_h = tl.cast(stride_do_h_in, tl.int64) - stride_do_m = tl.cast(stride_do_m_in, tl.int64) - stride_do_k = tl.cast(stride_do_k_in, tl.int64) - stride_dropout_b = tl.cast(stride_dropout_b_in, tl.int64) - stride_dropout_h = tl.cast(stride_dropout_h_in, tl.int64) - stride_dropout_m = tl.cast(stride_dropout_m_in, tl.int64) - stride_dropout_n = tl.cast(stride_dropout_n_in, tl.int64) - philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) - if IS_FP8: - stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) - stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) - stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) - stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) - else: - stride_q_b = stride_q_b_in - stride_q_h = stride_q_h_in - stride_q_m = stride_q_m_in - stride_q_k = stride_q_k_in - stride_k_b = stride_k_b_in - stride_k_h = stride_k_h_in - stride_k_n = stride_k_n_in - stride_k_k = stride_k_k_in - stride_v_b = stride_v_b_in - stride_v_h = stride_v_h_in - stride_v_n = stride_v_n_in - stride_v_k = stride_v_k_in - stride_dk_b = stride_dk_b_in - stride_dk_h = stride_dk_h_in - stride_dk_n = stride_dk_n_in - stride_dk_k = stride_dk_k_in - stride_dq_b = stride_dq_b_in - stride_dq_h = stride_dq_h_in - stride_dq_m = stride_dq_m_in - stride_dq_k = stride_dq_k_in - stride_delta_b = stride_delta_b_in - stride_delta_h = stride_delta_h_in - stride_delta_m = stride_delta_m_in - stride_do_b = stride_do_b_in - stride_do_h = stride_do_h_in - stride_do_m = stride_do_m_in - stride_do_k = stride_do_k_in - stride_dropout_b = stride_dropout_b_in - stride_dropout_h = stride_dropout_h_in - stride_dropout_m = stride_dropout_m_in - stride_dropout_n = stride_dropout_n_in - philox_offset_base = philox_offset_base_in - stride_descale_q_z = stride_descale_q_z_in - stride_descale_k_z = stride_descale_k_z_in - stride_descale_v_z = stride_descale_v_z_in - stride_descale_do_z = stride_descale_do_z_in - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - wid = tl.program_id(0) # workgoup id: 0, ..., NUM_Q_PIDS * BATCH * NUM_K_HEADS - 1 - - NUM_XCD: tl.constexpr = 8 - head_q_idx = wid % NUM_Q_HEADS - head_q_idx = remap_xcd(head_q_idx, NUM_Q_HEADS, NUM_XCD) - seq_k_blk_idx = (wid // NUM_Q_HEADS) % NUM_K_PIDS - batch_idx = (wid // (NUM_K_PIDS * NUM_Q_HEADS)) % BATCH - - # In the backward we dont want concurrent workgroups to handle consecutive heads or blocks, so remap them to be far apart. - head_q_idx = (head_q_idx * 29) % NUM_Q_HEADS - # seq_k_blk_idx = (seq_k_blk_idx * 29) % NUM_K_PIDS - - head_k_idx = head_q_idx // GROUP_SIZE - - # Determine q and k start along with seqlen_q and seqlen_k - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - else: - start_delta = start_delta_q_lt_k - - start_n = seq_k_blk_idx * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_kv &= mask_k[None, :] - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = ( - batch_idx * stride_k_b - + head_k_idx * stride_k_h - + k_start * stride_k_n - + offs_n[:, None] * stride_k_n - + offs_k[None, :] * stride_k_k - ) - adj_v = ( - batch_idx * stride_v_b - + head_k_idx * stride_v_h - + k_start * stride_v_n - + offs_n[:, None] * stride_v_n - + offs_k[None, :] * stride_v_k - ) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) - v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) - - # If MQA / GQA, set the K and V head offsets appropriately. - # for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = (start_m // BLOCK_M) * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - - # offset input and output tensor by batch and Q/K heads - adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m - adj_dq = batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m - - q_ptr_adj = q_ptr + adj_q - dq_ptr_adj = dq_ptr + adj_dq - - adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m - do_ptr_adj = do_ptr + adj_do - adj_delta = ( - batch_idx * stride_delta_b - + head_q_idx * stride_delta_h - + q_start * stride_delta_m - ) - m_ptr_adj = m_ptr + adj_delta - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base - + batch_idx * stride_dropout_b - + head_q_idx * stride_dropout_h - ) - dropout_offset = ( - dropout_mask + batch_idx * stride_dropout_b + head_q_idx * stride_dropout_h - ) - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - - # when q < k, we may skip the initial masked op - if seq_k_blk_idx < num_blocks_skip: - num_steps = 0 - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if unaligned start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_dkdvdq_inner( - dk, - dv, # output tensors - q_ptr_adj, - k, - v, - do_ptr_adj, - dq_ptr_adj, - m_ptr_adj, - delta_ptr_adj, - sm_scale, # input tensors - stride_q_m, - stride_q_k, # strides for q - stride_dq_m, - stride_dq_k, # strides for q - stride_do_m, - stride_do_k, # strides for o - stride_dropout_m, - stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, # - seqlen_q, - seqlen_k, # max sequence length for q and k - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, # fp8 descale factors from user - MASK_BLOCK_M, - BLOCK_N, # block dim - BLOCK_D_MODEL, - BLOCK_D_MODEL_POW2, # head dim - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - - dk, dv = _bwd_dkdvdq_inner( - dk, - dv, # output tensors - q_ptr_adj, - k, - v, - do_ptr_adj, - dq_ptr_adj, - m_ptr_adj, - delta_ptr_adj, - sm_scale, # input tensors - stride_q_m, - stride_q_k, # strides for q - stride_dq_m, - stride_dq_k, # strides for dq - stride_do_m, - stride_do_k, # strides for o - stride_dropout_m, - stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, # - seqlen_q, - seqlen_k, # max sequence length for q and k - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, # fp8 descale factors from user - BLOCK_M, - BLOCK_N, # block dim - BLOCK_D_MODEL, - BLOCK_D_MODEL_POW2, # head dim - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - - # Write back dV and dK. - offs_dkdv = ( - batch_idx * stride_dk_b - + head_k_idx * stride_dk_h - + k_start * stride_dk_n - + offs_n[:, None] * stride_dk_n - + offs_k[None, :] * stride_dk_k - ) - tl.atomic_add(dv_ptr + offs_dkdv, dv, mask=mask_kv, sem="relaxed") - dk *= sm_scale - tl.atomic_add(dk_ptr + offs_dkdv, dk, mask=mask_kv, sem="relaxed") - - -@triton.jit -def _bwd_kernel_dkdvdq_noncausal( - Q, - K, - V, - sm_scale, - DO, - DK, - DV, - DQ, - M, - Delta, - stride_qb_in, - stride_qh_in, - stride_qm_in, - stride_qk_in, - stride_kb_in, - stride_kh_in, - stride_kn_in, - stride_kk_in, - stride_vb_in, - stride_vh_in, - stride_vn_in, - stride_vk_in, - stride_dkb_in, - stride_dkh_in, - stride_dkn_in, - stride_dkk_in, - stride_dqb_in, - stride_dqh_in, - stride_dqm_in, - stride_dqk_in, - stride_deltab_in, - stride_deltah_in, - stride_deltam_in, - stride_dob_in, - stride_doh_in, - stride_dom_in, - stride_dok_in, - stride_dropoutb_in, - stride_dropouth_in, - stride_dropoutm_in, - stride_dropoutn_in, - stride_descale_q_z_in, - stride_descale_k_z_in, - stride_descale_v_z_in, - stride_descale_do_z_in, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q_ptr, - descale_k_ptr, - descale_v_ptr, - descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - NUM_SMS: tl.constexpr, - USE_INT64_STRIDES: tl.constexpr, -): - if USE_INT64_STRIDES: - stride_qb = tl.cast(stride_qb_in, tl.int64) - stride_qh = tl.cast(stride_qh_in, tl.int64) - stride_qm = tl.cast(stride_qm_in, tl.int64) - stride_qk = tl.cast(stride_qk_in, tl.int64) - stride_kb = tl.cast(stride_kb_in, tl.int64) - stride_kh = tl.cast(stride_kh_in, tl.int64) - stride_kn = tl.cast(stride_kn_in, tl.int64) - stride_kk = tl.cast(stride_kk_in, tl.int64) - stride_vb = tl.cast(stride_vb_in, tl.int64) - stride_vh = tl.cast(stride_vh_in, tl.int64) - stride_vn = tl.cast(stride_vn_in, tl.int64) - stride_vk = tl.cast(stride_vk_in, tl.int64) - stride_dkb = tl.cast(stride_dkb_in, tl.int64) - stride_dkh = tl.cast(stride_dkh_in, tl.int64) - stride_dkn = tl.cast(stride_dkn_in, tl.int64) - stride_dkk = tl.cast(stride_dkk_in, tl.int64) - stride_dqb = tl.cast(stride_dqb_in, tl.int64) - stride_dqh = tl.cast(stride_dqh_in, tl.int64) - stride_dqm = tl.cast(stride_dqm_in, tl.int64) - stride_dqk = tl.cast(stride_dqk_in, tl.int64) - stride_deltab = tl.cast(stride_deltab_in, tl.int64) - stride_deltah = tl.cast(stride_deltah_in, tl.int64) - stride_deltam = tl.cast(stride_deltam_in, tl.int64) - stride_dob = tl.cast(stride_dob_in, tl.int64) - stride_doh = tl.cast(stride_doh_in, tl.int64) - stride_dom = tl.cast(stride_dom_in, tl.int64) - stride_dok = tl.cast(stride_dok_in, tl.int64) - stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) - stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) - stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) - stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) - if IS_FP8: - stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) - stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) - stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) - stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) - else: - stride_qb = stride_qb_in - stride_qh = stride_qh_in - stride_qm = stride_qm_in - stride_qk = stride_qk_in - stride_kb = stride_kb_in - stride_kh = stride_kh_in - stride_kn = stride_kn_in - stride_kk = stride_kk_in - stride_vb = stride_vb_in - stride_vh = stride_vh_in - stride_vn = stride_vn_in - stride_vk = stride_vk_in - stride_dkb = stride_dkb_in - stride_dkh = stride_dkh_in - stride_dkn = stride_dkn_in - stride_dkk = stride_dkk_in - stride_dqb = stride_dqb_in - stride_dqh = stride_dqh_in - stride_dqm = stride_dqm_in - stride_dqk = stride_dqk_in - stride_deltab = stride_deltab_in - stride_deltah = stride_deltah_in - stride_deltam = stride_deltam_in - stride_dob = stride_dob_in - stride_doh = stride_doh_in - stride_dom = stride_dom_in - stride_dok = stride_dok_in - stride_dropoutb = stride_dropoutb_in - stride_dropouth = stride_dropouth_in - stride_dropoutm = stride_dropoutm_in - stride_dropoutn = stride_dropoutn_in - stride_descale_q_z = stride_descale_q_z_in - stride_descale_k_z = stride_descale_k_z_in - stride_descale_v_z = stride_descale_v_z_in - stride_descale_do_z = stride_descale_do_z_in - - # workgroup id - wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 - - # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim - # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. - bid = wid % BATCH - hkid = wid // BATCH % NUM_K_HEADS - pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask_kv &= offs_k < BLOCK_D_MODEL - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = ( - bid * stride_kb - + hkid * stride_kh - + k_start * stride_kn - + offs_n[:, None] * stride_kn - + offs_k[None, :] * stride_kk - ) - adj_v = ( - bid * stride_vb - + hkid * stride_vh - + k_start * stride_vn - + offs_n[:, None] * stride_vn - + offs_k[None, :] * stride_vk - ) - - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - - Q_ptr = Q + adj_q - DQ_ptr = DQ + adj_dq - - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - # dropout - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset + bid * stride_dropoutb + hqid * stride_dropouth - ) - dropout_offset = ( - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - ) - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - - dk, dv = _bwd_dkdvdq_inner( - dk, - dv, - Q_ptr, - k, - v, - DO_ptr, - DQ_ptr, - M_ptr, - Delta_ptr, - sm_scale, - stride_qm, - stride_qk, - stride_dqm, - stride_dqk, - stride_dom, - stride_dok, - stride_dropoutm, - stride_dropoutn, - stride_deltam, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - seqlen_q, - seqlen_k, - start_n, - start_m, - num_steps, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M, - BLOCK_N, - BLOCK_D_MODEL, - BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=wid, - ) - - adj_dkdv = ( - bid * stride_dkb - + hkid * stride_dkh - + k_start * stride_dkn - + offs_n[:, None] * stride_dkn - + offs_k[None, :] * stride_dkk - ) - tl.store(DV + adj_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv, dk, mask=mask_kv) - - -@functools.lru_cache(maxsize=1024) -def _get_config(): - if not hasattr(_get_config, "_config_dict"): - dev = arch_info.get_device() - _get_config._config_dict = {} - fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" - with open(fpath, "r") as file: - config = json.load(file) - _get_config._config_dict = config - - return _get_config._config_dict["bkwd_fused"] - - -def flash_attn_fused_backward( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - dbias: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - USE_INT64_STRIDES: Optional[bool] = False, - config: Optional[Dict[str, any]] = None, -): - if dbias is not None: - raise ValueError("Bias is not supported yet in the Triton Backend") - - IS_FP8 = _is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - descale_strides = ( - descale_q.stride(0), - descale_k.stride(0), - descale_v.stride(0), - descale_do.stride(0), - ) - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_do_z - ) = None - descale_strides = ( - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_descale_do_z, - ) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - # get strides and shape - if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( - len(cu_seqlens_q) - 1, - max_seqlen_q, - q.shape[1], - q.shape[2], - ) - _, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - _, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - # [total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - # [batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - # preprocess - # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - if config is None: - config = _get_config() - - pre_grid = ( - triton.cdiv(max_seqlen_q, config["preprocess_kernel"]["PRE_BLOCK"]), - batch, - num_q_heads, - ) - - _bwd_preprocess[pre_grid]( - o, - do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, - max_seqlen_q, - descale_do, - BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - ) - # dropout_mask - use_dropout = dropout_p > 0.0 - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, - ) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - # Fuses dk,dv and dq computations into one kernel using atomics - if BLOCK_D_MODEL_POW2 > 160 or q.dtype == torch.float32: - config_dkdvdq = config["dkdvdq_kernel_N64"] - else: - config_dkdvdq = config["dkdvdq_kernel_N128"] - - num_k_pids = (max_seqlen_k + config_dkdvdq["BLOCK_N"] - 1) // config_dkdvdq[ - "BLOCK_N" - ] - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - if causal: - grid_dkdvdq = (batch * num_q_heads * num_k_pids,) - - _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - NUM_SMS=NUM_SMS, - USE_INT64_STRIDES=USE_INT64_STRIDES, - **config_dkdvdq, - ) - else: - # in non causal inner loop over grouped q heads - grid_dkdvdq = (batch * num_k_heads * num_k_pids,) - _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - NUM_SMS=NUM_SMS, - USE_INT64_STRIDES=USE_INT64_STRIDES, - **config_dkdvdq, - ) - - return delta diff --git a/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py b/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py deleted file mode 100644 index aace3dabc45..00000000000 --- a/flash_attn/flash_attn_triton_amd/mha_onekernel_bwd.py +++ /dev/null @@ -1,1806 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -from typing import Optional, Dict -import functools -import json -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -import aiter.ops.triton.utils.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -from aiter.ops.triton.utils.mha_kernel_utils import ( - _compute_fp8_scaling_factors, - _is_fp8, -) - - -# NOTE: triton fails to import tl.constexprs so create them here for the file -DROPOUT_USE_PYTORCH = False -DROPOUT_DUMP = False - -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_preprocess( - o_ptr, - do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, - stride_o_h, - stride_o_m, - stride_o_k, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, - max_seqlen_q, - descale_do_ptr, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_m = tl.program_id(0) # seqlen - bid = tl.program_id(1) # batch - hid = tl.program_id(2) # head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = ( - bid * stride_o_b - + hid * stride_o_h - + q_start * stride_o_m - + offs_m[:, None] * stride_o_m - + offs_k[None, :] * stride_o_k - ) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = ( - bid * stride_delta_b - + hid * stride_delta_h - + q_start * stride_delta_m - + offs_m * stride_delta_m - ) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _bwd_dkdv_inner( - dk, - dv, # output - Q, - k, - v, - DO, - M, - D, - sm_scale, # input tensor - stride_qm, - stride_qk, - stride_dom, - stride_dok, - stride_dropoutm, - stride_dropoutn, - stride_deltam, - BLOCK_M: tl.constexpr, # 16 - BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - seqlen_q, - seqlen_k, # max sequence length for q and k - # Filled in by the wrapper. - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, # fp8 descale factors from user - MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal - ENABLE_DROPOUT: tl.constexpr, # activate dropout - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, # activate exp2 - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) - offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N % BLOCK_M == 0) - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - - for blk_idx in range(num_steps): - if DEBUG_TRITON: - print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 - offs_m = curr_m + tl.arange(0, BLOCK_M) - # update the mask because offs_m advanced - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - # generate dropout mask - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = ( - curr_philox_offset - + offs_m[None, :] * stride_dropoutm - + offs_n[:, None] * stride_dropoutn - ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[None, :] * stride_dropoutm - + offs_n[:, None] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_nm) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - if IS_FP8: - qkT = tl.dot(k, qT) * descale_q * descale_k - else: - qkT = tl.dot(k, qT) - qkT_scaled = qkT * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qkT_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"qT: {qT.shape}\n", qT) - print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) - # TODO: remove the scaling of m later when we removed re-scaling in fwd - if USE_EXP2: - pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) - else: - pT = tl.math.exp(qkT_scaled - m[None, :]) - - # Autoregressive masking. - if MASK: - # offset offs_m with delta_qk since the causal mask starts at - # bottom right of the (seqlen_q, seqlen_k) matrix - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"causal_mask: {causal_mask.shape}\n", causal_mask) - print( - f"qkT after causal: {qkT.shape}\n", - tl.where(causal_mask, qkT * sm_scale, 0.0), - ) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - # Compute dV. - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = _compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = _compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"pT: {pT.shape}\n", pT) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - # Compute dP and dS. - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - if IS_FP8: - scale_dsT, descale_dsT = _compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_qm - do_ptrs += step_m * stride_dom - return dk, dv - - -# the main inner-loop logic for computing dQ -@triton.jit -def _bwd_dq_inner( - dq, # output - q, - K, - V, - do, - m, - Delta, - sm_scale, # input - # shared by Q/K/V. - stride_qm, - stride_qk, - stride_kn, - stride_kk, - stride_vn, - stride_vk, - stride_dropoutm, - stride_dropoutn, # stride for dropout - stride_deltam, - seqlen_q, - seqlen_k, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - # Filled in by the wrapper. - start_m, - start_n, - end_n, - num_steps, # - descale_q, - descale_k, - descale_v, - descale_do, # fp8 descale factors from user - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - for blk_idx in range(num_steps): - if DEBUG_TRITON: - print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 - offs_n = curr_n + tl.arange(0, BLOCK_N2) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - if DEBUG_TRITON_DETAIL: - print( - f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}" - ) # noqa: E701 - if DEBUG_TRITON_DETAIL: - print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = ( - curr_philox_offset - + offs_m[:, None] * stride_dropoutm - + offs_n[None, :] * stride_dropoutn - ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[:, None] * stride_dropoutm - + offs_n[None, :] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - if IS_FP8: - qk = tl.dot(q, kT) * descale_q * descale_k - else: - qk = tl.dot(q, kT) - qk_scaled = qk * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qk_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: - print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 - if USE_EXP2: - p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) - else: - p = tl.math.exp(qk_scaled - m) - - # Autoregressive masking. - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask & mask_mn - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - if IS_FP8: - dp = tl.dot(do, vT) * descale_do * descale_v - else: - dp = tl.dot(do, vT) - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - delta_i = Di[:, None] - ds = p * (dp - delta_i) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = _compute_fp8_scaling_factors(ds, FP8_MAX) - dq += ( - tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) - * descale_ds - * descale_k - ) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - - -@triton.jit -def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) - Q, - K, - V, - sm_scale, - DO, - DQ, - DK, - DV, - M, - Delta, - stride_qb_in, - stride_qh_in, - stride_qm_in, - stride_qd_in, - stride_kb_in, - stride_kh_in, - stride_kn_in, - stride_kd_in, - stride_vb_in, - stride_vh_in, - stride_vn_in, - stride_vd_in, - stride_dqb_in, - stride_dqh_in, - stride_dqm_in, - stride_dqd_in, - stride_dkb_in, - stride_dkh_in, - stride_dkn_in, - stride_dkd_in, - stride_dvb_in, - stride_dvh_in, - stride_dvn_in, - stride_dvd_in, - stride_deltab_in, - stride_deltah_in, - stride_deltam_in, - stride_dob_in, - stride_doh_in, - stride_dom_in, - stride_dod_in, - stride_dropoutb_in, - stride_dropouth_in, - stride_dropoutm_in, - stride_dropoutn_in, - stride_descale_q_z_in, - stride_descale_k_z_in, - stride_descale_v_z_in, - stride_descale_do_z_in, - stride_az_in, - stride_ah_in, - HQ, - HK, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - Dropout_mask, - dropout_p, - philox_seed, - philox_offset_base_in, - Alibi_slopes, - Descale_q, - Descale_k, - Descale_v, - Descale_do, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, - USE_INT64_STRIDES: tl.constexpr, -): - if USE_INT64_STRIDES: - stride_qb = tl.cast(stride_qb_in, tl.int64) - stride_qh = tl.cast(stride_qh_in, tl.int64) - stride_qm = tl.cast(stride_qm_in, tl.int64) - stride_qd = tl.cast(stride_qd_in, tl.int64) - stride_kb = tl.cast(stride_kb_in, tl.int64) - stride_kh = tl.cast(stride_kh_in, tl.int64) - stride_kn = tl.cast(stride_kn_in, tl.int64) - stride_kd = tl.cast(stride_kd_in, tl.int64) - stride_vb = tl.cast(stride_vb_in, tl.int64) - stride_vh = tl.cast(stride_vh_in, tl.int64) - stride_vn = tl.cast(stride_vn_in, tl.int64) - stride_vd = tl.cast(stride_vd_in, tl.int64) - stride_dqb = tl.cast(stride_dqb_in, tl.int64) - stride_dqh = tl.cast(stride_dqh_in, tl.int64) - stride_dqm = tl.cast(stride_dqm_in, tl.int64) - stride_dqd = tl.cast(stride_dqd_in, tl.int64) - stride_dkb = tl.cast(stride_dkb_in, tl.int64) - stride_dkh = tl.cast(stride_dkh_in, tl.int64) - stride_dkn = tl.cast(stride_dkn_in, tl.int64) - stride_dkd = tl.cast(stride_dkd_in, tl.int64) - stride_dvb = tl.cast(stride_dvb_in, tl.int64) - stride_dvh = tl.cast(stride_dvh_in, tl.int64) - stride_dvn = tl.cast(stride_dvn_in, tl.int64) - stride_dvd = tl.cast(stride_dvd_in, tl.int64) - stride_deltab = tl.cast(stride_deltab_in, tl.int64) - stride_deltah = tl.cast(stride_deltah_in, tl.int64) - stride_deltam = tl.cast(stride_deltam_in, tl.int64) - stride_dob = tl.cast(stride_dob_in, tl.int64) - stride_doh = tl.cast(stride_doh_in, tl.int64) - stride_dom = tl.cast(stride_dom_in, tl.int64) - stride_dod = tl.cast(stride_dod_in, tl.int64) - philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) - stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) - stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) - stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) - stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) - if IS_FP8: - stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) - stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) - stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) - stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) - stride_az = tl.cast(stride_az_in, tl.int64) - stride_ah = tl.cast(stride_ah_in, tl.int64) - else: - stride_qb = stride_qb_in - stride_qh = stride_qh_in - stride_qm = stride_qm_in - stride_qd = stride_qd_in - stride_kb = stride_kb_in - stride_kh = stride_kh_in - stride_kn = stride_kn_in - stride_kd = stride_kd_in - stride_vb = stride_vb_in - stride_vh = stride_vh_in - stride_vn = stride_vn_in - stride_vd = stride_vd_in - stride_dqb = stride_dqb_in - stride_dqh = stride_dqh_in - stride_dqm = stride_dqm_in - stride_dqd = stride_dqd_in - stride_dkb = stride_dkb_in - stride_dkh = stride_dkh_in - stride_dkn = stride_dkn_in - stride_dkd = stride_dkd_in - stride_dvb = stride_dvb_in - stride_dvh = stride_dvh_in - stride_dvn = stride_dvn_in - stride_dvd = stride_dvd_in - stride_deltab = stride_deltab_in - stride_deltah = stride_deltah_in - stride_deltam = stride_deltam_in - stride_dob = stride_dob_in - stride_doh = stride_doh_in - stride_dom = stride_dom_in - stride_dod = stride_dod_in - philox_offset_base = philox_offset_base_in - stride_dropoutb = stride_dropoutb_in - stride_dropouth = stride_dropouth_in - stride_dropoutm = stride_dropoutm_in - stride_dropoutn = stride_dropoutn_in - stride_descale_q_z = stride_descale_q_z_in - stride_descale_k_z = stride_descale_k_z_in - stride_descale_v_z = stride_descale_v_z_in - stride_descale_do_z = stride_descale_do_z_in - stride_az = stride_az_in - stride_ah = stride_ah_in - - # program ids - hkid = tl.program_id(0) - pid = tl.program_id(1) - bid = tl.program_id(2) - if DEBUG_TRITON: - print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: - print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM - offs_d = tl.arange(0, HEAD_DIM) - GROUP_SIZE: tl.constexpr = HQ // HK - - # align the delta_qk - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N1 - delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 - if delta_qk >= 0: - start_delta = delta_qk - if DEBUG_TRITON: - print( - f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}" - ) # noqa: E701 - else: - start_delta = start_delta_q_lt_k - if DEBUG_TRITON: - print( - f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}" - ) # noqa: E701 - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] - - # K/V tensors not changed for the group - adj_k = ( - bid * stride_kb - + hkid * stride_kh - + k_start * stride_kn - + offs_n[:, None] * stride_kn - + offs_d[None, :] * stride_kd - ) - adj_v = ( - bid * stride_vb - + hkid * stride_vh - + k_start * stride_vn - + offs_n[:, None] * stride_vn - + offs_d[None, :] * stride_vd - ) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - # hqid = hkid - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N1 - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M1 * BLOCK_M1 - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N1 + residue_m - if DEBUG_TRITON: - print(f"residue_m = {residue_m}") # noqa: E701 - - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = ( - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - ) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth - ) - dropout_offset = ( - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - ) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) - # when q < k, we may skip the initial masked op - if pid < num_blocks_skip: - num_steps = 0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - if DEBUG_TRITON: - print( - f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" - ) # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors - Q_ptr, - k, - v, - DO_ptr, - M_ptr, - Delta_ptr, - sm_scale, # input tensors - stride_qm, - stride_qd, # strides for q - stride_dom, - stride_dod, # strides for o - stride_dropoutm, - stride_dropoutn, # strides for dropout - stride_deltam, - MASK_BLOCK_M1, - BLOCK_N1, # block dim - HEAD_DIM, - ACTUAL_HEAD_DIM, # head dim - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - seqlen_q, - seqlen_k, # max sequence length for q and k - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - start_m += num_steps * MASK_BLOCK_M1 - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) - end_m = start_m + num_steps * BLOCK_M1 - - if DEBUG_TRITON: - print( - f"start_m after Masked step: {start_m}; num_steps: {num_steps}" - ) # noqa: E701 - if DEBUG_TRITON: - print( - f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}" - ) # noqa: E701 - if DEBUG_TRITON: - print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors - Q_ptr, - k, - v, - DO_ptr, - M_ptr, - Delta_ptr, - sm_scale, # input tensors - stride_qm, - stride_qd, # strides for q - stride_dom, - stride_dod, # strides for o - stride_dropoutm, - stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M1, - BLOCK_N1, # block dim - HEAD_DIM, - ACTUAL_HEAD_DIM, # head dim - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - seqlen_q, - seqlen_k, # max sequence length for q and k - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # end of GQA/MQA of dkdv - # Write back dV - adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) - # write back dk - adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd - dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) - - # This part does dq - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - # seqlen_q > seqlen_k, no need to process these tile for dq - if DEBUG_TRITON: - print( - f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}" - ) # noqa: E701 - if start_m + BLOCK_M2 < delta_qk: - if DEBUG_TRITON: - print( - f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}" - ) # noqa: E701 - return - - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod - # NOTE: don't assume that the strides for k and v are the same! - K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn - V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn - - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M2 - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - if DEBUG_TRITON: - print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = ( - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - ) - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth - ) - dropout_offset = ( - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - ) - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M2, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, - K, - V, - do, - m, - Delta_ptr, - sm_scale, - stride_qm, - stride_qd, - stride_kn, - stride_kd, - stride_vn, - stride_vd, - stride_dropoutm, - stride_dropoutn, - stride_deltam, - seqlen_q, - seqlen_k, - BLOCK_M2, - MASK_BLOCK_N2, - HEAD_DIM, - ACTUAL_HEAD_DIM, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - start_m, - start_n, - end_n, - num_steps, - descale_q, - descale_k, - descale_v, - descale_do, - MASK=True, # - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - end_n -= num_steps * MASK_BLOCK_N2 - num_steps = tl.cdiv(end_n, BLOCK_N2) - start_n = max(end_n - num_steps * BLOCK_N2, 0) - if DEBUG_TRITON: - print( - f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" - ) # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, - K, - V, - do, - m, - Delta_ptr, - sm_scale, - stride_qm, - stride_qd, - stride_kn, - stride_kd, - stride_vn, - stride_vd, - stride_dropoutm, - stride_dropoutn, - stride_deltam, - seqlen_q, - seqlen_k, - BLOCK_M2, - BLOCK_N2, - HEAD_DIM, - ACTUAL_HEAD_DIM, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - start_m, - start_n, - end_n, - num_steps, - descale_q, - descale_k, - descale_v, - descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - # end of GQA/MQA of dq - - -@triton.jit -def bwd_kernel_noncausal( - Q, - K, - V, - sm_scale, - DO, - DQ, - DK, - DV, - M, - Delta, - stride_qb_in, - stride_qh_in, - stride_qm_in, - stride_qd_in, - stride_kb_in, - stride_kh_in, - stride_kn_in, - stride_kd_in, - stride_vb_in, - stride_vh_in, - stride_vn_in, - stride_vd_in, - stride_dqb_in, - stride_dqh_in, - stride_dqm_in, - stride_dqd_in, - stride_dkb_in, - stride_dkh_in, - stride_dkn_in, - stride_dkd_in, - stride_dvb_in, - stride_dvh_in, - stride_dvn_in, - stride_dvd_in, - stride_deltab_in, - stride_deltah_in, - stride_deltam_in, - stride_dob_in, - stride_doh_in, - stride_dom_in, - stride_dod_in, - stride_dropoutb_in, - stride_dropouth_in, - stride_dropoutm_in, - stride_dropoutn_in, - stride_descale_q_z_in, - stride_descale_k_z_in, - stride_descale_v_z_in, - stride_descale_do_z_in, - stride_az_in, - stride_ah_in, - HQ, - HK, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - Dropout_mask, - dropout_p, - philox_seed, - philox_offset_base_in, - Alibi_slopes, - Descale_q, - Descale_k, - Descale_v, - Descale_do, - BLOCK_M1: tl.constexpr, # 32 - BLOCK_N1: tl.constexpr, # 128 - BLOCK_M2: tl.constexpr, # 128 - BLOCK_N2: tl.constexpr, # 32 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, - USE_INT64_STRIDES: tl.constexpr, -): - if USE_INT64_STRIDES: - stride_qb = tl.cast(stride_qb_in, tl.int64) - stride_qh = tl.cast(stride_qh_in, tl.int64) - stride_qm = tl.cast(stride_qm_in, tl.int64) - stride_qd = tl.cast(stride_qd_in, tl.int64) - stride_kb = tl.cast(stride_kb_in, tl.int64) - stride_kh = tl.cast(stride_kh_in, tl.int64) - stride_kn = tl.cast(stride_kn_in, tl.int64) - stride_kd = tl.cast(stride_kd_in, tl.int64) - stride_vb = tl.cast(stride_vb_in, tl.int64) - stride_vh = tl.cast(stride_vh_in, tl.int64) - stride_vn = tl.cast(stride_vn_in, tl.int64) - stride_vd = tl.cast(stride_vd_in, tl.int64) - stride_dqb = tl.cast(stride_dqb_in, tl.int64) - stride_dqh = tl.cast(stride_dqh_in, tl.int64) - stride_dqm = tl.cast(stride_dqm_in, tl.int64) - stride_dqd = tl.cast(stride_dqd_in, tl.int64) - stride_dkb = tl.cast(stride_dkb_in, tl.int64) - stride_dkh = tl.cast(stride_dkh_in, tl.int64) - stride_dkn = tl.cast(stride_dkn_in, tl.int64) - stride_dkd = tl.cast(stride_dkd_in, tl.int64) - stride_dvb = tl.cast(stride_dvb_in, tl.int64) - stride_dvh = tl.cast(stride_dvh_in, tl.int64) - stride_dvn = tl.cast(stride_dvn_in, tl.int64) - stride_dvd = tl.cast(stride_dvd_in, tl.int64) - stride_deltab = tl.cast(stride_deltab_in, tl.int64) - stride_deltah = tl.cast(stride_deltah_in, tl.int64) - stride_deltam = tl.cast(stride_deltam_in, tl.int64) - stride_dob = tl.cast(stride_dob_in, tl.int64) - stride_doh = tl.cast(stride_doh_in, tl.int64) - stride_dom = tl.cast(stride_dom_in, tl.int64) - stride_dod = tl.cast(stride_dod_in, tl.int64) - philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) - stride_dropoutb = tl.cast(stride_dropoutb_in, tl.int64) - stride_dropouth = tl.cast(stride_dropouth_in, tl.int64) - stride_dropoutm = tl.cast(stride_dropoutm_in, tl.int64) - stride_dropoutn = tl.cast(stride_dropoutn_in, tl.int64) - if IS_FP8: - stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) - stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) - stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) - stride_descale_do_z = tl.cast(stride_descale_do_z_in, tl.int64) - stride_az = tl.cast(stride_az_in, tl.int64) - stride_ah = tl.cast(stride_ah_in, tl.int64) - else: - stride_qb = stride_qb_in - stride_qh = stride_qh_in - stride_qm = stride_qm_in - stride_qd = stride_qd_in - stride_kb = stride_kb_in - stride_kh = stride_kh_in - stride_kn = stride_kn_in - stride_kd = stride_kd_in - stride_vb = stride_vb_in - stride_vh = stride_vh_in - stride_vn = stride_vn_in - stride_vd = stride_vd_in - stride_dqb = stride_dqb_in - stride_dqh = stride_dqh_in - stride_dqm = stride_dqm_in - stride_dqd = stride_dqd_in - stride_dkb = stride_dkb_in - stride_dkh = stride_dkh_in - stride_dkn = stride_dkn_in - stride_dkd = stride_dkd_in - stride_dvb = stride_dvb_in - stride_dvh = stride_dvh_in - stride_dvn = stride_dvn_in - stride_dvd = stride_dvd_in - stride_deltab = stride_deltab_in - stride_deltah = stride_deltah_in - stride_deltam = stride_deltam_in - stride_dob = stride_dob_in - stride_doh = stride_doh_in - stride_dom = stride_dom_in - stride_dod = stride_dod_in - philox_offset_base = philox_offset_base_in - stride_dropoutb = stride_dropoutb_in - stride_dropouth = stride_dropouth_in - stride_dropoutm = stride_dropoutm_in - stride_dropoutn = stride_dropoutn_in - stride_descale_q_z = stride_descale_q_z_in - stride_descale_k_z = stride_descale_k_z_in - stride_descale_v_z = stride_descale_v_z_in - stride_descale_do_z = stride_descale_do_z_in - stride_az = stride_az_in - stride_ah = stride_ah_in - - # program ids - hkid = tl.program_id(0) - pid = tl.program_id(1) - bid = tl.program_id(2) - if DEBUG_TRITON: - print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM - offs_d = tl.arange(0, HEAD_DIM) - GROUP_SIZE: tl.constexpr = HQ // HK - - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] - # NOTE: don't assume that the strides for k and v are the same! - # K/V tensors not changed for the group - adj_k = ( - bid * stride_kb - + hkid * stride_kh - + k_start * stride_kn - + offs_n[:, None] * stride_kn - + offs_d[None, :] * stride_kd - ) - adj_v = ( - bid * stride_vb - + hkid * stride_vh - + k_start * stride_vn - + offs_n[:, None] * stride_vn - + offs_d[None, :] * stride_vd - ) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = ( - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - ) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth - ) - dropout_offset = ( - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - ) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # because there is no causal, we always start from the beginning - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M1) - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors - Q_ptr, - k, - v, - DO_ptr, - M_ptr, - Delta_ptr, - sm_scale, # input tensors - stride_qm, - stride_qd, # strides for q - stride_dom, - stride_dod, # strides for o - stride_dropoutm, - stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M1, - BLOCK_N1, # block dim - HEAD_DIM, - ACTUAL_HEAD_DIM, # head dim - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, # - alibi_slope, - seqlen_q, - seqlen_k, # max sequence length for q and k - start_n, - start_m, - num_steps, # iteration numbers - descale_q, - descale_k, - descale_v, - descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV - adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) - # write back dk - adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd - dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) - - # THIS PART DOES DQ - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod - K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn - V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = ( - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - ) - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = ( - philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth - ) - dropout_offset = ( - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - ) - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) - m = m[:, None] - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # start can only be 0 at minimum - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, - K, - V, - do, - m, - Delta_ptr, - sm_scale, - stride_qm, - stride_qd, - stride_kn, - stride_kd, - stride_vn, - stride_vd, - stride_dropoutm, - stride_dropoutn, - stride_deltam, - seqlen_q, - seqlen_k, - BLOCK_M2, - BLOCK_N2, - HEAD_DIM, - ACTUAL_HEAD_DIM, - dropout_p, - philox_seed, - batch_philox_offset, - dropout_offset, - alibi_slope, - start_m, - start_n, - end_n, - num_steps, - descale_q, - descale_k, - descale_v, - descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -def is_contiguous(x, name): - if x.is_contiguous(): - return x - else: - print(f"{name} is not contiguous") - return x.contiguous() - - -@functools.lru_cache(maxsize=1024) -def _get_config(): - if not hasattr(_get_config, "_config_dict"): - dev = arch_info.get_device() - _get_config._config_dict = {} - fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" - with open(fpath, "r") as file: - config = json.load(file) - _get_config._config_dict = config - - return _get_config._config_dict["bkwd_onekernel"] - - -def flash_attn_onekernel_backward( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - dbias: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - USE_INT64_STRIDES: Optional[bool] = False, - config: Optional[Dict[str, any]] = None, -): - if dbias is not None: - raise ValueError("Bias is not supported yet in the Triton Backend") - - use_alibi, (stride_az, stride_ah) = ( - (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - ) - - IS_FP8 = _is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - descale_strides = ( - descale_q.stride(0), - descale_k.stride(0), - descale_v.stride(0), - descale_do.stride(0), - ) - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_do_z - ) = None - descale_strides = ( - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_descale_do_z, - ) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - # get strides and shape - if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( - len(cu_seqlens_q) - 1, - max_seqlen_q, - q.shape[1], - q.shape[2], - ) - _, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - _, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - # Configs - if config is None: - config = _get_config() - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - # [total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - # [batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - # preprocess - # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - pre_grid = ( - triton.cdiv(max_seqlen_q, config["preprocess_kernel"]["PRE_BLOCK"]), - batch, - num_q_heads, - ) - _bwd_preprocess[pre_grid]( - o, - do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, - max_seqlen_q, - descale_do, - BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - ) - - # dropout_mask - use_dropout = dropout_p > 0.0 - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, - ) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - seqlen = max(max_seqlen_q, max_seqlen_k) - - config_onekernel = config["onekernel"] - grid = ( - num_k_heads, - triton.cdiv(seqlen, config_onekernel["BLOCK_N1"]), - batch, - ) - - if causal: - bwd_kernel_causal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *dk_strides, - *dv_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - stride_az, - stride_ah, - num_q_heads, - num_k_heads, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - descale_do, - HEAD_DIM=head_sz, - ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=True, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=False, - DEBUG_TRITON=False, - DEBUG_TRITON_DETAIL=False, - USE_INT64_STRIDES=USE_INT64_STRIDES, - **config_onekernel, - ) - else: - bwd_kernel_noncausal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *dk_strides, - *dv_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - stride_az, - stride_ah, - num_q_heads, - num_k_heads, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - descale_do, - HEAD_DIM=head_sz, - ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=True, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=False, - DEBUG_TRITON=False, - DEBUG_TRITON_DETAIL=False, - USE_INT64_STRIDES=USE_INT64_STRIDES, - **config_onekernel, - ) - - return delta From 7f3e8d6524709b225c2959f1d7c6b91290008556 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 15:39:23 +0000 Subject: [PATCH 3/9] sliding window non causal masking works --- .../flash_attn_triton_amd/fwd_prefill.py | 77 ++++++++++++++++++- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 71b2b40458a..a9d8b834050 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -382,7 +382,7 @@ def _attn_fwd_mask(acc, l_i, m_i, @triton.jit -def compute_masking(seqlen_k, seqlen_q, start_m, +def compute_block_masking(seqlen_k, seqlen_q, start_m, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): @@ -423,7 +423,78 @@ def compute_masking(seqlen_k, seqlen_q, start_m, if IS_CAUSAL: return 0, 0, 0, total_k_blocks, n_extra_tokens else: - return 0, 0, 0, total_k_blocks, n_extra_tokens + # ------------------------------------------------------------------ + # token bounds seen by FIRST and LAST rows in this Q‑block + # ------------------------------------------------------------------ + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + base = seqlen_k - seqlen_q + + # left‑hand side + if WINDOW_SIZE_LEFT < 0: # un‑bounded + left_min = 0 # earliest row + left_max = 0 # latest row + else: + left_min = tl.maximum(0, q_start + base - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + base - WINDOW_SIZE_LEFT) + + # right‑hand side + right_min = tl.minimum(seqlen_k - 1, + q_start + base + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, + q_end + base + WINDOW_SIZE_RIGHT) + + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # ------------------------------------------------------------------ + # make sure full_left_block never outruns the visible range + # ------------------------------------------------------------------ + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N # right‑most block that *any* row touches + + # “first block that is fully visible for all rows” + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + + # clip to avoid front‑mask length > total_visible + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # ------------------------------------------------------------------ + # block counts + # ------------------------------------------------------------------ + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + + tmp = right_min // BLOCK_N + if (tmp + 1) * BLOCK_N - 1 > right_min: # ensure block fits earliest row + tmp -= 1 + full_right_block = tl.maximum(tmp, clipped_left - 1) + + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + # ------------------------------------------------------------ + # padded last‑K block + # ------------------------------------------------------------ + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + last_block_in_front = clipped_left > last_block # ← last block ended up on the left side + + if padded_last_k & (n_back_masked_blocks == 0): + if last_block_in_front: + # move the last block from front‑masked → back‑masked + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # move the last block from full → back‑masked + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + + n_back_masked_blocks = 1 # ensure it is handled with padding info + + return (n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens) else: if IS_CAUSAL: # ========== CAUSAL MODE: Classify K Blocks ========== @@ -575,7 +646,7 @@ def attn_fwd(Q, K, V, bias, # figure out masking pattern - n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_masking( + n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_block_masking( seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N ) From 4d5d1d63482c49aeb4b7f0b2ecaf2887dadd8ada Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 16:51:25 +0000 Subject: [PATCH 4/9] causal and sliding window block masking --- .../flash_attn_triton_amd/fwd_prefill.py | 76 ++++++++++++++++++- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index a9d8b834050..1364377ac06 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -418,10 +418,80 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, n_extra_tokens = 0 if USE_SLIDING_WINDOW: - # TODO: Optimize by computing which blocks can be fully skipped - # For now, process all blocks with the mask function if IS_CAUSAL: - return 0, 0, 0, total_k_blocks, n_extra_tokens + # ------------------------------------------------------------------ + # causal + sliding‑window block classification + # ------------------------------------------------------------------ + # window per row i: + # left_i = max(0, i + base − W_left) (if W_left >= 0) + # right_i = min(sk‑1, i + base) (causal cap) + # (if W_right < 0 then i+base+W_right) + # + # to be “full” a K‑block has to lie inside the *intersection* + # of every row’s window ⇒ use + # left_max = max_i left_i (earliest col seen by all rows) + # right_min = min_i right_i (latest col seen by all rows) + # any block wholly inside [left_max , right_min] is un‑masked. + # ------------------------------------------------------------------ + + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + base = seqlen_k - seqlen_q + + # ------------------ left edge ------------------ + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + base - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + base - WINDOW_SIZE_LEFT) + + # ------------------ right edge ----------------- + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + base + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + base + WINDOW_SIZE_RIGHT) + else: + # causal cap: col ≤ row + base + right_min = tl.minimum(seqlen_k - 1, q_start + base) + right_max = tl.minimum(seqlen_k - 1, q_end + base) + + # no overlap → nothing visible + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # ---------------- block geometry --------------- + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + + tmp = right_min // BLOCK_N + if (tmp + 1) * BLOCK_N - 1 > right_min: # ensure block fits earliest row + tmp -= 1 + full_right_block = tl.maximum(tmp, clipped_left - 1) + + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + # ------------- padded last‑K block ------------- + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + last_block_in_front = clipped_left > last_block + if padded_last_k & (n_back_masked_blocks == 0): + if last_block_in_front: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + n_back_masked_blocks = 1 + + return (n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens) else: # ------------------------------------------------------------------ # token bounds seen by FIRST and LAST rows in this Q‑block From ac5ae567b9deba2d841bbd7e057f02193eeb46f1 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 18:02:39 +0000 Subject: [PATCH 5/9] extract common --- .../flash_attn_triton_amd/fwd_prefill.py | 38 ++++++++----------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 1364377ac06..5326e47c9c8 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -416,6 +416,11 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, n_extra_tokens = seqlen_k % BLOCK_N else: n_extra_tokens = 0 + + # common + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q if USE_SLIDING_WINDOW: if IS_CAUSAL: @@ -434,26 +439,22 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # any block wholly inside [left_max , right_min] is un‑masked. # ------------------------------------------------------------------ - q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - base = seqlen_k - seqlen_q - # ------------------ left edge ------------------ if WINDOW_SIZE_LEFT < 0: left_min = 0 left_max = 0 else: - left_min = tl.maximum(0, q_start + base - WINDOW_SIZE_LEFT) - left_max = tl.maximum(0, q_end + base - WINDOW_SIZE_LEFT) + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) # ------------------ right edge ----------------- if WINDOW_SIZE_RIGHT < 0: - right_min = tl.minimum(seqlen_k - 1, q_start + base + WINDOW_SIZE_RIGHT) - right_max = tl.minimum(seqlen_k - 1, q_end + base + WINDOW_SIZE_RIGHT) + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) else: # causal cap: col ≤ row + base - right_min = tl.minimum(seqlen_k - 1, q_start + base) - right_max = tl.minimum(seqlen_k - 1, q_end + base) + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) # no overlap → nothing visible if right_max < left_min: @@ -496,23 +497,19 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # ------------------------------------------------------------------ # token bounds seen by FIRST and LAST rows in this Q‑block # ------------------------------------------------------------------ - q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - base = seqlen_k - seqlen_q - # left‑hand side if WINDOW_SIZE_LEFT < 0: # un‑bounded left_min = 0 # earliest row left_max = 0 # latest row else: - left_min = tl.maximum(0, q_start + base - WINDOW_SIZE_LEFT) - left_max = tl.maximum(0, q_end + base - WINDOW_SIZE_LEFT) + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) # right‑hand side right_min = tl.minimum(seqlen_k - 1, - q_start + base + WINDOW_SIZE_RIGHT) + q_start + diag + WINDOW_SIZE_RIGHT) right_max = tl.minimum(seqlen_k - 1, - q_end + base + WINDOW_SIZE_RIGHT) + q_end + diag + WINDOW_SIZE_RIGHT) # window vanishes → early exit if right_max < left_min: @@ -585,11 +582,6 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # 1. figure out, in tokens, the right-most K position # this Q-block may attend to # ------------------------------------------------------------ - q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - - # causal diagonal offset between the two streams - diag = seqlen_k - seqlen_q # 0 when |Q| == |K| k_max_token = q_end + diag # last visible K index # this Q-block is entirely above the diagonal ⇒ nothing to do From 913d4a8e62adff390fb219124722f0afd55b5105 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 18:06:07 +0000 Subject: [PATCH 6/9] clean up typo --- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 5326e47c9c8..5b26425ff34 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -428,9 +428,9 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # causal + sliding‑window block classification # ------------------------------------------------------------------ # window per row i: - # left_i = max(0, i + base − W_left) (if W_left >= 0) - # right_i = min(sk‑1, i + base) (causal cap) - # (if W_right < 0 then i+base+W_right) + # left_i = max(0, i + diag − W_left) (if W_left >= 0) + # right_i = min(sk‑1, i + diag) (causal cap) + # (if W_right < 0 then i+diag+W_right) # # to be “full” a K‑block has to lie inside the *intersection* # of every row’s window ⇒ use @@ -452,7 +452,7 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) else: - # causal cap: col ≤ row + base + # causal cap: col ≤ row + diag right_min = tl.minimum(seqlen_k - 1, q_start + diag) right_max = tl.minimum(seqlen_k - 1, q_end + diag) From d676dd778f5fac3fdc281d5930d7b27d62627f48 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 24 Jul 2025 18:51:38 +0000 Subject: [PATCH 7/9] helper for swa --- .../flash_attn_triton_amd/fwd_prefill.py | 276 ++++++++---------- 1 file changed, 120 insertions(+), 156 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 5b26425ff34..82c1064c7ca 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -382,26 +382,83 @@ def _attn_fwd_mask(acc, l_i, m_i, @triton.jit -def compute_block_masking(seqlen_k, seqlen_q, start_m, - IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, - WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """ - Classify K blocks for attention computation with sliding window support. +def compute_window_bounds(q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr): + """Calculate the window boundaries for a query block.""" + # Left boundary + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - Returns: - - n_front_skip_blocks: Blocks completely before the window - - n_front_masked_blocks: Blocks partially overlapping window front - - n_full_blocks: Blocks completely inside the window - - n_back_masked_blocks: Blocks partially overlapping window back - - n_extra_tokens: Padding tokens in last K block - """ - # Example case - # BLOCK_M = 4, BLOCK_N = 4, seqlen_q = 8, seqlen_k = 10 + # Right boundary + if IS_CAUSAL: + # Causal cap: col ≤ row + diag + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) + else: + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + else: + # Non-causal doesn't have the diagonal constraint + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + + return left_min, left_max, right_min, right_max - # Total K blocks in the key sequence - total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) +@triton.jit +def classify_window_blocks(left_min, left_max, right_min, right_max, + BLOCK_N: tl.constexpr): + """Classify blocks based on window boundaries.""" + # First and last blocks that have ANY overlap with window + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + # First block that is FULLY visible for all rows in Q block + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # Last block that is FULLY visible for all rows in Q block + last_full_block_candidate = right_min // BLOCK_N + if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: + last_full_block_candidate -= 1 + full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) + + # Calculate counts + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) # Return clipped_left for padded block handling + +@triton.jit +def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks): + """Adjust block counts when last K block has padding.""" + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + + if padded_last_k & (n_back_masked_blocks == 0): + last_block_in_front = clipped_left > last_block + if last_block_in_front: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + n_back_masked_blocks = 1 + + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks +@triton.jit +def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): + """Calculate padding information for the last K block.""" # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N # n_extra_tokens = 10 % 4 = 2 # This means the last K block has 2 valid tokens and 2 padding positions @@ -415,153 +472,60 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N else: - n_extra_tokens = 0 + n_extra_tokens = 0 + return n_extra_tokens + +@triton.jit +def compute_block_masking(seqlen_k, seqlen_q, start_m, + IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ # common q_start = start_m * BLOCK_M q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) diag = seqlen_k - seqlen_q + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) if USE_SLIDING_WINDOW: - if IS_CAUSAL: - # ------------------------------------------------------------------ - # causal + sliding‑window block classification - # ------------------------------------------------------------------ - # window per row i: - # left_i = max(0, i + diag − W_left) (if W_left >= 0) - # right_i = min(sk‑1, i + diag) (causal cap) - # (if W_right < 0 then i+diag+W_right) - # - # to be “full” a K‑block has to lie inside the *intersection* - # of every row’s window ⇒ use - # left_max = max_i left_i (earliest col seen by all rows) - # right_min = min_i right_i (latest col seen by all rows) - # any block wholly inside [left_max , right_min] is un‑masked. - # ------------------------------------------------------------------ - - # ------------------ left edge ------------------ - if WINDOW_SIZE_LEFT < 0: - left_min = 0 - left_max = 0 - else: - left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) - left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - - # ------------------ right edge ----------------- - if WINDOW_SIZE_RIGHT < 0: - right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) - right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) - else: - # causal cap: col ≤ row + diag - right_min = tl.minimum(seqlen_k - 1, q_start + diag) - right_max = tl.minimum(seqlen_k - 1, q_end + diag) - - # no overlap → nothing visible - if right_max < left_min: - return 0, 0, 0, 0, n_extra_tokens - - # ---------------- block geometry --------------- - first_block = left_min // BLOCK_N - last_block = right_max // BLOCK_N - - full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) - clipped_left = tl.minimum(full_left_block, last_block + 1) - - n_front_skip_blocks = first_block - n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) - - tmp = right_min // BLOCK_N - if (tmp + 1) * BLOCK_N - 1 > right_min: # ensure block fits earliest row - tmp -= 1 - full_right_block = tl.maximum(tmp, clipped_left - 1) - - n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) - n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) - - # ------------- padded last‑K block ------------- - padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) - last_block_in_front = clipped_left > last_block - if padded_last_k & (n_back_masked_blocks == 0): - if last_block_in_front: - n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) - else: - n_full_blocks = tl.maximum(0, n_full_blocks - 1) - n_back_masked_blocks = 1 - - return (n_front_skip_blocks, - n_front_masked_blocks, - n_full_blocks, - n_back_masked_blocks, - n_extra_tokens) - else: - # ------------------------------------------------------------------ - # token bounds seen by FIRST and LAST rows in this Q‑block - # ------------------------------------------------------------------ - # left‑hand side - if WINDOW_SIZE_LEFT < 0: # un‑bounded - left_min = 0 # earliest row - left_max = 0 # latest row - else: - left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) - left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - - # right‑hand side - right_min = tl.minimum(seqlen_k - 1, - q_start + diag + WINDOW_SIZE_RIGHT) - right_max = tl.minimum(seqlen_k - 1, - q_end + diag + WINDOW_SIZE_RIGHT) - - # window vanishes → early exit - if right_max < left_min: - return 0, 0, 0, 0, n_extra_tokens - - # ------------------------------------------------------------------ - # make sure full_left_block never outruns the visible range - # ------------------------------------------------------------------ - first_block = left_min // BLOCK_N - last_block = right_max // BLOCK_N # right‑most block that *any* row touches - - # “first block that is fully visible for all rows” - full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) - - # clip to avoid front‑mask length > total_visible - clipped_left = tl.minimum(full_left_block, last_block + 1) - - # ------------------------------------------------------------------ - # block counts - # ------------------------------------------------------------------ - n_front_skip_blocks = first_block - n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) - - tmp = right_min // BLOCK_N - if (tmp + 1) * BLOCK_N - 1 > right_min: # ensure block fits earliest row - tmp -= 1 - full_right_block = tl.maximum(tmp, clipped_left - 1) - - n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) - n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) - - # ------------------------------------------------------------ - # padded last‑K block - # ------------------------------------------------------------ - padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) - last_block_in_front = clipped_left > last_block # ← last block ended up on the left side - - if padded_last_k & (n_back_masked_blocks == 0): - if last_block_in_front: - # move the last block from front‑masked → back‑masked - n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) - else: - # move the last block from full → back‑masked - n_full_blocks = tl.maximum(0, n_full_blocks - 1) - - n_back_masked_blocks = 1 # ensure it is handled with padding info + # get window bounds + left_min, left_max, right_min, right_max = compute_window_bounds( + q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, IS_CAUSAL + ) - return (n_front_skip_blocks, - n_front_masked_blocks, - n_full_blocks, - n_back_masked_blocks, - n_extra_tokens) + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # classify blocks + (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) = classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N + ) + + # handle padded last block if needed + if n_extra_tokens != 0: + last_block = right_max // BLOCK_N + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = handle_padded_last_block( + n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks + ) + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, n_extra_tokens) else: if IS_CAUSAL: # ========== CAUSAL MODE: Classify K Blocks ========== From 3e1323466ddf9bb7af51322571279d910831c2d3 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 7 Aug 2025 13:59:18 -0500 Subject: [PATCH 8/9] ignore .amd --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 97991419fdb..2fd004dc706 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ training/data # ck modules csrc/composable_kernel csrc/cutlass -.analysis \ No newline at end of file +.amd \ No newline at end of file From f8d3fa85f5c50e34c82d2de717734c537823f874 Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 8 Aug 2025 03:35:54 -0500 Subject: [PATCH 9/9] fix last block bug --- .../flash_attn_triton_amd/fwd_prefill.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 82c1064c7ca..3a2bd56fda4 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -441,19 +441,33 @@ def classify_window_blocks(left_min, left_max, right_min, right_max, @triton.jit def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, - clipped_left, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks): - """Adjust block counts when last K block has padding.""" + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks): + """Ensure a padded last K-block is never classified as 'full'. + + We move the padded last block (if visible) into the back-masked bucket. + If it's already back-masked, we do nothing. If it was counted in the + front-masked range, we decrement front-masked; if it was counted as full, + we decrement full. Then we increment back-masked. + """ padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) - - if padded_last_k & (n_back_masked_blocks == 0): - last_block_in_front = clipped_left > last_block - if last_block_in_front: - n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) - else: - n_full_blocks = tl.maximum(0, n_full_blocks - 1) - n_back_masked_blocks = 1 - + + if padded_last_k: + # current 'full' range right edge + full_right_block = clipped_left + n_full_blocks - 1 + + # If last_block is already beyond full_right_block, it's already in back-masked → nothing to do + last_already_back_masked = last_block > full_right_block + if not last_already_back_masked: + # If the window starts past last_block, it was counted in front-masked + if clipped_left > last_block: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # Otherwise it was counted 'full' → move it out of full + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + # In both cases we need one more back-masked block + n_back_masked_blocks = n_back_masked_blocks + 1 + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks @triton.jit