A faster Mamba selective scan for Apple Silicon (custom Metal kernel) #1278
createcentury
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi mlx-lm team & community,
mlx_lm/models/mamba.pycurrently implements selective scan as a Pythonfor t in range(T)loop. This works correctly but is sequential — every timestep launches new MLX ops, so prefill cost grows linearly withTin serial.I wrote a Metal Shading Language kernel for the same recurrence that uses a parallel prefix scan over the associative
(a, b)pair operator — the same approach Mamba's CUDAselective_scan_fwd_kernel.cuhtakes. It runs throughmx.fast.metal_kerneland supports:delta_softplusseqlenvia chunked SRAM running prefixssm_state_outoutput for inference state caching (mirrors Mamba CUDA'sparams.x_ptr)Same M4 Max, same
mamba-130m-hfcheckpoint, same prompt + 50 decoded tokens (greedy):All five
state-spaces/mamba-{130m, 370m, 790m, 1.4b, 2.8b}-hfcheckpoints load and generate end-to-end.Repo: https://github.com/createcentury/mamba-metal — kernels are first-class
.metalfiles undermamba_metal/kernels/.Happy to discuss integration approaches if there's interest — the parallel-scan kernel could drop into mlx-lm's Mamba path for prefill while keeping the existing per-step loop for decode (which is already optimal at T=1).
Beta Was this translation helpful? Give feedback.
All reactions