Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims#946
Draft
Copilot wants to merge 2 commits into
Draft
Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims#946Copilot wants to merge 2 commits into
Copilot wants to merge 2 commits into
Conversation
…ilelang#1934) Agent-Logs-Url: https://github.com/state-spaces/mamba/sessions/6350492e-c136-47ab-93bf-c04150d6948e Co-authored-by: aakashlahoti <22520615+aakashlahoti@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Update Mamba-3 MIMO TileLang kernels to avoid dynamic-dimension cache churn
Fix Mamba-3 MIMO TileLang kernel cache churn for dynamic batch/seq/head dims
May 10, 2026
|
Thanks for looking at this. I believe I get cache misses even with constant B, S, H, G dimensions on the current main branch. This PR seems to fix it. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In TileLang 0.1.8,
@tilelang.jitcache keys are derived from Python call-site arguments. PassingT.dynamic("B")etc. as function parameters still creates distinct cache entries per unique argument object, causing recompilation and cache-miss warnings even for structurally identical kernels (tile-ai/tilelang#1934).Fix
Move dynamic dimension declarations (
B,S,H,G, andNSfor varlen paths) inside the jitted factory bodies instead of passing them as parameters. TileLang then binds them from runtime tensor shapes at invocation time, and the cache key reduces to true compile-time specialization knobs only.Before:
After:
Files changed
mamba3_mimo_fwd.py—mamba_mimo_fwd: removeB, S, H, Gfrom signature; declare inside; updatemamba_mimo_forwardcall site.mamba3_mimo_bwd.py—mamba_mimo_bwd_fwdandmamba_mimo_bwd_bwd: same; updatemamba_mimo_bwd_combinedcall sites.mamba3_mimo_fwd_varlen.py—mamba_mimo_fwd(varlen): additionally removesNS; retainsisVarlenas it gates kernel structure branches.mamba3_mimo_bwd_varlen.py—mamba_mimo_bwd_fwdandmamba_mimo_bwd_bwd(varlen): same; updatemamba_mimo_bwd_combined_varlencall sites.No math or output behavior changes.
Original prompt
Update the Mamba-3 MIMO TileLang kernels in the repository
state-spaces/mambato avoid TileLang dynamic-dimension cache churn described intile-ai/tilelang#1934.Context:
state-spaces/mambatilelang==0.1.8inpyproject.toml.@tilelang.jit-decorated functions and then call them usingT.dynamic(...)at the call site.tile-ai/tilelang#1934, cache keys depend on the Python arguments to the jitted function. Passing dynamic dimensions through the jitted function signature can therefore cause recompilation/cache misses or noisy cache warnings even when the generated PrimFunc is structurally identical.Please modify the following files:
mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.pymamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.pymamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd_varlen.pymamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd_varlen.pyGoal:
@tilelang.jitkernel factory signatures so that dynamic runtime dimensions are not passed as Python parameters to the jitted factories.T.dynamic(...)and let runtime tensor shapes bind them at kernel invocation time.hasZ/hasD/reduceO, and tuning/config params such aschunk_size,threads,num_stages,dtype, etc.).Likely dynamic dimensions to internalize (confirm from the code):
B,S,H,GB,S,H,G, and any per-call sequence-count dimension that should be runtime-bound if it is only dynamic shape state rather than true kernel specialization. Be careful not to remove parameters that actually determine distinct kernel structure or memory layout.Implementation guidance:
T.dynamic("...")arguments into the jitted factory for dimensions moved inside the factory.nchunks,max_nchunks, output tensor shapes, tensor annotations, kernels, and grid definitions) so they continue to work after moving dynamic symbols inside the factory.Validation:
Please create a pull request with a concise title and summary describing the cache-key / dynamic-dimension fix for Mamba-3 MIMO TileLang kernels.
The following is the prior conversation context from the user's chat exploration (may be truncated):
User: https://github.com/state-spaces/mamba/blob/main/pyproject.toml
this is the right github id. can you go and look at the correct kernel file and come back with a diagnosis. we are using mamba-3 mimo tilelang kernels
For better performance, consider using @tilelang. jit instead of direct kernel caching.
2026-05-0809:06:47 [Titelang:tilelang.cache.kernel_cache:WARNING] (kernel_cache.py:322): Found kernel mamba_mimo_bwd_bwd_k
ernel' in memory cache. For better performance, consider using @tilelang. jit' instead of direct kernel caching.
2026-05-08 09:06:47mo[TileLang: tilelang.cache.kernel_cache:WARNING] (kernel_cache.py: 322): Found kernel 'mamba_mimo_fwd_kerne
Assistant: Yes — the warning is coming from the correct Mamba-3 MIMO TileLang kernels, and the immediate diagnosis is:
this is probably not a functional kernel bug; it’s a TileLang kernel-caching usage warning, and the Mamba repo is already using
@tilelang.jiton the right kernels.So if you’re seeing this repeatedly, the more likely problem is how these kernels are being invoked/imported in your runtime, not that the kernel files themselves forgot to use
@tilelang.jit.What I checked
From the correct repo
state-spaces/mamba:pyproject.tomlpins:tilelang==0.1.8mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.pymamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.pyAnd the exact kernels from your warning are defined there:
mamba_mimo_fwd_kernelmamba_mimo_bwd_bwd_kernelImportant finding
Both kernel factorie...
This pull request was created from Copilot chat.