Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def build_parser() -> argparse.ArgumentParser:
default=True,
help="Enable prefix caching (default: True). Use --no-enable-prefix-caching to disable.",
)
parser.add_argument(
"--prefix-cache-backend",
default="hash",
choices=("hash", "radix"),
help="Prefix cache backend: hash or radix (default: hash).",
)
parser.add_argument(
"--enable-chunked-prefill",
action=argparse.BooleanOptionalAction,
Expand Down Expand Up @@ -114,6 +120,7 @@ def build_serving_engine_config(args: argparse.Namespace) -> EngineConfig:
long_prefill_token_threshold=args.long_prefill_token_threshold,
enable_prefix_cache=args.enable_prefix_caching,
enable_chunk_prefill=args.enable_chunked_prefill,
prefix_cache_backend=args.prefix_cache_backend,
)


Expand Down Expand Up @@ -188,6 +195,7 @@ async def shutdown():
print(f" Max scheduled tokens/iter: {config.max_num_scheduled_tokens}")
print(f" Chunked prefill threshold: {config.long_prefill_token_threshold}")
print(f" Prefix cache: {'enabled' if config.enable_prefix_cache else 'disabled'}")
print(f" Prefix cache backend: {config.prefix_cache_backend}")
print(f" Chunk prefill: {'enabled' if config.enable_chunk_prefill else 'disabled'}")
print(" Endpoints: /v1/completions, /v1/chat/completions, /v1/models, /health")

Expand Down
3 changes: 3 additions & 0 deletions python/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class EngineConfig:
# Feature flags
enable_prefix_cache: bool = True
enable_chunk_prefill: bool = True
prefix_cache_backend: str = "hash"


@dataclass
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
max_seq_len=runtime.max_seq_len,
enable_prefix_cache=self.config.enable_prefix_cache,
enable_chunk_prefill=self.config.enable_chunk_prefill,
prefix_cache_backend=self.config.prefix_cache_backend,
)
self.scheduler = Scheduler(config=scheduler_config, kv_cache_manager=self.kv_cache_manager)

Expand Down Expand Up @@ -178,6 +180,7 @@ async def add_request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
max_new_tokens=config.max_new_tokens,
model_id=self.config.model_id,
arrival_time=time.time(),
stop_strings=tuple(config.stop) if config.stop else (),
eos_token_id=self.eos_token_id,
Expand Down
53 changes: 51 additions & 2 deletions python/core/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch

from .radix_cache import RadixKey, RadixMatch, RadixPrefixCache
from .types import KvAllocation, ModelConfig, RuntimeConfig


Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
self.free_queue = FreeKVCacheBlockQueue()
self.hash_to_block: dict[int, KVCacheBlock] = {}
self.request_blocks: dict[str, list[KVCacheBlock]] = {}
self.radix_cache = self._new_radix_cache()
if num_blocks is not None:
self._init_blocks(num_blocks, block_size)

Expand All @@ -154,6 +156,7 @@ def _init_blocks(self, num_blocks: int, block_size: int) -> None:
self.blocks = [KVCacheBlock(block_id=i) for i in range(num_blocks)]
for block in self.blocks:
self.free_queue.append(block)
self.radix_cache = self._new_radix_cache()

def register_model(self, model_id: str, config: ModelConfig, runtime: RuntimeConfig) -> None:
"""Create the KV page pool for a model if it is not already registered."""
Expand Down Expand Up @@ -206,6 +209,8 @@ def allocate_blocks(self, num_blocks: int) -> list[KVCacheBlock] | None:
"""Allocate physical KV blocks, evicting stale prefix hashes as needed."""
if num_blocks <= 0:
return []
if self.num_free_blocks < num_blocks:
self.radix_cache.evict_pages(num_blocks - self.num_free_blocks)
if self.num_free_blocks < num_blocks:
return None
blocks: list[KVCacheBlock] = []
Expand All @@ -232,8 +237,27 @@ def allocate_block_ids(self, num_blocks: int) -> list[int] | None:
def release_blocks_by_ids(self, *block_id_groups: list[int]) -> None:
"""Release request references for one or more groups of physical block IDs."""
for block_ids in block_id_groups:
for block_id in block_ids:
self.release(self.blocks[block_id])
self.release_pages_from_request(block_ids)

