Skip to content

fix(vllm): prevent EngineCore crash when shared physical KV pool is exhausted (fix #262)#301

Open
mahendrarathore1742 wants to merge 2 commits into
ovg-project:mainfrom
mahendrarathore1742:fix/issue-262-enginecore-kvpool-exhaustion
Open

fix(vllm): prevent EngineCore crash when shared physical KV pool is exhausted (fix #262)#301
mahendrarathore1742 wants to merge 2 commits into
ovg-project:mainfrom
mahendrarathore1742:fix/issue-262-enginecore-kvpool-exhaustion

Conversation

@mahendrarathore1742
Copy link
Copy Markdown

Summary

This PR fixes a crash in ElasticBlockPool.get_new_blocks() under multi-instance load when the shared physical KV pool is exhausted between scheduler availability check and allocation.

Previously, kv_cache_manager.alloc() could return None, and the code asserted on that result, causing a fatal AssertionError and crashing EngineCore.

Root Cause

available_size() and alloc() are not atomic across multiple instances sharing the same physical pool.
A race can occur where one instance consumes physical pages after another instance checks capacity but before it calls alloc().

What Changed

  • Updated ElasticBlockPool.get_new_blocks() in kvcached/integration/vllm/patches.py:
    • Replaced fatal assert with graceful handling when alloc() returns None.
    • Added one retry path after opportunistic eviction from _evictable_blocks (when prefix cache is enabled).
    • If allocation still fails, raise ValueError with clear context (requested, available) instead of crashing.
    • Added a defensive check for unexpected partial allocation length.

Why This Is Safe

  • Keeps existing allocation flow and prefix-cache behavior intact.
  • Only changes failure-mode handling from hard crash to controlled error.
  • Retry path is bounded (single retry), so no retry loops.

Tests

Added regression tests in tests/test_prefix_cache.py:

  • test_get_new_blocks_raises_when_alloc_returns_none
  • test_get_new_blocks_retries_after_eviction_when_alloc_returns_none

Result

  • ✅ Both new regression tests pass (2 passed).

Issue

Closes #262

- replace fatal assert in ElasticBlockPool.get_new_blocks with graceful ValueError when alloc returns None

- retry allocation after opportunistic eviction to mitigate cross-instance race

- add regression tests for alloc(None) failure and retry path
Copilot AI review requested due to automatic review settings April 8, 2026 10:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the vLLM ElasticBlockPool integration against a multi-instance race where the shared physical KV pool can be exhausted between an availability check and the actual allocation, preventing an AssertionError crash by handling alloc() failures gracefully.

Changes:

  • Replace a fatal assert on kv_cache_manager.alloc() with a bounded retry (with optional eviction) and a contextual ValueError on failure.
  • Add defensive validation for unexpected allocation lengths.
  • Add regression tests covering alloc() returning None and the retry path.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
kvcached/integration/vllm/patches.py Adds retry + error handling in ElasticBlockPool.get_new_blocks() when physical-pool allocation fails.
tests/test_prefix_cache.py Adds regression tests to ensure alloc() returning None no longer crashes and triggers the retry behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +309 to +313
if len(block_ids) != num_blocks:
raise ValueError(
"KV cache manager returned an unexpected number of blocks; "
f"requested={num_blocks}, got={len(block_ids)}"
)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If kv_cache_manager.alloc() ever returns a non-None list with len(block_ids) != num_blocks, this code raises without freeing the partially allocated block_ids, which would leak blocks in the manager. Consider freeing block_ids before raising (or converting this to an assertion if partial allocation is truly impossible).

Copilot uses AI. Check for mistakes.
Comment thread kvcached/integration/vllm/patches.py Outdated
Comment on lines +289 to +302
if kvcached_free < num_blocks:
self._evict_blocks_from_pool(num_blocks - kvcached_free)
block_ids = self.kv_cache_manager.alloc(num_blocks)

Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eviction+alloc logic is duplicated (initial eviction block and the retry-on-None path repeat the same available_size()/_evict_blocks_from_pool() sequence). Consider factoring this into a small helper or a bounded loop over 2 attempts to reduce duplication and keep the eviction policy consistent in one place.

Copilot uses AI. Check for mistakes.
Comment thread tests/test_prefix_cache.py Outdated
def test_get_new_blocks_retries_after_eviction_when_alloc_returns_none(self, pool_factory):
pool, _ = pool_factory(5)

blocks = _simulate_request(pool, [b"h0", b"h1", b"h2", b"h3"])
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test uses byte hashes (b"h0", etc.) while most other tests in this file use string hashes (e.g. "h0"). To keep the test suite consistent and avoid type-related surprises, consider standardizing on one hash type across the file (ideally matching what ElasticBlockPool._make_cache_key expects).

Suggested change
blocks = _simulate_request(pool, [b"h0", b"h1", b"h2", b"h3"])
blocks = _simulate_request(pool, ["h0", "h1", "h2", "h3"])

Copilot uses AI. Check for mistakes.
…tions

- replace duplicated eviction/alloc sequence with bounded two-attempt loop

- free partial block_ids before raising on allocation size mismatch

- support str hash inputs in cache key and add regression coverage
@mahendrarathore1742
Copy link
Copy Markdown
Author

Hi @jiarong0907 @cui36 @ztang2370 — could you review this when you get a chance? Thanks!

@cui36
Copy link
Copy Markdown
Collaborator

cui36 commented May 12, 2026

@mahendrarathore1742 Thanks for the PR. Reviewing soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EngineCore crashes when physical KV pool is exhausted under multi-instance load

3 participants