metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320
Open
czoli1976 wants to merge 1 commit into
Open
metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
tract's metal crate already vendors libMetalFlashAttention (Apple / Philip
Turner's flash-attention kernels, MIT) but only used its sgemm/hgemm entry
points -- the fused `attention` kernel shipped inside the metallib was never
dispatched. The MetalTransform instead explodes every Sdpa into
einsum + softmax + einsum, materializing the full (B,H,Sq,Sk) score matrix in
device memory and round-tripping it through three separate kernels.
This wires the fused kernel:
* `dispatch_metal_mfa_attention` drives the vendored `attention` function
(online softmax / flash attention; ABI reconstructed from the v1.0.1
source + on-GPU pipeline reflection). f32 and f16.
* `mfa_attention_head_major` adapts tract's native [B,H,S,D] layout to MFA's
(Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd.
* `MetalMfaSdpa` op + a `register_metal_op!(Sdpa)` translator route a real
Sdpa node to the fused kernel; unsupported shapes fall back to the existing
explode path. The new `rewire_sdpa_metal` only flattens the Sdpa nodes the
kernel can't take, leaving fusable ones intact (cuda keeps the shared
`rewire_sdpa`).
* causal masking via an additive [Sq,Sk] mask -- the metallib's `triangular`
function-constant alone is a no-op, pinned by a regression test.
Eliminates the (B,H,Sq,Sk) intermediate and collapses three kernels to one.
Measured on M-series (f32, B=1 H=8 S=512 D=64): the kernel is ~2x the explode
path, and an 8-layer all-attention stack (amortizing host sync) runs 2.70x
faster on the attention portion -- so a real model's end-to-end gain is 2.70x
scaled by attention's compute share.
Correctness: bit-close to a CPU reference across f32/f16, head dims 16..128,
masked, causal, multi-head, and head-major layout; an e2e test builds a real
Sdpa model, runs the MetalTransform, asserts it routes to MetalMfaSdpa, and
matches the CPU FlashSdpa output.
Apple MetalFlashAttention: https://github.com/philipturner/metal-flash-attention
Prior art (fused-attention dispatch): llama.cpp ggml-metal flash-attn kernel;
candle-metal-kernels.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This was referenced May 31, 2026
czoli1976
added a commit
to czoli1976/tract
that referenced
this pull request
Jun 3, 2026
…rite + NNEF + resume
tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:
1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
`axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
any axis), double only when capacity is exceeded -> O(T) amortized copy.
valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.
2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
the buffers inside the consuming op is what makes the saving real. Drop-in for
{kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.
3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
(fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
so existing decode models adopt the in-place cache transparently.
4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).
5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).
Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.
Benched (release, B=1 H=8 D=128):
- cache-update only: 21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
- end-to-end via the op: 1.10x (256) -> 1.63x (2048), 39% faster decode @2k
- resume checkpoint: O(len), 0.10ms (256) -> 1.76ms (4096), one-time
Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
@kali — opening this as much as an RFC as a PR, so please read the design question below before the diff.
As I've seen you've been actively building out the GPU path lately — mirroring the CUDA kernels into Metal (
gather,diag_gather), landingscaled_masked_softmax(bool mask + post-softmax mask) on both backends, and adding the transform pre-check + the CPU-fallback-when-a-gpu-op-rejects-a-shape — I wanted to surface something adjacent that's been sitting unused in the Metal crate and check the direction with you before investing further.The opportunity
The vendored
libMetalFlashAttention.metallib(Apple / Philip Turner's MetalFlashAttention, MIT — already in-tree) ships four functions:sgemm,hgemm,convolution, andattention. We only ever dispatchsgemm/hgemm. Theattentionentry is a fully-implemented fused flash-attention kernel (online softmax, never materializes the score matrix) — and it's simply never called. (For completeness:convolutionin this metallib is an empty stub, soattentionis the only unused-but-real kernel here.)Meanwhile
MetalTransformexplodes everySdpaintoeinsum → softmax → einsum, which materializes the full(B,H,Sq,Sk)score buffer in device memory and round-trips it through three kernels — the middle one being thescaled_masked_softmaxyou just landed.This PR wires the fused kernel and routes
Sdpato it.The design question (why this is an RFC)
Dispatching a vendored, 2023-era
macosx13metallib that we don't own is a real commitment — and I noticed you just dropped the unusedGgmlFlashAttnkernel library on the CUDA side (73dada812), which cuts the other way. So I'd rather ask directly than assume:Is wiring this vendored Metal kernel the direction you want — or would you prefer a fresh, owned port (e.g. translating the MLX /
ggml-metalflash-attention kernel into a.metalsource we control)?The case for wiring it now: it's already in-tree, fully implemented, MIT, and measures ~2× — a low-risk way to close the fused-Metal-SDPA gap today. The owned-
.metal-port hedge stays open as a follow-up if metallib longevity on M3/M4 worries you. This PR is the "wire what's already there" option, fully validated, for you to accept or redirect.What it does
dispatch_metal_mfa_attention— drives the vendoredattentionfunction. ABI (buffers / function-constants / grid geometry) reconstructed from the MFA v1.0.1 source + on-GPU pipeline reflection. f32 + f16.mfa_attention_head_major— adapts tract's native[B,H,S,D]to MFA's layout (Q/O=[R,H,D],K=[H,D,C],V=[C,H,D]) on-device viacopy_nd. The one unavoidable copy is theKtranspose (candidate to fold later).MetalMfaSdpaop +register_metal_op!(Sdpa)translator — routes a realSdpanode to the fused kernel. Unsupported shapes returnNoneand fall through to the CPU-fallback path you just added (85255fdb9). A Metal-localrewire_sdpa_metalflattens only theSdpanodes the kernel can't take, leaving fusable ones intact (CUDA keeps the sharedrewire_sdpauntouched).[Sq,Sk]mask — the metallib'striangularfunction-constant alone is a no-op (it computes full attention), pinned by a regression test.Numbers (Apple M-series, f32, B=1 H=8 S=512 D=64)
dispatch_eval)(B,H,Sq,Sk)score buffer(A single-op model bench reports 3.9×, but that's overhead-inflated — the 2.70× multi-layer figure is the honest one, and it's consistent with the kernel-level ~2×.)
Correctness & gates
Sdpamodel, runsMetalTransform, asserts it routes toMetalMfaSdpa, and matches the CPUFlashSdpaoutput.tract-metalsuite 71/0;cargo build --workspaceclean; fmt + clippy clean on the new code.Orthogonal to #2319 (that one is the CPU
FlashSdpapath; this is Metal).Credits / prior art
ggml-metalflash-attn;candle-metal-kernels.🤖 Generated with Claude Code