Skip to content
87 changes: 63 additions & 24 deletions omlx/cache/paged_ssd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,15 @@ def _hot_cache_put(self, block_hash: bytes, entry: Dict) -> None:
self._hot_cache_total_bytes += entry_size

# Flush evicted entries to SSD outside the hot cache lock
for evicted_hash, evicted in evicted_entries:
self._enqueue_ssd_write(evicted_hash, evicted)
# Skip SSD write when hot_cache_only=True (evicted entries are discarded)
if not self._hot_cache_only:
for evicted_hash, evicted in evicted_entries:
self._enqueue_ssd_write(evicted_hash, evicted)
elif evicted_entries:
logger.debug(
f"Discarding {len(evicted_entries)} evicted blocks "
f"(hot_cache_only=True)"
)

def _enqueue_ssd_write(self, block_hash: bytes, entry: Dict) -> bool:
"""Enqueue a hot cache entry for SSD background write.
Expand Down Expand Up @@ -792,9 +799,7 @@ def _hot_cache_remove(self, block_hash: bytes) -> None:
with self._hot_cache_lock:
old = self._hot_cache.pop(block_hash, None)
if old:
self._hot_cache_total_bytes -= self._hot_cache_entry_size(
old['tensors_raw']
)
self._hot_cache_total_bytes -= self._hot_cache_entry_size(old)

def _promote_to_hot_cache(
self,
Expand Down Expand Up @@ -1244,6 +1249,21 @@ def save_block(
'block_metadata': block_metadata,
}

if self._hot_cache_only:
# Hot cache only mode: store mx.array directly (not tensors_raw).
# This avoids memory doubling on cache hit - we reuse the same
# GPU memory instead of creating new mx.array objects.
cache_entry = {
"arrays": arrays,
"file_metadata": metadata,
"num_layers": len(cache_data),
"layer_cache_types": layer_cache_types,
"block_metadata": block_metadata,
}
self._hot_cache_put(block_hash, cache_entry)
self._stats["saves"] += 1
return True

if self._hot_cache_enabled:
# Write-back mode: store only in hot cache, no SSD index entry.
# SSD index entry is created later when block is evicted or
Expand All @@ -1252,10 +1272,6 @@ def save_block(
self._stats["saves"] += 1
return True

if self._hot_cache_only:
# Hot cache disabled but hot_cache_only set: block is not retained.
return False

# SSD path: add to index for SSD file tracking
self._index.add(block_metadata)

Expand Down Expand Up @@ -1455,14 +1471,24 @@ def load_block(
# Check hot cache first (in-memory, no I/O)
entry = self._hot_cache_get(block_hash)
if entry is not None:
# Entries from _promote_to_hot_cache() store mx.array objects directly
# (safe — they come from SSD loads, not active inference).
# Entries from save_block() use tensors_raw (raw bytes).
arrays = entry.get('arrays') or self._arrays_from_tensors_raw(entry['tensors_raw'])
cache_data = self._reconstruct_cache_data(
arrays, entry['file_metadata'],
entry['num_layers'], entry['layer_cache_types'],
)
# Check for "arrays" first (direct mx.array storage).
# This is the fast path for hot_cache_only mode.
if "arrays" in entry:
cache_data = self._reconstruct_cache_data(
entry["arrays"],
entry["file_metadata"],
entry["num_layers"],
entry["layer_cache_types"],
)
else:
# Fallback: tensors_raw path
arrays = self._arrays_from_tensors_raw(entry["tensors_raw"])
cache_data = self._reconstruct_cache_data(
arrays,
entry["file_metadata"],
entry["num_layers"],
entry["layer_cache_types"],
)
if cache_data is not None:
self._index.touch(block_hash)
self._stats["loads"] += 1
Expand Down Expand Up @@ -1574,12 +1600,25 @@ def load_block_with_metadata(
# Check hot cache first (in-memory, no I/O)
entry = self._hot_cache_get(block_hash)
if entry is not None:
blk_meta = entry['block_metadata']
arrays = entry.get('arrays') or self._arrays_from_tensors_raw(entry['tensors_raw'])
cache_data = self._reconstruct_cache_data(
arrays, entry['file_metadata'],
entry['num_layers'], entry['layer_cache_types'],
)
blk_meta = entry["block_metadata"]
# Check for "arrays" first (direct mx.array storage)
# This is the fast path for hot_cache_only mode.
if "arrays" in entry:
cache_data = self._reconstruct_cache_data(
entry["arrays"],
entry["file_metadata"],
entry["num_layers"],
entry["layer_cache_types"],
)
else:
# Fallback: tensors_raw path
arrays = self._arrays_from_tensors_raw(entry["tensors_raw"])
cache_data = self._reconstruct_cache_data(
arrays,
entry["file_metadata"],
entry["num_layers"],
entry["layer_cache_types"],
)
if cache_data is None:
return None, None

Expand Down Expand Up @@ -1938,7 +1977,7 @@ def _matches(candidate: str) -> bool:
if blk_meta is None or not _matches(blk_meta.model_name):
continue
hot_entries.append(entry)
hot_size += self._hot_cache_entry_size(entry["tensors_raw"])
hot_size += self._hot_cache_entry_size(entry)

return PagedSSDCacheStats(
hits=self._stats["hits"],
Expand Down
36 changes: 14 additions & 22 deletions omlx/cache/prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,12 +856,10 @@ def _extract_block_tensor_slice(
block_slices.append(
(self._clone_tensor(keys_slice), self._clone_tensor(values_slice))
)
elif cache_type_name == 'RotatingKVCache':
# RotatingKVCache: last-block-only or boundary-snapshot strategy
has_valid_state = is_last_block or (
snapshot_cache_data is not None
and layer_idx < len(snapshot_cache_data)
)
elif cache_type_name == "RotatingKVCache":
# RotatingKVCache: Always store actual data for all blocks.
# This enables walk-back during restore without boundary snapshots.
has_valid_state = True
if has_valid_state:
# Use snapshot state if available, otherwise use main state
if (
Expand Down Expand Up @@ -948,15 +946,10 @@ def _extract_block_tensor_slice(
block_slices.append((mx.zeros((1,)), mx.zeros((1,))))
else:
# Other non-sliceable cache (ArraysCache/MambaCache)
# GDN recurrent state summarizes the ENTIRE sequence in a
# fixed-size matrix. Each block boundary snapshot captures
# the state at that point in the sequence. Without a snapshot,
# non-last blocks get a placeholder so partial matches are
# detected and rejected during reconstruction.
has_valid_state = is_last_block or (
snapshot_cache_data is not None
and layer_idx < len(snapshot_cache_data)
)
# GDN recurrent state summarizes the ENTIRE sequence.
# Always store actual data for all blocks to enable walk-back
# during restore without boundary snapshots.
has_valid_state = True
if has_valid_state:
# Use snapshot state if available, otherwise main state
if (
Expand Down Expand Up @@ -1583,19 +1576,18 @@ def reconstruct_cache(
if isinstance(ms, (list, tuple)) and len(ms) >= 3:
tq_bits = float(ms[1])
tq_seed = int(ms[2])
# Dequantize back to fp16 KVCache for merge compatibility.
# Keep TurboQuantKVCache in quantized form to avoid memory doubling.
# TQ will be re-applied at decode start (lazy quantization).
# This avoids the dequantize() step which creates FP16 copy.
tq = TurboQuantKVCache(bits=tq_bits, seed=tq_seed)
tq.keys = cat_ks
tq.values = cat_vs
tq.offset = _state_length(cat_ks)
_rebuild_codecs(tq, cat_ks, cat_vs)
keys, values = tq.dequantize()
cache = KVCache()
cache.keys = keys
cache.values = values
cache.offset = keys.shape[2]
reconstructed_caches.append(cache)

# Use TurboQuantKVCache directly without dequantizing
# This keeps memory at quantized level (vs 16-bit)
reconstructed_caches.append(tq)
except Exception as e:
logger.error(f"TQ layer {layer_idx}: reconstruction failed: {e}")
return None
Expand Down
14 changes: 13 additions & 1 deletion omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,8 @@ def _do_external_prefill(
# Boundary snapshot setup
block_size = self.config.paged_cache_block_size
boundary_enabled = (
block_size > 0
not self.config.hot_cache_only
and block_size > 0
and self.block_aware_cache is not None
and _prompt_cache_needs_snapshots(prompt_cache)
)
Expand Down Expand Up @@ -3764,6 +3765,17 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
f"{len(request.prompt_token_ids)} prompt + "
f"{len(request.output_token_ids)} output)"
)
# Immediately release _extracted_cache to free copy #1
# (store_cache already cloned to PagedCache blocks)
request._extracted_cache = None

# Clear boundary snapshots for this request after store to prevent memory leak.
# Boundary snapshots were needed for proper block storage but are no longer needed.
if request_id in self._boundary_cache_snapshots:
del self._boundary_cache_snapshots[request_id]
logger.debug(
f"Cleared boundary snapshots for request {request_id}"
)
except Exception as e:
logger.debug(f"Failed to submit async store for {request_id}: {e}")
else:
Expand Down
75 changes: 75 additions & 0 deletions tests/test_hot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,81 @@ def save(idx):
mgr.close()


@pytest.mark.skipif(not HAS_MLX, reason="MLX not available")
class TestHotCacheOnlyMode:
"""Contrast behavior when hot_cache_only=True vs False."""

def test_eviction_discard_vs_write(self, tmp_path):
"""hot_cache_only=True should discard evicted blocks, False should write to SSD."""
block_hash = b"evict_contrast_hash"
cache_data = [(mx.zeros((1, 4, 16, 16)), mx.zeros((1, 4, 16, 16)))]

# Case 1: hot_cache_only = True (Discard)
mgr_true = PagedSSDCacheManager(
cache_dir=tmp_path / "evict_true",
max_size_bytes=1024**3,
hot_cache_max_bytes=100, # Force immediate eviction
hot_cache_only=True,
)
mgr_true.save_block(block_hash, cache_data, 16, layer_cache_types=["KVCache"])
# Save another block to trigger eviction of first
mgr_true.save_block(
b"evict_trigger", cache_data, 16, layer_cache_types=["KVCache"]
)

# No SSD files should exist
ssd_files = list((tmp_path / "evict_true").rglob("*.safetensors"))
assert len(ssd_files) == 0
mgr_true.close()

# Case 2: hot_cache_only = False (Write to SSD)
mgr_false = PagedSSDCacheManager(
cache_dir=tmp_path / "evict_false",
max_size_bytes=1024**3,
hot_cache_max_bytes=100, # Force immediate eviction
hot_cache_only=False,
)
mgr_false.save_block(block_hash, cache_data, 16, layer_cache_types=["KVCache"])
# Save another block to trigger eviction
mgr_false.save_block(
b"evict_trigger", cache_data, 16, layer_cache_types=["KVCache"]
)

# Wait for background writer
import time

time.sleep(0.5)
ssd_files = list((tmp_path / "evict_false").rglob("*.safetensors"))
assert len(ssd_files) >= 1
mgr_false.close()

def test_entry_size_calculation(self):
"""Verify size calculation for both array and tensor_raw formats."""
mgr = PagedSSDCacheManager(
cache_dir=Path("/tmp/size_test"), max_size_bytes=1024**3
)

# 1. Array format
arrays = {"k": mx.zeros((1, 8, 16, 16)), "v": mx.zeros((1, 8, 16, 16))}
entry_arrays = {"arrays": arrays}
# 1*8*16*16 * 4 bytes * 2 tensors = 16384 bytes
expected_size = 1 * 8 * 16 * 16 * 4 * 2
assert mgr._hot_cache_entry_size(entry_arrays) == expected_size

# 2. Tensors_raw format
raw_data = bytes(1024)
tensors_raw = {
"k": (raw_data, "F32", [1, 8, 16, 16]),
"v": (raw_data, "F32", [1, 8, 16, 16]),
}
entry_raw = {"tensors_raw": tensors_raw}
# 1024 * 2 = 2048 bytes
assert mgr._hot_cache_entry_size(entry_raw) == 2048

# 3. Empty/Unknown
assert mgr._hot_cache_entry_size({}) == 0


@pytest.mark.skipif(not HAS_MLX, reason="MLX not available")
class TestHotCacheWriteBack:
"""Test write-back behavior: no SSD writes until eviction or shutdown."""
Expand Down
12 changes: 6 additions & 6 deletions tests/test_hybrid_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ def test_rotating_non_last_block_placeholder(self, prefix_cache):
keys0, values0 = result[0]
assert keys0.shape == (1, 8, 4, 64)

# RotatingKVCache layer: placeholder
# RotatingKVCache layer: full state (always stored to enable walk-back)
keys1, values1 = result[1]
assert keys1.shape == (1,)
assert values1.shape == (1,)
assert keys1.shape == (1, 8, 256, 64)
assert values1.shape == (1, 8, 256, 64)

def test_rotating_last_block_full_state(self, prefix_cache):
"""Test RotatingKVCache last block stores full state."""
Expand Down Expand Up @@ -349,15 +349,15 @@ def test_hybrid_model_multiple_blocks(self, prefix_cache):
)
assert block0 is not None
assert block0[0][0].shape == (1, 8, 4, 64) # KVCache slice
assert block0[1][0].shape == (1,) # RotatingKVCache placeholder
assert block0[1][0].shape == (1, 8, 256, 64) # RotatingKVCache full state

# Block 1 (non-last): KVCache sliced, RotatingKVCache placeholder
# Block 1 (non-last): KVCache sliced, RotatingKVCache full state
block1 = prefix_cache._extract_block_tensor_slice(
cache_data, 4, 8, model_cache_config=config, is_last_block=False
)
assert block1 is not None
assert block1[0][0].shape == (1, 8, 4, 64)
assert block1[1][0].shape == (1,)
assert block1[1][0].shape == (1, 8, 256, 64)

# Block 3 (last): KVCache sliced, RotatingKVCache full state
block3 = prefix_cache._extract_block_tensor_slice(
Expand Down
Loading