Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,17 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
if isinstance(c, QuantizedKVCache) or not hasattr(c, "to_quantized"):
continue
if c.offset >= quantized_kv_start:
try:
prompt_cache[e] = c.to_quantized(
group_size=kv_group_size, bits=kv_bits
)
except NotImplementedError:
# e.g. RotatingKVCache used by hybrid/sliding-window models has
# no quantized variant yet; leave that layer unquantized.
pass


def generate_step(
Expand Down
47 changes: 45 additions & 2 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ def _make_state_machine(
return sm, sequences

def _is_batchable(self, args):
# The batched generator uses BatchKVCache, which has no quantized
# variant yet, so route through the single-sequence path (which does
# support KV quantization) whenever --kv-bits is requested.
if self.cli_args.kv_bits is not None:
return False
return self.model_provider.is_batchable and args.seed is None

def _generate(self):
Expand Down Expand Up @@ -968,9 +973,15 @@ def progress(tokens_processed, tokens_total):
ctx.prompt_cache_count = len(prompt) - len(rest)
cache_key = prompt[:]
if cache is None:
cache = make_prompt_cache(self.model_provider.model)
cache = make_prompt_cache(
self.model_provider.model,
max_kv_size=self.cli_args.max_kv_size,
)
if self.model_provider.draft_model is not None:
cache += make_prompt_cache(self.model_provider.draft_model)
cache += make_prompt_cache(
self.model_provider.draft_model,
max_kv_size=self.cli_args.max_kv_size,
)

# Process the prompt and generate tokens
for gen in stream_generate(
Expand All @@ -985,6 +996,9 @@ def progress(tokens_processed, tokens_total):
num_draft_tokens=args.num_draft_tokens,
prompt_progress_callback=progress,
prefill_step_size=self.cli_args.prefill_step_size,
kv_bits=self.cli_args.kv_bits,
kv_group_size=self.cli_args.kv_group_size,
quantized_kv_start=self.cli_args.quantized_kv_start,
):
finish_reason = gen.finish_reason
sm_state, match_sequence, current_state = sm.match(sm_state, gen.token)
Expand Down Expand Up @@ -1884,6 +1898,35 @@ def main():
action="store_true",
help="Use pipelining instead of tensor parallelism",
)
parser.add_argument(
"--kv-bits",
type=int,
default=None,
help="Number of bits for KV cache quantization. Defaults to no "
"quantization. Note: enabling this serves requests sequentially "
"(the batched path has no quantized cache yet).",
)
parser.add_argument(
"--kv-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64).",
)
parser.add_argument(
"--quantized-kv-start",
type=int,
default=5000,
help="When --kv-bits is set, begin quantizing the KV cache after "
"this many tokens (default: 5000).",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Maximum KV cache size in tokens; older entries are evicted "
"beyond this limit (uses a RotatingKVCache). Ignored for models that "
"provide their own make_cache (e.g. some hybrid architectures).",
)
args = parser.parse_args()
if mx.metal.is_available():
wired_limit = mx.device_info()["max_recommended_working_set_size"]
Expand Down
22 changes: 21 additions & 1 deletion tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
batch_generate,
generate,
generate_step,
maybe_quantize_kv_cache,
stream_generate,
)
from mlx_lm.models.cache import KVCache, RotatingKVCache
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import load

Expand Down Expand Up @@ -806,6 +807,25 @@ def test_batch_max_kv_size_none_creates_regular_cache(self):
for cache in r.prompt_cache:
self.assertIsInstance(cache, KVCache)

def test_maybe_quantize_kv_cache_hybrid(self):
# Hybrid cache (full-attention + sliding-window), as used by models
# like Gemma. RotatingKVCache has no quantized variant, so it must be
# skipped rather than crashing with NotImplementedError.
full = KVCache()
rotating = RotatingKVCache(max_size=8)
keys = mx.zeros((1, 1, 4, 64))
values = mx.zeros((1, 1, 4, 64))
full.update_and_fetch(keys, values)
rotating.update_and_fetch(keys, values)

prompt_cache = [full, rotating]
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start=0, kv_group_size=32, kv_bits=4
)

self.assertIsInstance(prompt_cache[0], QuantizedKVCache)
self.assertIsInstance(prompt_cache[1], RotatingKVCache)

def test_batch_generate_return_logprobs(self):
"""Test that batch_generate returns per-token logprobs when requested."""
prompts = [
Expand Down