diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 5b172c43..381f8be9 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -65,6 +65,7 @@ def __init__( quantized_kv_start=quantized_kv_start, ) + self._history = self._make_history() self._history_key = "session" self._live_tokens: Optional[mx.array] = None @@ -90,6 +91,10 @@ def _num_tokens_in_cache(self, cache: Optional[List[Any]] = None) -> int | None: for entry in cache: if hasattr(entry, "offset"): return entry.offset + # Fallback: use the length of tracked tokens if cache offset is unavailable + # This handles models where cache layers don't expose offset (e.g., GPT-OSS) + if self.tokens is not None: + return len(self.tokens) return None def _store_snapshot( diff --git a/mlx_engine/model_kit/batched_model_kit.py b/mlx_engine/model_kit/batched_model_kit.py index 4dcd21eb..bbf44c6b 100644 --- a/mlx_engine/model_kit/batched_model_kit.py +++ b/mlx_engine/model_kit/batched_model_kit.py @@ -340,8 +340,11 @@ def get_next_request(timeout=None): ) # Track this request + # Use separate tracking for cross-prompt cache key (original prompt only) + # and live cache key (updated during generation for intra-request caching) self._batch_results[uid] = { - "cache_key": request.prompt_tokens[:], + "cross_prompt_cache_key": request.prompt_tokens[:], + "live_cache_key": request.prompt_tokens[:], "rqueue": request.rqueue, "detokenizer": self.tokenizer.detokenizer, "top_logprobs": request.top_logprobs, @@ -375,7 +378,7 @@ def get_next_request(timeout=None): # Create response object result = self._batch_results[r.uid] detokenizer = result["detokenizer"] - result["cache_key"].append(r.token) + result["live_cache_key"].append(r.token) if r.finish_reason != "stop": detokenizer.add_token(r.token) if r.finish_reason is not None: @@ -417,8 +420,13 @@ def get_next_request(timeout=None): # Clean up if necessary if r.finish_reason is not None: result["rqueue"].put(None) + # Use cross_prompt_cache_key for cross-prompt caching + # This ensures the cache is keyed by the original prompt, + # not by the prompt + generated tokens self._prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, + result["cross_prompt_cache_key"], + r.prompt_cache, ) del self._batch_results[r.uid] diff --git a/tests/test_cache_wrapper.py b/tests/test_cache_wrapper.py index df3693cf..48a88ce7 100644 --- a/tests/test_cache_wrapper.py +++ b/tests/test_cache_wrapper.py @@ -341,6 +341,21 @@ def quantize_same_prompt_cache(prompt_cache, **_): _, reporter = self._run_update_cache(session, prompt) self.assertEqual(reporter.events[0]["cached_tokens"], 0) + def test_get_num_tokens_in_cache_without_offset(self): + """Test that _get_num_tokens_in_cache falls back to len(self.tokens) when offset is unavailable""" + mock_cache = [object() for _ in range(10)] + + wrapper = object.__new__(CacheWrapper) + wrapper.cache = mock_cache + wrapper.tokens = mx.array([1, 2, 3, 4, 5]) + + result = wrapper._get_num_tokens_in_cache() + self.assertEqual(result, 5) + + wrapper.tokens = None + result = wrapper._get_num_tokens_in_cache() + self.assertIsNone(result) + if __name__ == "__main__": unittest.main(verbosity=2)