Skip to content

Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims#946

Draft
Copilot wants to merge 2 commits into
mainfrom
copilot/update-mamba3-kernels-for-tilelang-cache
Draft

Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims#946
Copilot wants to merge 2 commits into
mainfrom
copilot/update-mamba3-kernels-for-tilelang-cache

Conversation

Copy link
Copy Markdown

Copilot AI commented May 10, 2026

In TileLang 0.1.8, @tilelang.jit cache keys are derived from Python call-site arguments. Passing T.dynamic("B") etc. as function parameters still creates distinct cache entries per unique argument object, causing recompilation and cache-miss warnings even for structurally identical kernels (tile-ai/tilelang#1934).

Fix

Move dynamic dimension declarations (B, S, H, G, and NS for varlen paths) inside the jitted factory bodies instead of passing them as parameters. TileLang then binds them from runtime tensor shapes at invocation time, and the cache key reduces to true compile-time specialization knobs only.

Before:

kernel = mamba_mimo_fwd(T.dynamic("B"), T.dynamic("S"), T.dynamic("H"), T.dynamic("G"),
                        N, P, R, hasZ, hasD, reduceO, ...)

After:

# Inside mamba_mimo_fwd factory — not in the signature:
B = T.dynamic("B")
S = T.dynamic("S")
H = T.dynamic("H")
G = T.dynamic("G")

# Call site — only static specialization args:
kernel = mamba_mimo_fwd(N, P, R, hasZ, hasD, reduceO, ...)

Files changed

  • mamba3_mimo_fwd.pymamba_mimo_fwd: remove B, S, H, G from signature; declare inside; update mamba_mimo_forward call site.
  • mamba3_mimo_bwd.pymamba_mimo_bwd_fwd and mamba_mimo_bwd_bwd: same; update mamba_mimo_bwd_combined call sites.
  • mamba3_mimo_fwd_varlen.pymamba_mimo_fwd (varlen): additionally removes NS; retains isVarlen as it gates kernel structure branches.
  • mamba3_mimo_bwd_varlen.pymamba_mimo_bwd_fwd and mamba_mimo_bwd_bwd (varlen): same; update mamba_mimo_bwd_combined_varlen call sites.

No math or output behavior changes.

Original prompt

Update the Mamba-3 MIMO TileLang kernels in the repository state-spaces/mamba to avoid TileLang dynamic-dimension cache churn described in tile-ai/tilelang#1934.

Context:

  • Repository: state-spaces/mamba
  • The repository pins tilelang==0.1.8 in pyproject.toml.
  • The current TileLang MIMO kernel factories pass dynamic dimensions as parameters to @tilelang.jit-decorated functions and then call them using T.dynamic(...) at the call site.
  • In TileLang 0.1.8, per issue tile-ai/tilelang#1934, cache keys depend on the Python arguments to the jitted function. Passing dynamic dimensions through the jitted function signature can therefore cause recompilation/cache misses or noisy cache warnings even when the generated PrimFunc is structurally identical.

Please modify the following files:

  • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py
  • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py
  • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd_varlen.py
  • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd_varlen.py

Goal:

  • Refactor the @tilelang.jit kernel factory signatures so that dynamic runtime dimensions are not passed as Python parameters to the jitted factories.
  • Instead, define the dynamic dimensions inside the jitted factory using T.dynamic(...) and let runtime tensor shapes bind them at kernel invocation time.
  • Preserve current math and output behavior.
  • Preserve existing compile-time specialization parameters where appropriate (for example dimensions/flags that truly affect kernel structure, booleans such as hasZ/hasD/reduceO, and tuning/config params such as chunk_size, threads, num_stages, dtype, etc.).
  • Ensure both dense and varlen paths are updated consistently.

Likely dynamic dimensions to internalize (confirm from the code):

  • dense forward/backward: B, S, H, G
  • varlen forward/backward: likely B, S, H, G, and any per-call sequence-count dimension that should be runtime-bound if it is only dynamic shape state rather than true kernel specialization. Be careful not to remove parameters that actually determine distinct kernel structure or memory layout.

Implementation guidance:

  • Keep the public wrapper functions callable the same way if possible.
  • Update the wrapper call sites so they no longer pass T.dynamic("...") arguments into the jitted factory for dimensions moved inside the factory.
  • Review all dependent shape expressions (nchunks, max_nchunks, output tensor shapes, tensor annotations, kernels, and grid definitions) so they continue to work after moving dynamic symbols inside the factory.
  • Ensure the code remains valid TileLang/TIR and consistent with current style.
  • If needed, add or update comments/docstrings to explain that dynamic dims are intentionally declared inside the jitted factory to improve TileLang cache reuse under 0.1.8.

Validation:

  • Make sure imports and signatures are coherent after refactor.
  • Confirm no stale references remain that still pass runtime dynamic dims into the jitted factories.
  • If there are tests or simple sanity checks for these kernels in the repo, run or update the most relevant ones.

Please create a pull request with a concise title and summary describing the cache-key / dynamic-dimension fix for Mamba-3 MIMO TileLang kernels.

The following is the prior conversation context from the user's chat exploration (may be truncated):

User: https://github.com/state-spaces/mamba/blob/main/pyproject.toml

this is the right github id. can you go and look at the correct kernel file and come back with a diagnosis. we are using mamba-3 mimo tilelang kernels

For better performance, consider using @tilelang. jit instead of direct kernel caching.
2026-05-0809:06:47 [Titelang:tilelang.cache.kernel_cache:WARNING] (kernel_cache.py:322): Found kernel mamba_mimo_bwd_bwd_k
ernel' in memory cache. For better performance, consider using @tilelang. jit' instead of direct kernel caching.
2026-05-08 09:06:47mo[TileLang: tilelang.cache.kernel_cache:WARNING] (kernel_cache.py: 322): Found kernel 'mamba_mimo_fwd_kerne

Assistant: Yes — the warning is coming from the correct Mamba-3 MIMO TileLang kernels, and the immediate diagnosis is:

this is probably not a functional kernel bug; it’s a TileLang kernel-caching usage warning, and the Mamba repo is already using @tilelang.jit on the right kernels.
So if you’re seeing this repeatedly, the more likely problem is how these kernels are being invoked/imported in your runtime, not that the kernel files themselves forgot to use @tilelang.jit.

What I checked

From the correct repo state-spaces/mamba:

  • pyproject.toml pins:
    • tilelang==0.1.8
  • The Mamba-3 MIMO TileLang kernels are here:
    • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py
    • mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py

And the exact kernels from your warning are defined there:

  • mamba_mimo_fwd_kernel
  • mamba_mimo_bwd_bwd_kernel

Important finding

Both kernel factorie...

This pull request was created from Copilot chat.

…ilelang#1934)

Agent-Logs-Url: https://github.com/state-spaces/mamba/sessions/6350492e-c136-47ab-93bf-c04150d6948e

Co-authored-by: aakashlahoti <22520615+aakashlahoti@users.noreply.github.com>
Copilot AI changed the title [WIP] Update Mamba-3 MIMO TileLang kernels to avoid dynamic-dimension cache churn Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims May 10, 2026
Copilot AI requested a review from aakashlahoti May 10, 2026 04:21
@peterbjorgensen
Copy link
Copy Markdown

Thanks for looking at this. I believe I get cache misses even with constant B, S, H, G dimensions on the current main branch. This PR seems to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants