From 323bf428a47435a5438448f6ad90766fd87582c8 Mon Sep 17 00:00:00 2001 From: Daniel Molnar Date: Fri, 5 Jun 2026 18:40:28 +0200 Subject: [PATCH 1/3] server: add --kv-bits / --kv-group-size / --quantized-kv-start Wire KV cache quantization (already supported by generate_step) into the server, addressing #1043. Because the batched path's BatchKVCache has no quantized variant yet, enabling --kv-bits routes requests through the single-sequence stream_generate path, which does support quantized KV. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- mlx_lm/server.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..71d835cb1 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -683,6 +683,11 @@ def _make_state_machine( return sm, sequences def _is_batchable(self, args): + # The batched generator uses BatchKVCache, which has no quantized + # variant yet, so route through the single-sequence path (which does + # support KV quantization) whenever --kv-bits is requested. + if self.cli_args.kv_bits is not None: + return False return self.model_provider.is_batchable and args.seed is None def _generate(self): @@ -985,6 +990,9 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + kv_bits=self.cli_args.kv_bits, + kv_group_size=self.cli_args.kv_group_size, + quantized_kv_start=self.cli_args.quantized_kv_start, ): finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) @@ -1884,6 +1892,27 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--kv-bits", + type=int, + default=None, + help="Number of bits for KV cache quantization. Defaults to no " + "quantization. Note: enabling this serves requests sequentially " + "(the batched path has no quantized cache yet).", + ) + parser.add_argument( + "--kv-group-size", + type=int, + default=64, + help="Group size for KV cache quantization (default: 64).", + ) + parser.add_argument( + "--quantized-kv-start", + type=int, + default=5000, + help="When --kv-bits is set, begin quantizing the KV cache after " + "this many tokens (default: 5000).", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] From 059a29c9308c33364ba28f37b359baee3345c30c Mon Sep 17 00:00:00 2001 From: Daniel Molnar Date: Sun, 7 Jun 2026 19:55:06 +0200 Subject: [PATCH 2/3] server: add --max-kv-size flag Wire make_prompt_cache(max_kv_size=...) through ModelProvider so the server can bound the per-request KV cache (RotatingKVCache eviction) for standard models. Folds in the --max-kv-size flag from the closed PR #1362 (thanks @0xSoftBoi), making this PR a superset. Note: models providing their own make_cache (some hybrid architectures) ignore max_kv_size by design. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- mlx_lm/server.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 71d835cb1..1b63cada7 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -973,9 +973,15 @@ def progress(tokens_processed, tokens_total): ctx.prompt_cache_count = len(prompt) - len(rest) cache_key = prompt[:] if cache is None: - cache = make_prompt_cache(self.model_provider.model) + cache = make_prompt_cache( + self.model_provider.model, + max_kv_size=self.cli_args.max_kv_size, + ) if self.model_provider.draft_model is not None: - cache += make_prompt_cache(self.model_provider.draft_model) + cache += make_prompt_cache( + self.model_provider.draft_model, + max_kv_size=self.cli_args.max_kv_size, + ) # Process the prompt and generate tokens for gen in stream_generate( @@ -1913,6 +1919,14 @@ def main(): help="When --kv-bits is set, begin quantizing the KV cache after " "this many tokens (default: 5000).", ) + parser.add_argument( + "--max-kv-size", + type=int, + default=None, + help="Maximum KV cache size in tokens; older entries are evicted " + "beyond this limit (uses a RotatingKVCache). Ignored for models that " + "provide their own make_cache (e.g. some hybrid architectures).", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] From 6b55bb112a79d5c7132fbd2acdeb7c9e54107b2e Mon Sep 17 00:00:00 2001 From: Daniel Molnar Date: Thu, 11 Jun 2026 15:53:40 +0200 Subject: [PATCH 3/3] generate: skip RotatingKVCache when quantizing hybrid KV caches --- mlx_lm/generate.py | 13 +++++++++++-- tests/test_generate.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..f2b0b0368 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -300,8 +300,17 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ if kv_bits is None: return for e, c in enumerate(prompt_cache): - if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: - prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) + if isinstance(c, QuantizedKVCache) or not hasattr(c, "to_quantized"): + continue + if c.offset >= quantized_kv_start: + try: + prompt_cache[e] = c.to_quantized( + group_size=kv_group_size, bits=kv_bits + ) + except NotImplementedError: + # e.g. RotatingKVCache used by hybrid/sliding-window models has + # no quantized variant yet; leave that layer unquantized. + pass def generate_step( diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..96f339332 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -13,9 +13,10 @@ batch_generate, generate, generate_step, + maybe_quantize_kv_cache, stream_generate, ) -from mlx_lm.models.cache import KVCache, RotatingKVCache +from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.utils import load @@ -806,6 +807,25 @@ def test_batch_max_kv_size_none_creates_regular_cache(self): for cache in r.prompt_cache: self.assertIsInstance(cache, KVCache) + def test_maybe_quantize_kv_cache_hybrid(self): + # Hybrid cache (full-attention + sliding-window), as used by models + # like Gemma. RotatingKVCache has no quantized variant, so it must be + # skipped rather than crashing with NotImplementedError. + full = KVCache() + rotating = RotatingKVCache(max_size=8) + keys = mx.zeros((1, 1, 4, 64)) + values = mx.zeros((1, 1, 4, 64)) + full.update_and_fetch(keys, values) + rotating.update_and_fetch(keys, values) + + prompt_cache = [full, rotating] + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start=0, kv_group_size=32, kv_bits=4 + ) + + self.assertIsInstance(prompt_cache[0], QuantizedKVCache) + self.assertIsInstance(prompt_cache[1], RotatingKVCache) + if __name__ == "__main__": unittest.main()