transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321
Open
czoli1976 wants to merge 1 commit into
Open
transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
Contributor
Author
|
@kali look at me first |
kali
pushed a commit
that referenced
this pull request
Jun 2, 2026
For models trained with sliding-window attention (Mistral, Gemma-style local/global): a fixed-capacity ring buffer that overwrites the oldest slot on append, so decode runs at CONSTANT memory + per-step cost regardless of context length, losslessly (the model is trained to attend only within the window). Cheap because decode attention is ORDER-INVARIANT over keys (O = Σ softmax_j·V_j is unchanged under a (K,V) permutation), so the ring buffer never needs un-rotation — the consumer attends over the W physical slots as-is. Validated: holds the last-W as a set (incl. prefill chunk > window); windowed attention == ordered last-W attention (close, float summation order); memory bounded at W. Companion to the in-place cache (#2321) = 'in-place cache with a cap + wraparound'. 3 tests, fmt+clippy clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rite + NNEF + resume
tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:
1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
`axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
any axis), double only when capacity is exceeded -> O(T) amortized copy.
valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.
2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
the buffers inside the consuming op is what makes the saving real. Drop-in for
{kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.
3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
(fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
so existing decode models adopt the in-place cache transparently.
4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).
5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).
Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.
Benched (release, B=1 H=8 D=128):
- cache-update only: 21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
- end-to-end via the op: 1.10x (256) -> 1.63x (2048), 39% faster decode @2k
- resume checkpoint: O(len), 0.10ms (256) -> 1.76ms (4096), one-time
Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
eb63080 to
8837149
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
@kali — opening this as much as an RFC as a PR: it adds an opt-in in-place KV-cache path for decode, and since it sits right next to your
DynKeyValueCache+freeze_intowork I'd like your read on the shape before it's merge-bound. It's additive and opt-in — no default behavior changes.Independent of #2319 / #2320 — not stacked
This branches off
mainand can be reviewed and merged on its own, in any order:flash_sdpa.rs; this only addsinplace_kv_cache.rsplus two wiring lines (ops/mod.rs,lib.rs). No conflict. The fused op reuses the existingFlashSdpaOp::flash_attention_gqaconsumer that transformers: CPU FlashSdpa — contiguous P·V GEMM + head-parallel exec + seq-len lowering heuristic #2319 happens to speed up, so the two are synergistic but not coupled — this works on the currentmainconsumer regardless of transformers: CPU FlashSdpa — contiguous P·V GEMM + head-parallel exec + seq-len lowering heuristic #2319.The problem
DynKeyValueCachegrows the cache withTypedConcat([past, new])every step, so steptcopies the wholet-token past into a fresh buffer — O(T²) total copy over aT-token decode (plus an allocation per step). The attention compute is already O(T²); this is pure cache-management overhead on top.The design choice (the RFC bit)
The catch: in-place growth only pays off if the consumer reads the cache without re-copying. A naive "preallocate + write-at-cursor, then slice
[0..len]for the consumer" is a wash —Tensor::slicecopies, so the per-step slice-to-valid reintroduces the O(T²). And tractTensors can't be zero-copy views across an op boundary.So I kept the K/V buffers inside the op that consumes them: a stateful fused op (
InPlaceKvSdpa) that owns two in-place caches and runs the existing CPU SDPA (FlashSdpaOp::flash_attention_gqa) over zero-copy[0..len]views. It's a drop-in for the{kv_cache(K), kv_cache(V), Sdpa}subgraph — same output by construction (same kernel, same K/V) — and because it does GQA internally, fusing also removes the unsqueeze/broadcast/reshape chain.The question for you: is fusing cache+attention the direction you want, or would you rather (a) make
DynKeyValueCacheitself in-place + a length-awareSdpareadingbuffer + length, or (b) a core-level zero-copy sub-tensor (Tensor=Arcbuffer + offset/len) so the cache can output a view — more general, bigger change? I built (the fused op) because it needs no core changes and is provably equivalent; happy to redirect.What's in the PR
InPlaceKvCache— geometric-growth (Vec-style doubling) in-place cache; appends viaTensor::assign_slice(strided-safe for any axis);valid_view()is a zero-copy ndarray view of[0..len]. O(T) amortized copy. (For the seq axis of[B,H,S,D]a per-head slice of the capacity buffer is a contiguous prefix, so the consumer reads it at concat cost.)InPlaceKvSdpa— the stateful fused op (Op/EvalOp/TypedOp+OpState+OpStateFreeze).InPlaceKvSdpaTransform— an opt-inModelTransform(likeKeyValueCacheTransform) that strips the GQA broadcast chain (reusing yourfuse_kv_cache_broadcast_rule) then fusescache → SdpaintoInPlaceKvSdpa. Apply it to get in-place decode; don't, and nothing changes.save_to/load_fromcheckpoint the cache as[K,V];freeze/unfreezesnapshot the running state — extends yourfreeze_into).Numbers (Apple M-series, f32, B=1 H=8 D=128)
save_to)These are op-level microbenches on synthetic attention — I can add a real decode-model wall-clock A/B (transform on/off, M1/M4) if you'd want that for the merge bar.
Correctness
InPlaceKvCachebit-exact vs concat-grow (multi-axis + decode); the fused op matches theconcat-cache + FlashSdpaOpbaseline over prefill+decode, GQA, causal/non-causal; runs end-to-end through a persistentSimpleState; the rewrite fires + the rewritten model matches the baseline; NNEF round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized (≤12 reallocs / 1024 pushes).cargo build --workspaceclean;tract-transformers+ blast-radius suite green; fmt + clippy clean.Apple research & prior art
MLState/StateType, in-place mutable cache/recurrent tensors).kv_cache.rs; llama.cppllama_memory_i/ past-present share-buffer; ONNX Runtimepast_present_share_buffer.Related
Cfor both the key-loop bound and the K address stride, and K is[H,D,C]so the valid prefix is strided — I validated this on-GPU). In-place K on Metal would need an owned.metalport — noted on metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×) #2320. This PR is the CPU/cross-backend path available today.🤖 Generated with Claude Code