Skip to content

transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/inplace-kv-cache
Open

transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/inplace-kv-cache

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

@czoli1976 czoli1976 commented May 31, 2026

@kali — opening this as much as an RFC as a PR: it adds an opt-in in-place KV-cache path for decode, and since it sits right next to your DynKeyValueCache + freeze_into work I'd like your read on the shape before it's merge-bound. It's additive and opt-in — no default behavior changes.

Independent of #2319 / #2320 — not stacked

This branches off main and can be reviewed and merged on its own, in any order:

The problem

DynKeyValueCache grows the cache with TypedConcat([past, new]) every step, so step t copies the whole t-token past into a fresh buffer — O(T²) total copy over a T-token decode (plus an allocation per step). The attention compute is already O(T²); this is pure cache-management overhead on top.

The design choice (the RFC bit)

The catch: in-place growth only pays off if the consumer reads the cache without re-copying. A naive "preallocate + write-at-cursor, then slice [0..len] for the consumer" is a wash — Tensor::slice copies, so the per-step slice-to-valid reintroduces the O(T²). And tract Tensors can't be zero-copy views across an op boundary.

So I kept the K/V buffers inside the op that consumes them: a stateful fused op (InPlaceKvSdpa) that owns two in-place caches and runs the existing CPU SDPA (FlashSdpaOp::flash_attention_gqa) over zero-copy [0..len] views. It's a drop-in for the {kv_cache(K), kv_cache(V), Sdpa} subgraph — same output by construction (same kernel, same K/V) — and because it does GQA internally, fusing also removes the unsqueeze/broadcast/reshape chain.

The question for you: is fusing cache+attention the direction you want, or would you rather (a) make DynKeyValueCache itself in-place + a length-aware Sdpa reading buffer + length, or (b) a core-level zero-copy sub-tensor (Tensor = Arc buffer + offset/len) so the cache can output a view — more general, bigger change? I built (the fused op) because it needs no core changes and is provably equivalent; happy to redirect.

What's in the PR

  • InPlaceKvCache — geometric-growth (Vec-style doubling) in-place cache; appends via Tensor::assign_slice (strided-safe for any axis); valid_view() is a zero-copy ndarray view of [0..len]. O(T) amortized copy. (For the seq axis of [B,H,S,D] a per-head slice of the capacity buffer is a contiguous prefix, so the consumer reads it at concat cost.)
  • InPlaceKvSdpa — the stateful fused op (Op/EvalOp/TypedOp + OpState + OpStateFreeze).
  • InPlaceKvSdpaTransform — an opt-in ModelTransform (like KeyValueCacheTransform) that strips the GQA broadcast chain (reusing your fuse_kv_cache_broadcast_rule) then fuses cache → Sdpa into InPlaceKvSdpa. Apply it to get in-place decode; don't, and nothing changes.
  • NNEF ser/de and resume (save_to/load_from checkpoint the cache as [K,V]; freeze/unfreeze snapshot the running state — extends your freeze_into).

Numbers (Apple M-series, f32, B=1 H=8 D=128)

measurement result
end-to-end decode through the op (update + attention), the representative figure 1.10× (T=256) → 1.63× (T=2048), grows with T
cache-update only (the asymptotic mechanism) 21× (256) → 709× (4096), O(T²) → O(T)
resume checkpoint (save_to) O(len), 0.10 ms (256) → 1.76 ms (4096), one-time

These are op-level microbenches on synthetic attention — I can add a real decode-model wall-clock A/B (transform on/off, M1/M4) if you'd want that for the merge bar.

Note on framing — this is not Apple's reported ~13× (1.25 → 16.3 tok/s), and the difference is a baseline mismatch, not an apples-to-apples KV-cache delta:

  • Apple's 13× is the aggregate of their whole Core ML deployment (int4 palettization + fused attention + stateful KV + ANE execution), measured against a much weaker starting point. Critically, their baseline had no effective cache, so the 13× folds in eliminating O(T²) recompute of K/V (and the quantization win) — not just cache copies.
  • tract already caches (no recompute), so the only thing left to remove is the cache-management copy. This PR's honest contribution is narrowly O(T²) → O(T) on that copy — the 1.1–1.6× end-to-end above. Quantization/fused-attention gains are orthogonal (and on Metal already land via metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×) #2320).

Correctness

InPlaceKvCache bit-exact vs concat-grow (multi-axis + decode); the fused op matches the concat-cache + FlashSdpaOp baseline over prefill+decode, GQA, causal/non-causal; runs end-to-end through a persistent SimpleState; the rewrite fires + the rewritten model matches the baseline; NNEF round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized (≤12 reallocs / 1024 pushes). cargo build --workspace clean; tract-transformers + blast-radius suite green; fmt + clippy clean.

Apple research & prior art

  • Motivation — Apple Core ML stateful in-place KV cache (Llama 3.1 on Core ML: MLState/StateType, in-place mutable cache/recurrent tensors).
  • Prior art for the cache mechanism: candle-nn kv_cache.rs; llama.cpp llama_memory_i / past-present share-buffer; ONNX Runtime past_present_share_buffer.

Related

🤖 Generated with Claude Code

@czoli1976
Copy link
Copy Markdown
Contributor Author

@kali look at me first

kali pushed a commit that referenced this pull request Jun 2, 2026
For models trained with sliding-window attention (Mistral, Gemma-style local/global):
a fixed-capacity ring buffer that overwrites the oldest slot on append, so decode runs
at CONSTANT memory + per-step cost regardless of context length, losslessly (the model
is trained to attend only within the window).

Cheap because decode attention is ORDER-INVARIANT over keys (O = Σ softmax_j·V_j is
unchanged under a (K,V) permutation), so the ring buffer never needs un-rotation — the
consumer attends over the W physical slots as-is. Validated: holds the last-W as a set
(incl. prefill chunk > window); windowed attention == ordered last-W attention (close,
float summation order); memory bounded at W. Companion to the in-place cache (#2321) =
'in-place cache with a cap + wraparound'. 3 tests, fmt+clippy clean.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rite + NNEF + resume

tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:

1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
   `axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
   any axis), double only when capacity is exceeded -> O(T) amortized copy.
   valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
   that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
   capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.

2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
   SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
   cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
   the buffers inside the consuming op is what makes the saving real. Drop-in for
   {kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.

3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
   (fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
   so existing decode models adopt the in-place cache transparently.

4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).

5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
   snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).

Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.

Benched (release, B=1 H=8 D=128):
  - cache-update only:      21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
  - end-to-end via the op:  1.10x (256) -> 1.63x (2048), 39% faster decode @2k
  - resume checkpoint:      O(len), 0.10ms (256) -> 1.76ms (4096), one-time

Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@czoli1976 czoli1976 force-pushed the feature/inplace-kv-cache branch from eb63080 to 8837149 Compare June 3, 2026 19:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant