From 8dd3010744aaffbdf97e9cfec8b3e05aac39fe63 Mon Sep 17 00:00:00 2001 From: Agisilaos Tsarampoulidis Date: Thu, 11 Jun 2026 07:42:24 +0200 Subject: [PATCH] fix(server): honor prompt cache byte limit What changed: - Construct the server prompt cache through a helper that passes prompt_cache_bytes into LRUPromptCache as max_bytes. - Add a regression test showing a server-created prompt cache evicts entries by byte budget. Why: - --prompt-cache-bytes was parsed and partially used in batching but not applied to the LRU cache itself. - Sequential serving could therefore keep prompt caches unbounded by bytes and contribute to Metal OOM aborts as long prompts accumulated. Alternatives considered: - Trimming only in the sequential serve path was considered, but wiring the limit into LRUPromptCache keeps enforcement centralized for every insertion path. --- mlx_lm/server.py | 9 ++++++++- tests/test_server.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..b9416ccf9 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1740,7 +1740,7 @@ def run( handler_class=APIHandler, ): group = mx.distributed.init() - prompt_cache = LRUPromptCache(model_provider.cli_args.prompt_cache_size) + prompt_cache = make_lru_prompt_cache(model_provider.cli_args) response_generator = ResponseGenerator(model_provider, prompt_cache) if group.rank() == 0: _run_http_server(host, port, response_generator) @@ -1748,6 +1748,13 @@ def run( response_generator.join() +def make_lru_prompt_cache(args): + max_bytes = getattr(args, "prompt_cache_bytes", None) + if max_bytes is None: + max_bytes = 1 << 63 + return LRUPromptCache(args.prompt_cache_size, max_bytes=max_bytes) + + def main(): parser = argparse.ArgumentParser(description="MLX Http Server.") parser.add_argument( diff --git a/tests/test_server.py b/tests/test_server.py index 9a8a2ad14..c3a465ffa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -17,6 +17,7 @@ Response, ResponseGenerator, _process_control_tokens, + make_lru_prompt_cache, ) from mlx_lm.utils import load @@ -518,6 +519,19 @@ def keepalive_callback(processed_tokens, total_tokens): class TestLRUPromptCache(unittest.TestCase): + def test_server_prompt_cache_uses_byte_limit(self): + args = types.SimpleNamespace(prompt_cache_size=100, prompt_cache_bytes=10) + cache = make_lru_prompt_cache(args) + model = ("test", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("aaa")]) + cache.insert_cache(model, [3, 4], [MockCache("bbb")]) + cache.insert_cache(model, [4, 5], [MockCache("ccc")]) + cache.insert_cache(model, [6, 7], [MockCache("ddd")]) + + self.assertEqual(len(cache), 3) + self.assertEqual(cache.nbytes, 9) + def test_caching(self): cache = LRUPromptCache(max_size=10)