From 10f7fac81c832386ef4fc93e9b50c3dfb13c7893 Mon Sep 17 00:00:00 2001 From: zmnobug Date: Thu, 18 Jun 2026 11:56:45 +0800 Subject: [PATCH 1/2] Add radix prefix cache backend --- python/cli/main.py | 8 + python/core/async_engine.py | 3 + python/core/kv_cache.py | 51 +++++- python/core/radix_cache.py | 305 ++++++++++++++++++++++++++++++++++++ python/core/scheduler.py | 80 +++++++++- tests/test_cli.py | 14 ++ tests/test_radix_cache.py | 214 +++++++++++++++++++++++++ 7 files changed, 671 insertions(+), 4 deletions(-) create mode 100644 python/core/radix_cache.py create mode 100644 tests/test_radix_cache.py diff --git a/python/cli/main.py b/python/cli/main.py index 32342b9..aea44b3 100644 --- a/python/cli/main.py +++ b/python/cli/main.py @@ -73,6 +73,12 @@ def build_parser() -> argparse.ArgumentParser: default=True, help="Enable prefix caching (default: True). Use --no-enable-prefix-caching to disable.", ) + parser.add_argument( + "--prefix-cache-backend", + default="hash", + choices=("hash", "radix"), + help="Prefix cache backend: hash or radix (default: hash).", + ) parser.add_argument( "--enable-chunked-prefill", action=argparse.BooleanOptionalAction, @@ -114,6 +120,7 @@ def build_serving_engine_config(args: argparse.Namespace) -> EngineConfig: long_prefill_token_threshold=args.long_prefill_token_threshold, enable_prefix_cache=args.enable_prefix_caching, enable_chunk_prefill=args.enable_chunked_prefill, + prefix_cache_backend=args.prefix_cache_backend, ) @@ -188,6 +195,7 @@ async def shutdown(): print(f" Max scheduled tokens/iter: {config.max_num_scheduled_tokens}") print(f" Chunked prefill threshold: {config.long_prefill_token_threshold}") print(f" Prefix cache: {'enabled' if config.enable_prefix_cache else 'disabled'}") + print(f" Prefix cache backend: {config.prefix_cache_backend}") print(f" Chunk prefill: {'enabled' if config.enable_chunk_prefill else 'disabled'}") print(" Endpoints: /v1/completions, /v1/chat/completions, /v1/models, /health") diff --git a/python/core/async_engine.py b/python/core/async_engine.py index 5374389..f5084a3 100644 --- a/python/core/async_engine.py +++ b/python/core/async_engine.py @@ -49,6 +49,7 @@ class EngineConfig: # Feature flags enable_prefix_cache: bool = True enable_chunk_prefill: bool = True + prefix_cache_backend: str = "hash" @dataclass @@ -102,6 +103,7 @@ def __init__( max_seq_len=runtime.max_seq_len, enable_prefix_cache=self.config.enable_prefix_cache, enable_chunk_prefill=self.config.enable_chunk_prefill, + prefix_cache_backend=self.config.prefix_cache_backend, ) self.scheduler = Scheduler(config=scheduler_config, kv_cache_manager=self.kv_cache_manager) @@ -178,6 +180,7 @@ async def add_request( request_id=request_id, prompt_token_ids=prompt_token_ids, max_new_tokens=config.max_new_tokens, + model_id=self.config.model_id, arrival_time=time.time(), stop_strings=tuple(config.stop) if config.stop else (), eos_token_id=self.eos_token_id, diff --git a/python/core/kv_cache.py b/python/core/kv_cache.py index 9d0096e..3dbadae 100644 --- a/python/core/kv_cache.py +++ b/python/core/kv_cache.py @@ -14,6 +14,7 @@ import torch +from .radix_cache import RadixKey, RadixMatch, RadixPrefixCache from .types import KvAllocation, ModelConfig, RuntimeConfig @@ -132,6 +133,7 @@ def __init__( self.free_queue = FreeKVCacheBlockQueue() self.hash_to_block: dict[int, KVCacheBlock] = {} self.request_blocks: dict[str, list[KVCacheBlock]] = {} + self.radix_cache = self._new_radix_cache() if num_blocks is not None: self._init_blocks(num_blocks, block_size) @@ -154,6 +156,7 @@ def _init_blocks(self, num_blocks: int, block_size: int) -> None: self.blocks = [KVCacheBlock(block_id=i) for i in range(num_blocks)] for block in self.blocks: self.free_queue.append(block) + self.radix_cache = self._new_radix_cache() def register_model(self, model_id: str, config: ModelConfig, runtime: RuntimeConfig) -> None: """Create the KV page pool for a model if it is not already registered.""" @@ -232,8 +235,27 @@ def allocate_block_ids(self, num_blocks: int) -> list[int] | None: def release_blocks_by_ids(self, *block_id_groups: list[int]) -> None: """Release request references for one or more groups of physical block IDs.""" for block_ids in block_id_groups: - for block_id in block_ids: - self.release(self.blocks[block_id]) + self.release_pages_from_request(block_ids) + + def retain_pages_for_request(self, page_ids: list[int]) -> None: + """Take one active-request reference on each physical KV page.""" + for page_id in page_ids: + self.retain(self.blocks[page_id]) + + def release_pages_from_request(self, page_ids: list[int]) -> None: + """Release one active-request reference from each physical KV page.""" + for page_id in page_ids: + self.release(self.blocks[page_id]) + + def retain_pages_for_cache(self, page_ids: list[int]) -> None: + """Take one prefix-cache reference on each physical KV page.""" + for page_id in page_ids: + self.retain(self.blocks[page_id]) + + def release_pages_from_cache(self, page_ids: list[int]) -> None: + """Release one prefix-cache reference from each physical KV page.""" + for page_id in page_ids: + self.release(self.blocks[page_id]) def release_cached_blocks(self, blocks: list[KVCacheBlock]) -> None: """Release cached block objects returned by ``get_computed_blocks``.""" @@ -284,6 +306,12 @@ def release(self, block: KVCacheBlock) -> None: if block.ref_cnt == 0: self.free_queue.append(block) + def retain(self, block: KVCacheBlock) -> None: + """Retain one reference to a physical KV block.""" + if block.ref_cnt == 0: + self.free_queue.remove(block) + block.ref_cnt += 1 + def _iter_block_hashes(self, token_ids: list[int]): """Yield (block_index, block_hash) for each full block in the token sequence.""" parent_hash = NONE_HASH @@ -310,6 +338,18 @@ def compute_block_hashes(self, token_ids: list[int]) -> list[int]: """Compute chained hashes for all full blocks in the token sequence.""" return [block_hash for _, block_hash in self._iter_block_hashes(token_ids)] + def match_radix_prefix(self, model_id: str, token_ids: list[int]) -> RadixMatch: + """Find and lock the longest page-aligned radix prefix for one model.""" + if not self.enable_prefix_cache: + return self.radix_cache.match(RadixKey.from_tokens([], extra_key=(model_id,))) + return self.radix_cache.match(RadixKey.from_tokens(token_ids, extra_key=(model_id,))) + + def insert_radix_prefix(self, model_id: str, token_ids: list[int], page_ids: list[int]) -> None: + """Publish a page-aligned request prefix to the radix cache.""" + if not self.enable_prefix_cache: + return + self.radix_cache.insert(RadixKey.from_tokens(token_ids, extra_key=(model_id,)), page_ids) + def ensure_one_more_slot(self, alloc: KvAllocation) -> int: """Ensure a request has capacity for one more token and return its slot.""" pool = self._pool(alloc.model_id) @@ -413,3 +453,10 @@ def _pool(self, model_id: str) -> _CachePool: if model_id not in self._pools: raise KeyError(f"Model {model_id} is not registered with the KV cache manager.") return self._pools[model_id] + + def _new_radix_cache(self) -> RadixPrefixCache: + return RadixPrefixCache( + self.block_size, + retain_pages=self.retain_pages_for_cache, + release_pages=self.release_pages_from_cache, + ) diff --git a/python/core/radix_cache.py b/python/core/radix_cache.py new file mode 100644 index 0000000..0850b40 --- /dev/null +++ b/python/core/radix_cache.py @@ -0,0 +1,305 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +from __future__ import annotations + +import time +from collections.abc import Callable +from dataclasses import dataclass, field + + +ExtraKey = tuple[str, ...] + + +@dataclass(frozen=True) +class RadixKey: + """Token prefix lookup key with an explicit cache namespace.""" + + token_ids: tuple[int, ...] + extra_key: ExtraKey = () + + @classmethod + def from_tokens(cls, token_ids: list[int] | tuple[int, ...], *, extra_key: ExtraKey = ()) -> "RadixKey": + return cls(tuple(int(token_id) for token_id in token_ids), tuple(extra_key)) + + +@dataclass +class RadixNode: + """Compressed radix-tree node whose edge stores page-aligned token and page IDs.""" + + parent: "RadixNode | None" = None + extra_key: ExtraKey = () + tokens: tuple[int, ...] = () + page_ids: list[int] = field(default_factory=list) + children: dict[object, "RadixNode"] = field(default_factory=dict) + lock_ref: int = 0 + last_access_time: float = field(default_factory=time.monotonic) + priority: int = 0 + + +@dataclass(frozen=True) +class RadixMatch: + """Longest page-aligned prefix match result.""" + + prefix_len: int + page_ids: list[int] + last_node: RadixNode + + +@dataclass(frozen=True) +class RadixInsertResult: + """Insert accounting for already-present and newly-inserted pages.""" + + existing_prefix_len: int + inserted_len: int + + +class RadixPrefixCache: + """Page-aligned radix-tree prefix cache for physical KV page IDs.""" + + def __init__( + self, + page_size: int, + *, + retain_pages: Callable[[list[int]], None] | None = None, + release_pages: Callable[[list[int]], None] | None = None, + ) -> None: + if page_size <= 0: + raise ValueError("page_size must be positive") + self.page_size = int(page_size) + self._retain_pages = retain_pages + self._release_pages = release_pages + self.root = RadixNode(priority=-(2**63)) + + def reset(self) -> None: + """Drop all radix metadata without releasing page references.""" + self.root = RadixNode(priority=-(2**63)) + + def match(self, key: RadixKey) -> RadixMatch: + """Return the longest cached prefix and lock its path against eviction.""" + tokens = self._page_aligned_tokens(key.token_ids) + if not tokens: + return RadixMatch(prefix_len=0, page_ids=[], last_node=self.root) + + node = self.root + matched_pages: list[int] = [] + offset = 0 + access_time = time.monotonic() + self.root.last_access_time = access_time + + while offset < len(tokens): + child_key = self._child_key(key.extra_key, tokens[offset:]) + child = node.children.get(child_key) + if child is None: + break + + child.last_access_time = access_time + prefix_len = self._page_aligned_len(self._common_prefix_len(child.tokens, tokens[offset:])) + if prefix_len == 0: + break + if prefix_len < len(child.tokens): + child = self._split_node(child, prefix_len) + child.last_access_time = access_time + + matched_pages.extend(child.page_ids) + offset += len(child.tokens) + node = child + + if node is not self.root: + self.inc_lock_ref(node) + return RadixMatch(prefix_len=offset, page_ids=list(matched_pages), last_node=node) + + def insert(self, key: RadixKey, page_ids: list[int], *, priority: int = 0) -> RadixInsertResult: + """Insert a page-aligned token prefix and retain cache refs for new pages.""" + tokens, page_ids = self._align_tokens_and_pages(key.token_ids, page_ids) + if not tokens: + return RadixInsertResult(existing_prefix_len=0, inserted_len=0) + + node = self.root + offset = 0 + page_offset = 0 + existing_prefix_len = 0 + access_time = time.monotonic() + node.last_access_time = access_time + node.priority = max(node.priority, priority) + + while offset < len(tokens): + child_key = self._child_key(key.extra_key, tokens[offset:]) + child = node.children.get(child_key) + if child is None: + break + + child.last_access_time = access_time + child.priority = max(child.priority, priority) + prefix_len = self._page_aligned_len(self._common_prefix_len(child.tokens, tokens[offset:])) + if prefix_len == 0: + break + if prefix_len < len(child.tokens): + child = self._split_node(child, prefix_len) + child.priority = max(child.priority, priority) + child.last_access_time = access_time + + offset += len(child.tokens) + page_offset += len(child.page_ids) + existing_prefix_len += len(child.tokens) + node = child + + inserted_len = len(tokens) - offset + if inserted_len > 0: + suffix_tokens = tokens[offset:] + suffix_pages = list(page_ids[page_offset:]) + new_node = RadixNode( + parent=node, + extra_key=key.extra_key, + tokens=suffix_tokens, + page_ids=suffix_pages, + priority=priority, + ) + node.children[self._child_key(key.extra_key, suffix_tokens)] = new_node + if suffix_pages and self._retain_pages is not None: + self._retain_pages(suffix_pages) + + return RadixInsertResult(existing_prefix_len=existing_prefix_len, inserted_len=inserted_len) + + def inc_lock_ref(self, node: RadixNode) -> None: + """Protect a matched path from radix eviction.""" + while node is not self.root: + node.lock_ref += 1 + node = node.parent + if node is None: + raise RuntimeError("radix node is detached from its root") + + def dec_lock_ref(self, node: RadixNode) -> None: + """Release one eviction lock from a matched path.""" + while node is not self.root: + if node.lock_ref <= 0: + raise RuntimeError("radix node lock_ref is already zero") + node.lock_ref -= 1 + node = node.parent + if node is None: + raise RuntimeError("radix node is detached from its root") + + def evict_pages(self, min_pages: int) -> int: + """Evict at least ``min_pages`` radix-owned pages from unlocked LRU leaves.""" + if min_pages <= 0: + return 0 + + evicted = 0 + while evicted < min_pages: + leaf = self._oldest_evictable_leaf() + if leaf is None: + break + parent = leaf.parent + if parent is None: + break + if leaf.page_ids and self._release_pages is not None: + self._release_pages(list(leaf.page_ids)) + evicted += len(leaf.page_ids) + parent.children.pop(self._child_key(leaf.extra_key, leaf.tokens), None) + leaf.parent = None + leaf.children.clear() + leaf.page_ids = [] + leaf.tokens = () + return evicted + + def total_pages(self) -> int: + return sum(len(node.page_ids) for node in self._iter_nodes() if node is not self.root) + + def protected_pages(self) -> int: + return sum( + len(node.page_ids) + for node in self._iter_nodes() + if node is not self.root and node.lock_ref > 0 + ) + + def evictable_pages(self) -> int: + return sum(len(node.page_ids) for node in self._iter_nodes() if self._is_evictable_leaf(node)) + + def _split_node(self, child: RadixNode, split_len: int) -> RadixNode: + if split_len <= 0 or split_len >= len(child.tokens): + raise ValueError("split_len must split the child edge") + if split_len % self.page_size != 0: + raise ValueError("split_len must be page-aligned") + + parent = child.parent + if parent is None: + raise RuntimeError("cannot split detached radix node") + + split_pages = split_len // self.page_size + prefix_tokens = child.tokens[:split_len] + prefix_pages = child.page_ids[:split_pages] + suffix_tokens = child.tokens[split_len:] + suffix_pages = child.page_ids[split_pages:] + + new_node = RadixNode( + parent=parent, + extra_key=child.extra_key, + tokens=prefix_tokens, + page_ids=list(prefix_pages), + lock_ref=child.lock_ref, + last_access_time=child.last_access_time, + priority=child.priority, + ) + + old_child_key = self._child_key(child.extra_key, child.tokens) + parent.children[old_child_key] = new_node + + child.parent = new_node + child.tokens = suffix_tokens + child.page_ids = list(suffix_pages) + new_node.children[self._child_key(child.extra_key, suffix_tokens)] = child + return new_node + + def _oldest_evictable_leaf(self) -> RadixNode | None: + leaves = [node for node in self._iter_nodes() if self._is_evictable_leaf(node)] + if not leaves: + return None + return min(leaves, key=lambda node: node.last_access_time) + + def _is_evictable_leaf(self, node: RadixNode) -> bool: + return node is not self.root and node.lock_ref == 0 and not node.children and bool(node.page_ids) + + def _iter_nodes(self): + stack = [self.root] + while stack: + node = stack.pop() + yield node + stack.extend(node.children.values()) + + def _align_tokens_and_pages( + self, + token_ids: tuple[int, ...], + page_ids: list[int], + ) -> tuple[tuple[int, ...], list[int]]: + tokens = self._page_aligned_tokens(token_ids) + num_pages = len(tokens) // self.page_size + if len(page_ids) < num_pages: + raise ValueError(f"page_ids has {len(page_ids)} pages, need at least {num_pages}") + return tokens, list(page_ids[:num_pages]) + + def _page_aligned_tokens(self, token_ids: tuple[int, ...]) -> tuple[int, ...]: + aligned_len = self._page_aligned_len(len(token_ids)) + return tuple(token_ids[:aligned_len]) + + def _page_aligned_len(self, length: int) -> int: + return (int(length) // self.page_size) * self.page_size + + def _child_key(self, extra_key: ExtraKey, tokens: tuple[int, ...]) -> object: + if len(tokens) < self.page_size: + raise ValueError("child key requires at least one full page") + page_key = tuple(tokens[: self.page_size]) + return (tuple(extra_key), page_key) + + @staticmethod + def _common_prefix_len(left: tuple[int, ...], right: tuple[int, ...]) -> int: + limit = min(len(left), len(right)) + idx = 0 + while idx < limit and left[idx] == right[idx]: + idx += 1 + return idx diff --git a/python/core/scheduler.py b/python/core/scheduler.py index 6f9305c..1f8d96d 100644 --- a/python/core/scheduler.py +++ b/python/core/scheduler.py @@ -9,6 +9,7 @@ from __future__ import annotations +import os import time from collections import deque from dataclasses import dataclass, field @@ -45,6 +46,7 @@ class SchedulerConfig: # Feature flags enable_prefix_cache: bool = True enable_chunk_prefill: bool = True + prefix_cache_backend: str = "hash" @dataclass @@ -52,6 +54,7 @@ class Request: request_id: str prompt_token_ids: list[int] max_new_tokens: int + model_id: str = "" arrival_time: float = field(default_factory=time.time) status: RequestStatus = RequestStatus.WAITING num_computed_tokens: int = 0 @@ -121,15 +124,22 @@ class Scheduler: def __init__(self, config: SchedulerConfig, kv_cache_manager: KvCacheManager) -> None: self.config = config + if self.config.prefix_cache_backend not in {"hash", "radix"}: + raise ValueError( + "prefix_cache_backend must be either 'hash' or 'radix', " + f"got {self.config.prefix_cache_backend!r}" + ) self.kv_cache_manager = kv_cache_manager self.waiting: deque[Request] = deque() self.running: list[Request] = [] self.requests: dict[str, Request] = {} + self._radix_nodes: dict[str, object] = {} + self._radix_debug = os.environ.get("PYPTO_RADIX_DEBUG", "").lower() not in {"", "0", "false", "no"} def add_request(self, request: Request) -> None: if len(request.prompt_token_ids) > self.config.max_seq_len: request.prompt_token_ids = request.prompt_token_ids[: self.config.max_seq_len] - if self.config.enable_prefix_cache: + if self.config.enable_prefix_cache and self.config.prefix_cache_backend == "hash": request.block_hashes = self.kv_cache_manager.compute_block_hashes(request.prompt_token_ids) request.status = RequestStatus.WAITING self.waiting.append(request) @@ -224,7 +234,10 @@ def schedule(self) -> SchedulerOutput: request = self.waiting.popleft() # Prefix cache lookup - if self.config.enable_prefix_cache: + if self.config.enable_prefix_cache and self.config.prefix_cache_backend == "radix": + self._match_radix_prefix(request) + cached_blocks = [] + elif self.config.enable_prefix_cache: cached_blocks = self.kv_cache_manager.get_computed_blocks(request.prompt_token_ids) if cached_blocks: request.cached_block_ids = [b.block_id for b in cached_blocks] @@ -238,12 +251,18 @@ def schedule(self) -> SchedulerOutput: num_new = min(num_new, token_budget) if num_new <= 0: + if self.config.prefix_cache_backend == "radix": + self.kv_cache_manager.release_pages_from_request(request.cached_block_ids) + self._release_radix_prefix(request) remaining_waiting.append(request) continue num_blocks_needed = self._blocks_needed(request, num_new) if not self._try_allocate_blocks(request, num_blocks_needed): self.kv_cache_manager.release_cached_blocks(cached_blocks) + if self.config.prefix_cache_backend == "radix": + self.kv_cache_manager.release_pages_from_request(request.cached_block_ids) + self._release_radix_prefix(request) request.cached_block_ids = [] request.num_computed_tokens = 0 remaining_waiting.append(request) @@ -261,6 +280,14 @@ def schedule(self) -> SchedulerOutput: block_ids=list(all_block_ids), ) ) + self._debug_radix( + "schedule", + request, + scheduled_prefill_tokens=num_new, + computed_tokens=request.num_computed_tokens, + cached_pages=len(request.cached_block_ids), + allocated_pages=len(request.allocated_block_ids), + ) output.num_prefill_tokens += num_new token_budget -= num_new @@ -395,6 +422,7 @@ def _free_request_blocks(self, request: Request) -> None: request.cached_block_ids, request.allocated_block_ids, ) + self._release_radix_prefix(request) request.cached_block_ids = [] request.allocated_block_ids = [] @@ -402,6 +430,20 @@ def _cache_completed_blocks(self, request: Request) -> None: """Register completed blocks in the prefix cache.""" if not self.config.enable_prefix_cache: return + if self.config.prefix_cache_backend == "radix": + all_block_ids = request.cached_block_ids + request.allocated_block_ids + self.kv_cache_manager.insert_radix_prefix( + request.model_id, + request.all_token_ids, + all_block_ids, + ) + self._debug_radix( + "insert", + request, + total_tokens=len(request.all_token_ids), + pages=len(all_block_ids), + ) + return total_blocks_computed = request.num_computed_tokens // self.kv_cache_manager.block_size already_cached = len(request.cached_block_ids) all_block_ids = request.cached_block_ids + request.allocated_block_ids @@ -411,3 +453,37 @@ def _cache_completed_blocks(self, request: Request) -> None: already_cached, total_blocks_computed, ) + + def _match_radix_prefix(self, request: Request) -> None: + """Attach a radix prefix hit to a waiting request.""" + max_match_len = max(0, request.num_prompt_tokens - 1) + match_tokens = request.prompt_token_ids[:max_match_len] + match = self.kv_cache_manager.match_radix_prefix(request.model_id, match_tokens) + self._debug_radix( + "match", + request, + prompt_tokens=request.num_prompt_tokens, + max_match_len=max_match_len, + matched_tokens=match.prefix_len, + matched_pages=len(match.page_ids), + ) + if match.prefix_len <= 0: + return + request.cached_block_ids = list(match.page_ids) + request.num_computed_tokens = match.prefix_len + self.kv_cache_manager.retain_pages_for_request(request.cached_block_ids) + self._radix_nodes[request.request_id] = match.last_node + + def _debug_radix(self, event: str, request: Request, **fields: object) -> None: + if not self._radix_debug or self.config.prefix_cache_backend != "radix": + return + details = " ".join(f"{key}={value}" for key, value in fields.items()) + print( + f"[radix] {event} request={request.request_id} model={request.model_id} {details}", + flush=True, + ) + + def _release_radix_prefix(self, request: Request) -> None: + node = self._radix_nodes.pop(request.request_id, None) + if node is not None: + self.kv_cache_manager.radix_cache.dec_lock_ref(node) diff --git a/tests/test_cli.py b/tests/test_cli.py index f88e5cb..a5ce3bf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -46,6 +46,7 @@ def test_build_serving_engine_config_uses_cli_args(tmp_path, monkeypatch): "--max-num-batched-tokens", "256", "--long-prefill-token-threshold", "64", "--no-enable-prefix-caching", + "--prefix-cache-backend", "radix", "--no-enable-chunked-prefill", ]) @@ -70,9 +71,22 @@ def test_build_serving_engine_config_uses_cli_args(tmp_path, monkeypatch): assert config.max_num_scheduled_tokens == 256 assert config.long_prefill_token_threshold == 64 assert config.enable_prefix_cache is False + assert config.prefix_cache_backend == "radix" assert config.enable_chunk_prefill is False +def test_prefix_cache_backend_defaults_to_hash_even_with_env(tmp_path, monkeypatch): + model_dir = tmp_path / "model" + model_dir.mkdir() + monkeypatch.setenv("PYPTO_PREFIX_CACHE_BACKEND", "radix") + + args = _parse_args(["--model", str(model_dir), "--backend", "npu"]) + config = cli.build_serving_engine_config(args) + + assert args.prefix_cache_backend == "hash" + assert config.prefix_cache_backend == "hash" + + def test_parser_rejects_invalid_backend(tmp_path): model_dir = tmp_path / "model" model_dir.mkdir() diff --git a/tests/test_radix_cache.py b/tests/test_radix_cache.py new file mode 100644 index 0000000..1443321 --- /dev/null +++ b/tests/test_radix_cache.py @@ -0,0 +1,214 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +import pytest + +from python.core.radix_cache import RadixKey, RadixPrefixCache +from python.core.kv_cache import KvCacheManager +from python.core.scheduler import Request, Scheduler, SchedulerConfig + + +def _key(tokens: list[int], extra_key: tuple[str, ...] = ("model",)) -> RadixKey: + return RadixKey.from_tokens(tokens, extra_key=extra_key) + + +def test_radix_cache_matches_inserted_prefix(): + cache = RadixPrefixCache(page_size=2) + + result = cache.insert(_key([1, 2, 3, 4]), [10, 11]) + match = cache.match(_key([1, 2, 3, 4, 5, 6])) + + assert result.existing_prefix_len == 0 + assert result.inserted_len == 4 + assert match.prefix_len == 4 + assert match.page_ids == [10, 11] + + +def test_radix_cache_splits_node_on_page_aligned_partial_match(): + cache = RadixPrefixCache(page_size=2) + cache.insert(_key([1, 2, 3, 4, 5, 6]), [10, 11, 12]) + + match = cache.match(_key([1, 2, 3, 4, 9, 9])) + insert = cache.insert(_key([1, 2, 3, 4, 9, 9]), [10, 11, 13]) + branch_match = cache.match(_key([1, 2, 3, 4, 9, 9, 7, 7])) + original_match = cache.match(_key([1, 2, 3, 4, 5, 6])) + + assert match.prefix_len == 4 + assert match.page_ids == [10, 11] + assert insert.existing_prefix_len == 4 + assert insert.inserted_len == 2 + assert branch_match.prefix_len == 6 + assert branch_match.page_ids == [10, 11, 13] + assert original_match.prefix_len == 6 + assert original_match.page_ids == [10, 11, 12] + + +def test_radix_cache_page_aligns_keys_and_values(): + cache = RadixPrefixCache(page_size=4) + + result = cache.insert(_key([1, 2, 3, 4, 5, 6]), [20, 21]) + match = cache.match(_key([1, 2, 3, 4, 7, 8])) + + assert result.inserted_len == 4 + assert cache.total_pages() == 1 + assert match.prefix_len == 4 + assert match.page_ids == [20] + + +def test_radix_cache_extra_key_isolates_prefixes(): + cache = RadixPrefixCache(page_size=2) + + cache.insert(_key([1, 2, 3, 4], ("model-a",)), [10, 11]) + + assert cache.match(_key([1, 2, 3, 4], ("model-a",))).page_ids == [10, 11] + assert cache.match(_key([1, 2, 3, 4], ("model-b",))).page_ids == [] + + +def test_radix_cache_retain_and_release_callbacks(): + retained: list[int] = [] + released: list[int] = [] + cache = RadixPrefixCache( + page_size=2, + retain_pages=lambda page_ids: retained.extend(page_ids), + release_pages=lambda page_ids: released.extend(page_ids), + ) + + cache.insert(_key([1, 2, 3, 4]), [10, 11]) + match = cache.match(_key([1, 2, 3, 4])) + + assert retained == [10, 11] + assert cache.protected_pages() == 2 + assert cache.evict_pages(1) == 0 + + cache.dec_lock_ref(match.last_node) + + assert cache.evict_pages(1) == 2 + assert released == [10, 11] + assert cache.total_pages() == 0 + + +def test_radix_cache_rejects_missing_page_ids(): + cache = RadixPrefixCache(page_size=2) + + with pytest.raises(ValueError, match="need at least 2"): + cache.insert(_key([1, 2, 3, 4]), [10]) + + +def test_scheduler_uses_radix_prefix_for_suffix_prefill(): + manager = KvCacheManager(num_blocks=8, block_size=2) + cached_pages = manager.allocate_block_ids(2) + assert cached_pages is not None + manager.insert_radix_prefix("model", [1, 2, 3, 4], cached_pages) + manager.release_pages_from_request(cached_pages) + + scheduler = Scheduler( + SchedulerConfig( + max_num_running_reqs=1, + max_num_scheduled_tokens=16, + max_seq_len=8, + enable_prefix_cache=True, + prefix_cache_backend="radix", + ), + manager, + ) + request = Request( + request_id="req-0", + model_id="model", + prompt_token_ids=[1, 2, 3, 4, 5], + max_new_tokens=1, + ) + + scheduler.add_request(request) + scheduled = scheduler.schedule() + + assert len(scheduled.scheduled_requests) == 1 + sr = scheduled.scheduled_requests[0] + assert sr.num_computed_tokens == 4 + assert sr.num_new_tokens == 1 + assert sr.block_ids[:2] == cached_pages + assert request.cached_block_ids == cached_pages + assert len(request.allocated_block_ids) == 1 + assert manager.blocks[cached_pages[0]].ref_cnt == 2 + + scheduler.abort_request(request.request_id) + + assert manager.blocks[cached_pages[0]].ref_cnt == 1 + + +def test_scheduler_default_hash_backend_ignores_radix_cache(): + manager = KvCacheManager(num_blocks=8, block_size=2) + cached_pages = manager.allocate_block_ids(2) + assert cached_pages is not None + manager.insert_radix_prefix("model", [1, 2, 3, 4], cached_pages) + manager.release_pages_from_request(cached_pages) + + scheduler = Scheduler( + SchedulerConfig( + max_num_running_reqs=1, + max_num_scheduled_tokens=16, + max_seq_len=8, + enable_prefix_cache=True, + ), + manager, + ) + request = Request( + request_id="req-0", + model_id="model", + prompt_token_ids=[1, 2, 3, 4, 5], + max_new_tokens=1, + ) + + scheduler.add_request(request) + scheduled = scheduler.schedule() + + assert scheduler.config.prefix_cache_backend == "hash" + assert len(scheduled.scheduled_requests) == 1 + sr = scheduled.scheduled_requests[0] + assert sr.num_computed_tokens == 0 + assert sr.num_new_tokens == 5 + assert request.cached_block_ids == [] + assert sr.block_ids[:2] != cached_pages + + scheduler.abort_request(request.request_id) + + assert manager.blocks[cached_pages[0]].ref_cnt == 1 + + +def test_scheduler_releases_radix_match_when_suffix_allocation_fails(): + manager = KvCacheManager(num_blocks=2, block_size=2) + cached_pages = manager.allocate_block_ids(2) + assert cached_pages is not None + manager.insert_radix_prefix("model", [1, 2, 3, 4], cached_pages) + manager.release_pages_from_request(cached_pages) + + scheduler = Scheduler( + SchedulerConfig( + max_num_running_reqs=1, + max_num_scheduled_tokens=16, + max_seq_len=8, + enable_prefix_cache=True, + prefix_cache_backend="radix", + ), + manager, + ) + request = Request( + request_id="req-0", + model_id="model", + prompt_token_ids=[1, 2, 3, 4, 5], + max_new_tokens=1, + ) + + scheduler.add_request(request) + scheduled = scheduler.schedule() + + assert scheduled.is_empty + assert request.cached_block_ids == [] + assert request.num_computed_tokens == 0 + assert manager.blocks[cached_pages[0]].ref_cnt == 1 + assert manager.radix_cache.protected_pages() == 0 From 107db2d90296440f1cd78ebfc5de551e8d9a3a2e Mon Sep 17 00:00:00 2001 From: zmnobug Date: Thu, 18 Jun 2026 13:16:50 +0800 Subject: [PATCH 2/2] Address radix cache review feedback --- python/core/kv_cache.py | 2 + python/core/radix_cache.py | 42 ++++++++++------- python/core/scheduler.py | 16 ++++--- tests/test_radix_cache.py | 97 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 23 deletions(-) diff --git a/python/core/kv_cache.py b/python/core/kv_cache.py index 3dbadae..00b9b95 100644 --- a/python/core/kv_cache.py +++ b/python/core/kv_cache.py @@ -209,6 +209,8 @@ def allocate_blocks(self, num_blocks: int) -> list[KVCacheBlock] | None: """Allocate physical KV blocks, evicting stale prefix hashes as needed.""" if num_blocks <= 0: return [] + if self.num_free_blocks < num_blocks: + self.radix_cache.evict_pages(num_blocks - self.num_free_blocks) if self.num_free_blocks < num_blocks: return None blocks: list[KVCacheBlock] = [] diff --git a/python/core/radix_cache.py b/python/core/radix_cache.py index 0850b40..1f58aaf 100644 --- a/python/core/radix_cache.py +++ b/python/core/radix_cache.py @@ -192,20 +192,13 @@ def evict_pages(self, min_pages: int) -> int: evicted = 0 while evicted < min_pages: - leaf = self._oldest_evictable_leaf() - if leaf is None: + leaves = self._evictable_leaves_by_lru() + if not leaves: break - parent = leaf.parent - if parent is None: - break - if leaf.page_ids and self._release_pages is not None: - self._release_pages(list(leaf.page_ids)) - evicted += len(leaf.page_ids) - parent.children.pop(self._child_key(leaf.extra_key, leaf.tokens), None) - leaf.parent = None - leaf.children.clear() - leaf.page_ids = [] - leaf.tokens = () + for leaf in leaves: + if evicted >= min_pages: + break + evicted += self._evict_leaf(leaf) return evicted def total_pages(self) -> int: @@ -256,11 +249,26 @@ def _split_node(self, child: RadixNode, split_len: int) -> RadixNode: new_node.children[self._child_key(child.extra_key, suffix_tokens)] = child return new_node - def _oldest_evictable_leaf(self) -> RadixNode | None: + def _evictable_leaves_by_lru(self) -> list[RadixNode]: leaves = [node for node in self._iter_nodes() if self._is_evictable_leaf(node)] - if not leaves: - return None - return min(leaves, key=lambda node: node.last_access_time) + leaves.sort(key=lambda node: node.last_access_time) + return leaves + + def _evict_leaf(self, leaf: RadixNode) -> int: + if not self._is_evictable_leaf(leaf): + return 0 + parent = leaf.parent + if parent is None: + return 0 + evicted = len(leaf.page_ids) + if leaf.page_ids and self._release_pages is not None: + self._release_pages(list(leaf.page_ids)) + parent.children.pop(self._child_key(leaf.extra_key, leaf.tokens), None) + leaf.parent = None + leaf.children.clear() + leaf.page_ids = [] + leaf.tokens = () + return evicted def _is_evictable_leaf(self, node: RadixNode) -> bool: return node is not self.root and node.lock_ref == 0 and not node.children and bool(node.page_ids) diff --git a/python/core/scheduler.py b/python/core/scheduler.py index 1f8d96d..79cdffb 100644 --- a/python/core/scheduler.py +++ b/python/core/scheduler.py @@ -369,8 +369,6 @@ def _blocks_needed(self, request: Request, num_new_tokens: int) -> int: def _try_allocate_blocks(self, request: Request, num_blocks: int) -> bool: if num_blocks <= 0: return True - if self.kv_cache_manager.num_free_blocks < num_blocks: - return False block_ids = self.kv_cache_manager.allocate_block_ids(num_blocks) if block_ids is None: return False @@ -432,16 +430,22 @@ def _cache_completed_blocks(self, request: Request) -> None: return if self.config.prefix_cache_backend == "radix": all_block_ids = request.cached_block_ids + request.allocated_block_ids + completed_len = min(request.num_computed_tokens, len(request.all_token_ids)) + completed_pages = min( + completed_len // self.kv_cache_manager.block_size, + len(all_block_ids), + ) + completed_len = completed_pages * self.kv_cache_manager.block_size self.kv_cache_manager.insert_radix_prefix( request.model_id, - request.all_token_ids, - all_block_ids, + request.all_token_ids[:completed_len], + all_block_ids[:completed_pages], ) self._debug_radix( "insert", request, - total_tokens=len(request.all_token_ids), - pages=len(all_block_ids), + total_tokens=completed_len, + pages=completed_pages, ) return total_blocks_computed = request.num_computed_tokens // self.kv_cache_manager.block_size diff --git a/tests/test_radix_cache.py b/tests/test_radix_cache.py index 1443321..071cc9b 100644 --- a/tests/test_radix_cache.py +++ b/tests/test_radix_cache.py @@ -93,6 +93,17 @@ def test_radix_cache_retain_and_release_callbacks(): assert cache.total_pages() == 0 +def test_radix_cache_evicts_parent_after_leaf_batch(): + released: list[int] = [] + cache = RadixPrefixCache(page_size=2, release_pages=lambda page_ids: released.extend(page_ids)) + cache.insert(_key([1, 2, 3, 4]), [10, 11]) + cache.insert(_key([1, 2, 5, 6]), [10, 12]) + + assert cache.evict_pages(3) == 3 + assert sorted(released) == [10, 11, 12] + assert cache.total_pages() == 0 + + def test_radix_cache_rejects_missing_page_ids(): cache = RadixPrefixCache(page_size=2) @@ -100,6 +111,23 @@ def test_radix_cache_rejects_missing_page_ids(): cache.insert(_key([1, 2, 3, 4]), [10]) +def test_radix_owned_pages_are_evicted_for_new_allocations(): + manager = KvCacheManager(num_blocks=2, block_size=2) + cached_pages = manager.allocate_block_ids(2) + assert cached_pages is not None + manager.insert_radix_prefix("model", [1, 2, 3, 4], cached_pages) + manager.release_pages_from_request(cached_pages) + + assert manager.num_free_blocks == 0 + + allocated = manager.allocate_block_ids(1) + + assert allocated is not None + assert len(allocated) == 1 + assert manager.blocks[allocated[0]].ref_cnt == 1 + assert manager.radix_cache.total_pages() == 0 + + def test_scheduler_uses_radix_prefix_for_suffix_prefill(): manager = KvCacheManager(num_blocks=8, block_size=2) cached_pages = manager.allocate_block_ids(2) @@ -141,6 +169,40 @@ def test_scheduler_uses_radix_prefix_for_suffix_prefill(): assert manager.blocks[cached_pages[0]].ref_cnt == 1 +def test_scheduler_allocation_can_evict_unlocked_radix_pages(): + manager = KvCacheManager(num_blocks=2, block_size=2) + cached_pages = manager.allocate_block_ids(2) + assert cached_pages is not None + manager.insert_radix_prefix("model", [1, 2, 3, 4], cached_pages) + manager.release_pages_from_request(cached_pages) + + scheduler = Scheduler( + SchedulerConfig( + max_num_running_reqs=1, + max_num_scheduled_tokens=16, + max_seq_len=8, + enable_prefix_cache=True, + prefix_cache_backend="radix", + ), + manager, + ) + request = Request( + request_id="req-0", + model_id="model", + prompt_token_ids=[9, 9], + max_new_tokens=1, + ) + + scheduler.add_request(request) + scheduled = scheduler.schedule() + + assert len(scheduled.scheduled_requests) == 1 + assert request.allocated_block_ids + assert manager.radix_cache.total_pages() == 0 + + scheduler.abort_request(request.request_id) + + def test_scheduler_default_hash_backend_ignores_radix_cache(): manager = KvCacheManager(num_blocks=8, block_size=2) cached_pages = manager.allocate_block_ids(2) @@ -180,6 +242,41 @@ def test_scheduler_default_hash_backend_ignores_radix_cache(): assert manager.blocks[cached_pages[0]].ref_cnt == 1 +def test_scheduler_radix_insert_uses_only_completed_full_pages_for_chunked_prefill(): + manager = KvCacheManager(num_blocks=8, block_size=2) + scheduler = Scheduler( + SchedulerConfig( + max_num_running_reqs=1, + max_num_scheduled_tokens=16, + long_prefill_token_threshold=3, + max_seq_len=8, + enable_prefix_cache=True, + enable_chunk_prefill=True, + prefix_cache_backend="radix", + ), + manager, + ) + request = Request( + request_id="req-0", + model_id="model", + prompt_token_ids=[1, 2, 3, 4, 5, 6], + max_new_tokens=1, + ) + + scheduler.add_request(request) + scheduled = scheduler.schedule() + scheduler.update_from_output(scheduled, {}) + + assert request.num_computed_tokens == 3 + assert manager.radix_cache.total_pages() == 1 + match = manager.radix_cache.match(_key([1, 2, 3, 4])) + assert match.prefix_len == 2 + assert len(match.page_ids) == 1 + + manager.radix_cache.dec_lock_ref(match.last_node) + scheduler.abort_request(request.request_id) + + def test_scheduler_releases_radix_match_when_suffix_allocation_fails(): manager = KvCacheManager(num_blocks=2, block_size=2) cached_pages = manager.allocate_block_ids(2)