diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..b9416ccf9 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1740,7 +1740,7 @@ def run( handler_class=APIHandler, ): group = mx.distributed.init() - prompt_cache = LRUPromptCache(model_provider.cli_args.prompt_cache_size) + prompt_cache = make_lru_prompt_cache(model_provider.cli_args) response_generator = ResponseGenerator(model_provider, prompt_cache) if group.rank() == 0: _run_http_server(host, port, response_generator) @@ -1748,6 +1748,13 @@ def run( response_generator.join() +def make_lru_prompt_cache(args): + max_bytes = getattr(args, "prompt_cache_bytes", None) + if max_bytes is None: + max_bytes = 1 << 63 + return LRUPromptCache(args.prompt_cache_size, max_bytes=max_bytes) + + def main(): parser = argparse.ArgumentParser(description="MLX Http Server.") parser.add_argument( diff --git a/tests/test_server.py b/tests/test_server.py index 9a8a2ad14..c3a465ffa 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -17,6 +17,7 @@ Response, ResponseGenerator, _process_control_tokens, + make_lru_prompt_cache, ) from mlx_lm.utils import load @@ -518,6 +519,19 @@ def keepalive_callback(processed_tokens, total_tokens): class TestLRUPromptCache(unittest.TestCase): + def test_server_prompt_cache_uses_byte_limit(self): + args = types.SimpleNamespace(prompt_cache_size=100, prompt_cache_bytes=10) + cache = make_lru_prompt_cache(args) + model = ("test", None, None) + + cache.insert_cache(model, [1, 2], [MockCache("aaa")]) + cache.insert_cache(model, [3, 4], [MockCache("bbb")]) + cache.insert_cache(model, [4, 5], [MockCache("ccc")]) + cache.insert_cache(model, [6, 7], [MockCache("ddd")]) + + self.assertEqual(len(cache), 3) + self.assertEqual(cache.nbytes, 9) + def test_caching(self): cache = LRUPromptCache(max_size=10)