Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions flash_sparse_attn/ops/triton/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,43 @@

@triton.jit
def check_inf(x):
return tl.where(x == float("-inf"), 0.0, x)
return tl.maximum(x, -1e6)


@triton.jit
def exp2(x):
"""
Compute 2^x, using fast hardware approximation on CUDA and default implementation on other backends.

:param x: Input tensor of shape [BLOCK_M, BLOCK_N].

:return: Tensor of shape [BLOCK_M, BLOCK_N] containing 2^x values.
"""
if tl.target_info.is_cuda():
return tl.inline_asm_elementwise(
"ex2.approx.ftz.f32 $0, $1;",
"=r, r",
[x],
dtype=tl.float32,
is_pure=True,
pack=1,
)
else:
return tl.exp2(x)


@triton.jit
def exp(x):
"""
Compute e^x, using fast hardware approximation on CUDA and default implementation on other backends.

:param x: Input tensor of shape [BLOCK_M, BLOCK_N].

:return: Tensor of shape [BLOCK_M, BLOCK_N] containing e^x values.
"""
log2_e: tl.constexpr = 1.4426950408889634
x *= log2_e
return exp2(x)


@triton.jit
Expand Down Expand Up @@ -48,15 +84,15 @@ def online_softmax(
if RESCALE_THRESHOLD > 0.0:
# Triton can only skip computation at block granularity
if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD:
row_scale = tl.exp2(acc_scale_log2)
row_scale = exp2(acc_scale_log2)
else:
row_max_new = row_max
row_scale = acc_scale_log2 * 0.0 + 1.0
else:
row_scale = tl.exp2(acc_scale_log2)
row_scale = exp2(acc_scale_log2)

# Compute attention weights
p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)
p = exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)

# Update row sum
row_sum_cur = tl.sum(p, axis=1)
Expand Down Expand Up @@ -125,10 +161,10 @@ def online_sparse_softmax(
acc_scale_log2 = (row_max - row_max_new) * scale_log2

# Compute row scale
row_scale = tl.exp2(acc_scale_log2)
row_scale = exp2(acc_scale_log2)

# Compute attention weights
p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)
p = exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)

# Update row sum
row_sum_cur = tl.sum(p, axis=1)
Expand Down Expand Up @@ -172,7 +208,7 @@ def finalize(
lse = row_max * scale_log2 + tl.log2(row_sum)
if not IS_LOG2:
# ln2 = math.log(2.0)
ln2 = 0.6931471805599453
ln2: tl.constexpr = 0.6931471805599453
lse *= ln2
return row_scale, lse

Expand Down Expand Up @@ -201,6 +237,18 @@ def rescale_o(
return acc_o


@triton.jit
def sigmoid(x):
"""
Compute sigmoid of x.

:param x: Input tensor of shape [BLOCK_M, BLOCK_N].

:return: Tensor of shape [BLOCK_M, BLOCK_N] containing sigmoid values.
"""
return 1.0 / (1.0 + exp(-x))


@triton.jit
def log_sigmoid(x, FASTMATH: tl.constexpr):
"""
Expand All @@ -217,7 +265,7 @@ def log_sigmoid(x, FASTMATH: tl.constexpr):
out = tl.minimum(x, 0.0) - 0.05674870 * xc2 + 0.37664706 * xc - 0.65169323
return out
else:
out = tl.minimum(x, 0.0) - tl.log(1.0 + tl.exp(-tl.abs(x)))
out = tl.minimum(x, 0.0) - tl.log(1.0 + exp(-tl.abs(x)))
return out


Expand Down
5 changes: 4 additions & 1 deletion flash_sparse_attn/ops/triton/flash_dense_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
Expand Down Expand Up @@ -77,7 +78,9 @@ def _bwd_inner_dense_base_kernel(
)

# Compute attention weights
p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(q_tile.dtype)
p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(
q_tile.dtype
)

# Load output gradients tile
do_tile = tl.load(do_ptrs, boundary_check=(0, 1), cache_modifier=".cg")
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/triton/flash_gated_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _bwd_inner_gated_base_kernel(
lse_ptrs = tl.advance(lse_ptrs, (TILE_M,))

# Compute attention weights
p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(
p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(
q_tile.dtype
)

Expand Down
3 changes: 2 additions & 1 deletion flash_sparse_attn/ops/triton/flash_sparse_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
launch_grid,
seqlen_info,
block_info,
activations,
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
Expand Down Expand Up @@ -90,7 +91,7 @@ def _bwd_inner_sparse_base_kernel(
lse_ptrs = tl.advance(lse_ptrs, (TILE_M,))

# Compute attention weights
p = tl.math.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(
p = activations.exp2(acc_s * softmax_scale_log2 - lse_log2[None, :]).to(
q_tile.dtype
)

Expand Down
Loading