diff --git a/flash_sparse_attn/ops/triton/flash_dense_bwd.py b/flash_sparse_attn/ops/triton/flash_dense_bwd.py index 8db3b49..87df6c1 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_bwd.py @@ -510,7 +510,7 @@ def _bwd_dense_kernel( TILE_M=TILE_M, TILE_N=TILE_N, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, ) ) diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index 2479eca..1e13477 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -449,7 +449,7 @@ def _fwd_dense_kernel( TILE_N=TILE_N, QHEAD_PER_KVHEAD_PACKGQA=QHEAD_PER_KVHEAD_PACKGQA, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, CHECK_INF=True, ) diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py index b2cc2ef..d3fbf7a 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -760,7 +760,7 @@ def _bwd_gated_kernel( TILE_M=TILE_M, TILE_N=TILE_N, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, IS_LOGSIGMOID_GATE=IS_LOGSIGMOID_GATE, ) diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 64df940..4eb6fe8 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -664,7 +664,7 @@ def _fwd_gated_kernel( TILE_N=TILE_N, QHEAD_PER_KVHEAD_PACKGQA=QHEAD_PER_KVHEAD_PACKGQA, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, IS_LOGSIGMOID_GATE=IS_LOGSIGMOID_GATE, CHECK_INF=True, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py index daab325..afeccd9 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py @@ -545,7 +545,7 @@ def _bwd_sparse_kernel( TILE_M=TILE_M, TILE_N=TILE_N, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, ) ) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 1ac63ca..e8dd2be 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -471,7 +471,7 @@ def _fwd_sparse_kernel( TILE_N=TILE_N, QHEAD_PER_KVHEAD_PACKGQA=QHEAD_PER_KVHEAD_PACKGQA, IS_MASK=True, - MASK_CAUSAL=True, + MASK_CAUSAL=IS_CAUSAL, MASK_LOCAL=True if IS_LOCAL else False, CHECK_INF=True, )