Skip to content

server: add KV cache quantization flags (--kv-bits, --kv-group-size, --quantized-kv-start)#1353

Open
soobrosa wants to merge 3 commits into
ml-explore:mainfrom
soobrosa:feat/server-kv-cache-quant
Open

server: add KV cache quantization flags (--kv-bits, --kv-group-size, --quantized-kv-start)#1353
soobrosa wants to merge 3 commits into
ml-explore:mainfrom
soobrosa:feat/server-kv-cache-quant

Conversation

@soobrosa

@soobrosa soobrosa commented Jun 5, 2026

Copy link
Copy Markdown

Add KV cache quantization flags to mlx_lm.server

Closes #1043 (see also #1308).

mlx_lm.generate has supported quantized KV cache for a while, but mlx_lm.server did
not expose it, so long-context serving was stuck at fp16 KV — which OOMs well before the
model's real context limit on memory-bound Apple Silicon.

What this does

Adds three CLI flags to mlx_lm.server:

  • --kv-bits (default None → no quantization)
  • --kv-group-size (default 64)
  • --quantized-kv-start (default 5000)

These are threaded straight into the existing stream_generate path, reusing the same
quantized-KV implementation mlx_lm.generate already uses.

Design note (batched path)

BatchKVCache has no quantized variant yet, so when --kv-bits is set, _is_batchable()
returns False and requests are served through the single-sequence generator (which does
support quantized KV). This trades batched throughput for correctness when quantization is
explicitly requested, and is forward-compatible with the batched-quantized work in #1322.

Validation

Tested against Qwen3.6-35B-A3B (hybrid GDN + MoE, MLX 4-bit) served with
--kv-bits 4 --kv-group-size 64:

  • short-prompt correctness: OK
  • long-context fact recall: a 22.3K-token prompt (facts up front, question after ~16K
    filler — past quantized_kv_start) returned the correct answer, no crash on the hybrid
    cache layout.

vs #1309

This PR additionally exposes --kv-group-size / --quantized-kv-start and routes through
the proven stream_generate path rather than building caches manually.

@soobrosa soobrosa marked this pull request as ready for review June 6, 2026 15:47
@nastya236 nastya236 added the enhancement New feature or request label Jun 7, 2026
soobrosa pushed a commit to soobrosa/dataroom that referenced this pull request Jun 7, 2026
Detects --kv-bits in mlx_lm.server (ml-explore/mlx-lm#1353); when present,
serves with --kv-bits 4 --kv-group-size 64 and raises MLX_CTX_CAP 75K->85K
(verified greedy-lossless + long-ctx recall). Falls back to fp16 KV + 75K cap
on stock mlx-lm, so the change is safe before #1353 ships.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Daniel Molnar and others added 2 commits June 7, 2026 19:55
Wire KV cache quantization (already supported by generate_step) into the
server, addressing ml-explore#1043. Because the batched path's BatchKVCache has no
quantized variant yet, enabling --kv-bits routes requests through the
single-sequence stream_generate path, which does support quantized KV.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Wire make_prompt_cache(max_kv_size=...) through ModelProvider so the server can
bound the per-request KV cache (RotatingKVCache eviction) for standard models.
Folds in the --max-kv-size flag from the closed PR ml-explore#1362 (thanks @0xSoftBoi),
making this PR a superset. Note: models providing their own make_cache (some
hybrid architectures) ignore max_kv_size by design.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
@soobrosa soobrosa force-pushed the feat/server-kv-cache-quant branch from 290c282 to 059a29c Compare June 7, 2026 18:04
@sunpazed

sunpazed commented Jun 9, 2026

Copy link
Copy Markdown

I noticed an issue with this PR, where RotatingKVCache quantization causes crash in hybrid cache models (ie; Gemma 4). When using models with hybrid caching strategies (like Gemma 4) that include RotatingKVCache instances, the server crashes with NotImplementedError: RotatingKVCache Quantization NYI if KV cache quantization is enabled.

mlx_lm.server --model mlx-community/gemma-4-26B-A4B-it-qat-4bit --kv-bits 8

I believe that the Gemma 4 model architecture uses a hybrid cache strategy, ie; some layers use regular KVCache() (full attention), and some layers use RotatingKVCache() with a sliding window.

@soobrosa

Copy link
Copy Markdown
Author

@sunpazed this looks better to me now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add KV cache quantization support to server

3 participants