Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions flash_sparse_attn/ops/triton/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ def flash_dense_attn_func(
is_quant: bool = False,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand All @@ -782,6 +784,8 @@ def flash_dense_attn_func(
:param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand All @@ -801,6 +805,8 @@ def flash_dense_attn_func(
is_quant,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down Expand Up @@ -889,6 +895,8 @@ def flash_dense_attn_varlen_func(
seqused_k: Optional[torch.Tensor] = None,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand All @@ -914,6 +922,8 @@ def flash_dense_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand All @@ -939,6 +949,8 @@ def flash_dense_attn_varlen_func(
seqused_k,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down Expand Up @@ -1031,6 +1043,8 @@ def flash_sparse_attn_func(
is_quant: bool = False,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand All @@ -1051,6 +1065,8 @@ def flash_sparse_attn_func(
:param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand All @@ -1071,6 +1087,8 @@ def flash_sparse_attn_func(
is_quant,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down Expand Up @@ -1163,6 +1181,8 @@ def flash_sparse_attn_varlen_func(
seqused_k: Optional[torch.Tensor] = None,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand All @@ -1189,6 +1209,8 @@ def flash_sparse_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand All @@ -1215,6 +1237,8 @@ def flash_sparse_attn_varlen_func(
seqused_k,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down Expand Up @@ -1315,6 +1339,8 @@ def flash_gated_attn_func(
is_quant: bool = False,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand All @@ -1340,6 +1366,8 @@ def flash_gated_attn_func(
:param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand All @@ -1365,6 +1393,8 @@ def flash_gated_attn_func(
is_quant,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down Expand Up @@ -1474,6 +1504,8 @@ def flash_gated_attn_varlen_func(
seqused_k: Optional[torch.Tensor] = None,
is_split_kv: bool = False,
pack_gqa: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
is_autotune: bool = False,
skip_checks: bool = False,
return_lse: bool = False,
Expand Down Expand Up @@ -1505,6 +1537,8 @@ def flash_gated_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param is_split_kv: Whether to enable split-KV for occupancy.
:param pack_gqa: Whether to pack grouped-query attention.
:param out: Optional preallocated output tensor with shape [batch_size, seqlen_q, num_heads, head_dim].
:param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads, seqlen_q].
:param is_autotune: Whether to use Triton autotuner for kernel launch configuration.
:param skip_checks: Whether to skip input validation checks for faster performance.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.
Expand Down Expand Up @@ -1536,6 +1570,8 @@ def flash_gated_attn_varlen_func(
seqused_k,
is_split_kv,
pack_gqa,
out,
lse,
is_autotune,
skip_checks,
return_lse,
Expand Down
Loading