diff --git a/flash_sparse_attn/ops/triton/activations.py b/flash_sparse_attn/ops/triton/activations.py index d8a47b0..0e1b648 100644 --- a/flash_sparse_attn/ops/triton/activations.py +++ b/flash_sparse_attn/ops/triton/activations.py @@ -4,7 +4,43 @@ @triton.jit def check_inf(x): - return tl.where(x == float("-inf"), 0.0, x) + return tl.maximum(x, -1e6) + + +@triton.jit +def exp2(x): + """ + Compute 2^x, using fast hardware approximation on CUDA and default implementation on other backends. + + :param x: Input tensor of shape [BLOCK_M, BLOCK_N]. + + :return: Tensor of shape [BLOCK_M, BLOCK_N] containing 2^x values. + """ + if tl.target_info.is_cuda(): + return tl.inline_asm_elementwise( + "ex2.approx.ftz.f32 $0, $1;", + "=r, r", + [x], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + else: + return tl.exp2(x) + + +@triton.jit +def exp(x): + """ + Compute e^x, using fast hardware approximation on CUDA and default implementation on other backends. + + :param x: Input tensor of shape [BLOCK_M, BLOCK_N]. + + :return: Tensor of shape [BLOCK_M, BLOCK_N] containing e^x values. + """ + log2_e: tl.constexpr = 1.4426950408889634 + x *= log2_e + return exp2(x) @triton.jit @@ -48,15 +84,15 @@ def online_softmax( if RESCALE_THRESHOLD > 0.0: # Triton can only skip computation at block granularity if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: - row_scale = tl.exp2(acc_scale_log2) + row_scale = exp2(acc_scale_log2) else: row_max_new = row_max row_scale = acc_scale_log2 * 0.0 + 1.0 else: - row_scale = tl.exp2(acc_scale_log2) + row_scale = exp2(acc_scale_log2) # Compute attention weights - p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) + p = exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) # Update row sum row_sum_cur = tl.sum(p, axis=1) @@ -125,10 +161,10 @@ def online_sparse_softmax( acc_scale_log2 = (row_max - row_max_new) * scale_log2 # Compute row scale - row_scale = tl.exp2(acc_scale_log2) + row_scale = exp2(acc_scale_log2) # Compute attention weights - p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) + p = exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) # Update row sum row_sum_cur = tl.sum(p, axis=1) @@ -172,7 +208,7 @@ def finalize( lse = row_max * scale_log2 + tl.log2(row_sum) if not IS_LOG2: # ln2 = math.log(2.0) - ln2 = 0.6931471805599453 + ln2: tl.constexpr = 0.6931471805599453 lse *= ln2 return row_scale, lse @@ -201,6 +237,18 @@ def rescale_o( return acc_o +@triton.jit +def sigmoid(x): + """ + Compute sigmoid of x. + + :param x: Input tensor of shape [BLOCK_M, BLOCK_N]. + + :return: Tensor of shape [BLOCK_M, BLOCK_N] containing sigmoid values. + """ + return 1.0 / (1.0 + exp(-x)) + + @triton.jit def log_sigmoid(x, FASTMATH: tl.constexpr): """ @@ -217,7 +265,7 @@ def log_sigmoid(x, FASTMATH: tl.constexpr): out = tl.minimum(x, 0.0) - 0.05674870 * xc2 + 0.37664706 * xc - 0.65169323 return out else: - out = tl.minimum(x, 0.0) - tl.log(1.0 + tl.exp(-tl.abs(x))) + out = tl.minimum(x, 0.0) - tl.log(1.0 + exp(-tl.abs(x))) return out diff --git a/flash_sparse_attn/ops/triton/flash_dense_bwd.py b/flash_sparse_attn/ops/triton/flash_dense_bwd.py index 6825f19..b770c3a 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_bwd.py @@ -12,6 +12,7 @@ launch_grid, seqlen_info, block_info, + activations, mask, flash_bwd_preprocess, flash_bwd_postprocess, @@ -77,7 +78,9 @@ def _bwd_inner_dense_base_kernel( ) # Compute attention weights - p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(q_tile.dtype) + p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to( + q_tile.dtype + ) # Load output gradients tile do_tile = tl.load(do_ptrs, boundary_check=(0, 1), cache_modifier=".cg") diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py index 8d00533..e70b972 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -135,7 +135,7 @@ def _bwd_inner_gated_base_kernel( lse_ptrs = tl.advance(lse_ptrs, (TILE_M,)) # Compute attention weights - p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to( + p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to( q_tile.dtype ) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py index 099a133..d5e3c44 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py @@ -12,6 +12,7 @@ launch_grid, seqlen_info, block_info, + activations, mask, flash_bwd_preprocess, flash_bwd_postprocess, @@ -90,7 +91,7 @@ def _bwd_inner_sparse_base_kernel( lse_ptrs = tl.advance(lse_ptrs, (TILE_M,)) # Compute attention weights - p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to( + p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to( q_tile.dtype )