Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_bwd_fwd(
B,
S,
H,
G,
N,
P,
R,
Expand All @@ -56,6 +52,13 @@ def mamba_mimo_bwd_fwd(
threads: int = 128,
num_stages: int = 0,
) -> torch.Tensor:
# Dynamic dimensions declared inside the factory to avoid TileLang
# cache-key churn when batch/sequence/head dimensions vary at runtime
# (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")

accum_dtype = 'float32'

Expand Down Expand Up @@ -503,10 +506,6 @@ def mamba_mimo_bwd_fwd_kernel(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_bwd_bwd(
B,
S,
H,
G,
N,
P,
R,
Expand All @@ -519,6 +518,13 @@ def mamba_mimo_bwd_bwd(
threads: int = 256,
num_stages: int = 0,
) -> torch.Tensor:
# Dynamic dimensions declared inside the factory to avoid TileLang
# cache-key churn when batch/sequence/head dimensions vary at runtime
# (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")

accum_dtype = 'float32'

Expand Down Expand Up @@ -1203,10 +1209,6 @@ def mamba_mimo_bwd_combined(
else:
dtype_str = dtype
bwd_fwd_kernel = mamba_mimo_bwd_fwd(
T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
z is not None,
D is not None,
Expand Down Expand Up @@ -1261,10 +1263,6 @@ def mamba_mimo_bwd_combined(


bwd_bwd_kernel = mamba_mimo_bwd_bwd(
T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
z is not None,
D is not None,
Expand Down
46 changes: 24 additions & 22 deletions mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,12 @@
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_bwd_fwd(
B,
S,
H,
G,
N,
P,
R,
hasZ,
hasD,
reduceO,
NS: int = 1,
isVarlen: bool = True,
chunk_size: int = 16,
rotary_dim_divisor: int = 4,
Expand All @@ -96,6 +91,15 @@ def mamba_mimo_bwd_fwd(
* STATES shape: ``[B, H, max_nchunks, N, P]`` with
``max_nchunks = (S // chunk_size) + NS``.
"""
# Dynamic dimensions declared inside the factory to avoid TileLang
# cache-key churn when batch/sequence/head/sequence-count dimensions
# vary at runtime (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")
NS = T.dynamic("NS")

accum_dtype = 'float32'
max_nchunks = (S // chunk_size) + NS
fused_chunk_size = chunk_size * R
Expand Down Expand Up @@ -543,17 +547,12 @@ def mamba_mimo_bwd_fwd_kernel(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_bwd_bwd(
B,
S,
H,
G,
N,
P,
R,
hasZ,
hasD,
reduceO,
NS: int = 1,
isVarlen: bool = False,
chunk_size: int = 16,
rotary_dim_divisor: int = 4,
Expand All @@ -578,6 +577,15 @@ def mamba_mimo_bwd_bwd(
* DSSDA shape: ``[B, H, max_nchunks, C, C]`` with
``max_nchunks = (S // chunk_size) + NS``.
"""
# Dynamic dimensions declared inside the factory to avoid TileLang
# cache-key churn when batch/sequence/head/sequence-count dimensions
# vary at runtime (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")
NS = T.dynamic("NS")

accum_dtype = 'float32'
max_nchunks = (S // chunk_size) + NS
fused_chunk_size = chunk_size * R
Expand Down Expand Up @@ -1319,14 +1327,11 @@ def mamba_mimo_bwd_combined_varlen(
qk_dot = torch.zeros([B, H, S, R, R], dtype=q.dtype, device=q.device)

bwd_fwd_kernel = mamba_mimo_bwd_fwd(
T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
z is not None, D is not None, reduceO,
T.dynamic("NS"), cu_seqlens is not None, chunk_size, rotary_dim_divisor, dtype_str,
bf_threads, bf_num_stages)
isVarlen=cu_seqlens is not None,
chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor, dtype=dtype_str,
threads=bf_threads, num_stages=bf_num_stages)

bwd_fwd_kernel(
dout, q, k, v, q_bias, k_bias, mimo_v, mimo_o,
Expand All @@ -1352,14 +1357,11 @@ def mamba_mimo_bwd_combined_varlen(
ddA_cs = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device)

bwd_bwd_kernel = mamba_mimo_bwd_bwd(
T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
z is not None, D is not None, reduceO,
T.dynamic("NS"), cu_seqlens is not None, chunk_size, rotary_dim_divisor, dtype_str,
bb_threads, bb_num_stages)
isVarlen=cu_seqlens is not None,
chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor, dtype=dtype_str,
threads=bb_threads, num_stages=bb_num_stages)

bwd_bwd_kernel(
dout, q, k, v, q_bias, k_bias, mimo_v, mimo_o,
Expand Down
19 changes: 10 additions & 9 deletions mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_fwd(
B,
S,
H,
G,
N,
P,
R,
Expand All @@ -53,6 +49,15 @@ def mamba_mimo_fwd(
threads: int = 128,
num_stages: int = 0,
) -> torch.Tensor:
# Dynamic dimensions are declared inside the factory so they are NOT
# included in TileLang's cache key. Passing them as Python arguments
# would cause a distinct cache entry for every (B, S, H, G) combination
# even though the generated PrimFunc is structurally identical, leading
# to cache churn / recompilation warnings (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")

accum_dtype = 'float32'

Expand Down Expand Up @@ -436,11 +441,7 @@ def mamba_mimo_forward(q, k, v,
else:
tl_dtype = dtype
reduceO = mimo_o is not None
kernel = mamba_mimo_fwd(T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
kernel = mamba_mimo_fwd(N, P, R,
z is not None,
D is not None,
reduceO,
Expand Down
20 changes: 9 additions & 11 deletions mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,12 @@
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mamba_mimo_fwd(
B,
S,
H,
G,
N,
P,
R,
hasZ,
hasD,
reduceO,
NS: int = 1,
isVarlen: bool = True,
return_final_state=False,
chunk_size: int = 16,
Expand All @@ -93,6 +88,14 @@ def mamba_mimo_fwd(
When NS == 1 (or cu_seqlens is None) the kernel degenerates to the
non-varlen behaviour.
"""
# Dynamic dimensions declared inside the factory to avoid TileLang
# cache-key churn when batch/sequence/head/sequence-count dimensions
# vary at runtime (tile-ai/tilelang#1934).
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")
NS = T.dynamic("NS")

accum_dtype = 'float32'

Expand Down Expand Up @@ -600,15 +603,10 @@ def mamba_mimo_forward_varlen(q, k, v,
NS = 1

reduceO = mimo_o is not None
kernel = mamba_mimo_fwd(T.dynamic("B"),
T.dynamic("S"),
T.dynamic("H"),
T.dynamic("G"),
N, P, R,
kernel = mamba_mimo_fwd(N, P, R,
z is not None,
D is not None,
reduceO,
NS=T.dynamic("NS"),
isVarlen=cu_seqlens is not None,
return_final_state=return_state,
chunk_size=chunk_size,
Expand Down