transformers: KIVI-style KV-cache quantization — packed u8 storage, ~4× memory vs f32#2329
Open
czoli1976 wants to merge 3 commits into
Open
transformers: KIVI-style KV-cache quantization — packed u8 storage, ~4× memory vs f32#2329czoli1976 wants to merge 3 commits into
czoli1976 wants to merge 3 commits into
Conversation
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ansform
Completes the KIVI-style KV-cache quantization integration:
1. QuantKeyCache: per-channel u8 storage for Keys. D channels each have a running
scale; new tokens quantized under the current channel scale. Memory: T*D + D*8 bytes.
2. QuantValueCache: per-token u8 storage for Values. Each token D bytes + 2 f32 params.
Memory: T*D + T*8 bytes (~4x vs f32 at large D).
3. QuantizedKvSdpa: stateful fused op (Op/EvalOp/TypedOp + OpState + freeze) that
stores K/V in packed u8, dequantizes per-head on each decode step, attends via
FlashSdpaOp (GQA handled). Real u8 bytes, not just float round-trip quality test.
4. QuantizedKvSdpaTransform: auto-wires {cache(K), cache(V), Sdpa} -> QuantizedKvSdpa.
6 tests: quant quality (3 existing) + packed_u8_saves_memory_vs_f32 (>3x saving) +
quantized_kv_sdpa_runs_in_model (engine correctness: near-lossless vs f32 reference) +
transform_fuses_cache_sdpa_to_quantized (structural auto-wiring). fmt+clippy clean,
transformers 18/0 no regression.
Configurable via the bits parameter (1..=16); int8 = near-lossless 4x vs f32 / 2x vs
f16. CommVQ codebook variant is the follow-on.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
tract_transformers_quantized_kv_sdpa primitive: axis + optional scale. Round-trip test: axis and scale survive write_to_tar -> model_for_read. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
Author
|
@kali is this an interested area ? I was even thinking of an SSD Offload but not sure if that goes too far and should be managed externally of tract |
czoli1976
pushed a commit
to czoli1976/tract
that referenced
this pull request
Jun 5, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Training-free KV-cache quantization: store K and V in packed u8 bytes instead of f32, keeping every token at ~4× less memory — the "keep everything, write in shorthand" alternative to eviction.
The idea
The core asymmetry (Liu et al. 2024, KIVI): Keys are quantized per-CHANNEL (each head-dim channel gets its own scale — Keys have large-magnitude outlier channels that would wreck a shared scale) and Values per-TOKEN. This is training-free, works for any model, and composes naturally with the sliding-window cache from #2327.
Validated on real GPT-2 K/V activations:
The per-channel-K layout matters: int4 per-channel-K is 1.75–1.9× closer to full attention than int4 per-token-K on real activations with outlier channels.
What's in the PR
QuantValueCache— per-token u8 storage: each token D bytes + 2 f32 params. Memory:T×D + T×8bytes.QuantKeyCache— per-channel u8 storage: running scale per channel, updated on each new token. Memory:T×D + D×8bytes.QuantizedKvSdpa— stateful fused op that owns the K/V packed caches, dequantizes per-head on each decode step, and attends viaFlashSdpaOp(GQA handled). Inputs[Q, K_new, V_new], output has Q's shape.QuantizedKvSdpaTransform— auto-wires{DynKeyValueCache(K), DynKeyValueCache(V), Sdpa}→QuantizedKvSdpa, so existing decode models adopt quantized storage transparently (mirrors the pattern from transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op #2321 and onnx,transformers: sliding-window attention — GQA window + bounded ring-buffer decode (#2323) #2327).tract_transformers_quantized_kv_sdpa, registered.Correctness & gates
7 tests: quality validation (round-trip bounded, per-channel beats per-token on outlier channels, 8-bit near-lossless for attention);
packed_u8_saves_memory_vs_f32(>3× measured);quantized_kv_sdpa_runs_in_model(runs through the engine, near-lossless vs f32 reference);transform_fuses_cache_sdpa_to_quantized(structural auto-wiring); NNEF round-trip.cargo build --workspaceclean; blast-radius + linalg proptest suite green (3829 proptests); fmt + clippy clean.Relationship to other PRs
~33 MB→~8 MBKV.Research & prior art
🤖 Generated with Claude Code