def retain_pages_for_request(self, page_ids: list[int]) -> None:
"""Take one active-request reference on each physical KV page."""
for page_id in page_ids:
self.retain(self.blocks[page_id])

def release_pages_from_request(self, page_ids: list[int]) -> None:
"""Release one active-request reference from each physical KV page."""
for page_id in page_ids:
self.release(self.blocks[page_id])

def retain_pages_for_cache(self, page_ids: list[int]) -> None:
"""Take one prefix-cache reference on each physical KV page."""
for page_id in page_ids:
self.retain(self.blocks[page_id])

def release_pages_from_cache(self, page_ids: list[int]) -> None:
"""Release one prefix-cache reference from each physical KV page."""
for page_id in page_ids:
self.release(self.blocks[page_id])

def release_cached_blocks(self, blocks: list[KVCacheBlock]) -> None:
"""Release cached block objects returned by ``get_computed_blocks``."""
Expand Down Expand Up @@ -284,6 +308,12 @@ def release(self, block: KVCacheBlock) -> None:
if block.ref_cnt == 0:
self.free_queue.append(block)

def retain(self, block: KVCacheBlock) -> None:
"""Retain one reference to a physical KV block."""
if block.ref_cnt == 0:
self.free_queue.remove(block)
block.ref_cnt += 1

def _iter_block_hashes(self, token_ids: list[int]):
"""Yield (block_index, block_hash) for each full block in the token sequence."""
parent_hash = NONE_HASH
Expand All @@ -310,6 +340,18 @@ def compute_block_hashes(self, token_ids: list[int]) -> list[int]:
"""Compute chained hashes for all full blocks in the token sequence."""
return [block_hash for _, block_hash in self._iter_block_hashes(token_ids)]

def match_radix_prefix(self, model_id: str, token_ids: list[int]) -> RadixMatch:
"""Find and lock the longest page-aligned radix prefix for one model."""
if not self.enable_prefix_cache:
return self.radix_cache.match(RadixKey.from_tokens([], extra_key=(model_id,)))
return self.radix_cache.match(RadixKey.from_tokens(token_ids, extra_key=(model_id,)))

def insert_radix_prefix(self, model_id: str, token_ids: list[int], page_ids: list[int]) -> None:
"""Publish a page-aligned request prefix to the radix cache."""
if not self.enable_prefix_cache:
return
self.radix_cache.insert(RadixKey.from_tokens(token_ids, extra_key=(model_id,)), page_ids)
Comment on lines +349 to +353

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Missing Eviction Mechanism for Radix Cache

The radix cache implementation pins physical KV pages in memory by calling retain_pages_for_cache (which increments ref_cnt and removes blocks from self.free_queue) when prefixes are inserted. However, there is no eviction mechanism integrated into the block allocation path (allocate_blocks in KvCacheManager).

As a result, once blocks are cached in the radix tree, they are never evicted, leading to block starvation and eventually causing the serving engine to hang or fail to schedule new requests.

Remedy:
In KvCacheManager.allocate_blocks, if self.num_free_blocks < num_blocks, trigger eviction from the radix cache by calling self.radix_cache.evict_pages(num_blocks - self.num_free_blocks) to free up enough blocks.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in 107db2d. KvCacheManager.allocate_blocks() now calls radix_cache.evict_pages() before returning allocation failure, and the scheduler allocation helper no longer short-circuits on num_free_blocks, so radix-owned unlocked pages can be reclaimed during scheduling. Added regression coverage in test_radix_owned_pages_are_evicted_for_new_allocations and test_scheduler_allocation_can_evict_unlocked_radix_pages.


def ensure_one_more_slot(self, alloc: KvAllocation) -> int:
"""Ensure a request has capacity for one more token and return its slot."""
pool = self._pool(alloc.model_id)
Expand Down Expand Up @@ -413,3 +455,10 @@ def _pool(self, model_id: str) -> _CachePool:
if model_id not in self._pools:
raise KeyError(f"Model {model_id} is not registered with the KV cache manager.")
return self._pools[model_id]

def _new_radix_cache(self) -> RadixPrefixCache:
return RadixPrefixCache(
self.block_size,
retain_pages=self.retain_pages_for_cache,
release_pages=self.release_pages_from_cache,
)
Loading