Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
PromptProgressReporter,
StopPromptProcessing,
)
from mlx_engine.utils.turboquant import (
TurboQuantKVCache,
make_wh_rotation,
lloyd_max_centroids,
)


PROMPT_PROCESSING_CHUNK_SIZE = 2048
Expand Down Expand Up @@ -201,7 +206,11 @@ def _prefill_cache(

current_chunk = remaining_tokens[:current_chunk_size]
model(current_chunk[None], cache=cache)
maybe_quantize_kv_cache(prompt_cache=cache, **self._kv_cache_qtn_params)
kv_bits = self._kv_cache_qtn_params.get("kv_bits")
if isinstance(kv_bits, str) and kv_bits.startswith("turbo"):
_apply_turboquant_to_cache(cache, kv_bits)
else:
maybe_quantize_kv_cache(prompt_cache=cache, **self._kv_cache_qtn_params)
self._live_cache[cache_start : cache_start + len(cache)] = cache
mx.eval([entry.state for entry in cache])

Expand All @@ -226,8 +235,34 @@ def _prefill_cache(
logger.info("Prompt processing was cancelled by the user.")
live_cache_size = self._num_tokens_in_cache()
if live_cache_size is None:
self._live_tokens = None
self._live_cache = self._make_cache()
self._live_tokens = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore cancellation branch indentation

Importing cache_wrapper.py now fails with IndentationError: expected an indented block after 'if' statement on line 237 because the body of if live_cache_size is None: was dedented out of _prefill_cache. This blocks the engine from starting for any configuration, before TurboQuant is even selected.

Useful? React with 👍 / 👎.

self._live_cache = self._make_cache()


def _apply_turboquant_to_cache(cache: List[Any], turbo_mode: str) -> None:
"""Apply TurboQuant PolarQuant compression to all cache entries.

Args:
cache: List of KVCache entries from the model.
turbo_mode: One of "turbo2", "turbo3", "turbo4".
"""
bit_map = {"turbo2": 2, "turbo3": 3, "turbo4": 4}
bits = bit_map.get(turbo_mode, 3)

for i, entry in enumerate(cache):
if hasattr(entry, "state") and entry.state is not None:
keys, values = entry.state
if keys is None:
continue
head_dim = keys.shape[-1]
tq = TurboQuantKVCache(head_dim=head_dim, k_bits=bits, v_bits=bits)
k_idx, v_idx, k_norm, v_norm = tq.quantize_kv(keys, values)
entry.keys = k_idx
entry.values = None
entry._turbo_k_norms = k_norm
entry._turbo_v_norms = v_norm
entry._turbo_head_dim = head_dim
entry._turbo_bits = bits
else:
self._live_tokens = self._live_tokens[:live_cache_size]
raise StopPromptProcessing
Expand Down
20 changes: 17 additions & 3 deletions mlx_engine/utils/kv_cache_quantization.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

# https://github.com/ml-explore/mlx/blob/f288db8d34c0bcfa0867b6458ab0277c5e86ed45/mlx/fast.cpp#L782
VALID_KV_BITS = (2, 3, 4, 6, 8)
# https://github.com/ml-explore/mlx/blob/f288db8d34c0bcfa0867b6458ab0277c5e86ed45/mlx/fast.cpp#L775
VALID_KV_GROUP_SIZE = (32, 64, 128)

# TurboQuant KV cache compression modes
# These use PolarQuant + WHT rotation instead of standard MLX quantization
# See mlx_engine/utils/turboquant.py for the implementation
VALID_TURBO_MODES = ("turbo2", "turbo3", "turbo4")


def get_kv_cache_quantization_params(
kv_bits: Optional[int],
kv_bits: Optional[Union[int, str]],
kv_group_size: Optional[int],
quantized_kv_start: Optional[int],
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
) -> Tuple[Optional[Union[int, str]], Optional[int], Optional[int]]:
"""
Validates and processes KV cache quantization parameters.

Args:
kv_bits: Number of bits for quantization. If None, disables quantization.
Can also be a string like "turbo3" to use TurboQuant compression.
kv_group_size: Group size for quantization. Defaults to 64 if quantization enabled.
quantized_kv_start: Step to begin quantization. Defaults to 0 if quantization enabled.

Expand All @@ -31,6 +37,14 @@ def get_kv_cache_quantization_params(
if kv_bits is None:
return None, None, None

# Check for TurboQuant string modes
if isinstance(kv_bits, str):
if kv_bits not in VALID_TURBO_MODES:
raise ValueError(
f"Invalid turbo mode '{kv_bits}'. Must be one of {VALID_TURBO_MODES}"
)
return kv_bits, kv_group_size, quantized_kv_start

# defaults taken from here:
# https://github.com/ml-explore/mlx-examples/blob/3d793ec/llms/mlx_lm/utils.py#L352-L353
if kv_group_size is None:
Expand Down
Loading
Loading