Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,238 @@ if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1":
raise ValueError(f"Unsupported transfer direction: {direction}")

lmc.multi_layer_block_kv_transfer = multi_layer_block_kv_transfer

# ---- Chunked KV loading (prevents GPU block exhaustion at high concurrency) ----
if os.environ.get("CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD", "0") != "0":
import chunked_connector_patch # noqa: F401

# ---- vLLM scheduler assertion fix (stale KV transfer notifications) ----
import scheduler_assertion_patch # noqa: F401
PY
}

write_chunked_connector_patch() {
local patch_dir="$1"
mkdir -p "$patch_dir"
cat > "$patch_dir/chunked_connector_patch.py" <<'PY'
"""
Monkey-patch for LMCacheMPConnector to add chunked KV loading.

Fixes GPU block exhaustion deadlock at high concurrency by capping
the number of external tokens reported AND retrieved per scheduling step.

Usage: set CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD=<tokens> and import this
module from sitecustomize.py before LMCache is loaded.
"""

import logging
import os
import sys
import builtins

logger = logging.getLogger("chunked_lmcache_patch")

_MAX_TOKENS = int(os.environ.get("CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD", "32768"))

# Per-request chunk tracking (module-level, survives across calls)
_chunk_state: dict[str, dict] = {}


def _apply_patch():
"""Patch LMCacheMPConnector in-place."""
mod = sys.modules.get("lmcache.integration.vllm.lmcache_mp_connector")
if mod is None:
return
cls = getattr(mod, "LMCacheMPConnector", None)
if cls is None or getattr(cls, "_chunked_patch_applied", False):
return

LMCacheMPRequestState = getattr(mod, "LMCacheMPRequestState", None)
_orig_get_matched = cls.get_num_new_matched_tokens
_orig_get_finished = cls.get_finished

def _get_blocks_per_chunk(self):
block_size = getattr(self, "block_size", 1)
return max(1, _MAX_TOKENS // block_size)

def _patched_get_num_new_matched_tokens(self, request, num_computed_tokens):
full_match = _orig_get_matched(self, request, num_computed_tokens)
if full_match <= 0 or _MAX_TOKENS <= 0:
return full_match

req_id = request.request_id
block_size = getattr(self, "block_size", 1)
blocks_per_chunk = _get_blocks_per_chunk(self)
full_match_blocks = full_match // block_size

state = _chunk_state.get(req_id)
if state is None or state.get("num_computed_at_start") != num_computed_tokens:
state = {
"full_match_blocks": full_match_blocks,
"chunk_end_blocks": 0,
"num_computed_at_start": num_computed_tokens,
"lookup_done": False,
}
_chunk_state[req_id] = state

if state["lookup_done"]:
return 0

remaining = state["full_match_blocks"] - state["chunk_end_blocks"]
if remaining <= 0:
state["lookup_done"] = True
return 0

this_chunk = min(remaining, blocks_per_chunk)
state["chunk_end_blocks"] += this_chunk
if state["chunk_end_blocks"] >= state["full_match_blocks"]:
state["lookup_done"] = True

capped = this_chunk * block_size
if capped < full_match:
logger.debug(
"Chunked LMCache: req %s capped %d -> %d tokens "
"(chunk %d/%d blocks)",
req_id, full_match, capped, this_chunk, full_match_blocks,
)

# Cap the tracker's hit blocks to match what we report
tracker = getattr(request, "kv_transfer_params", None)
if tracker is not None:
orig_hits = getattr(tracker, "num_lmcache_hit_blocks", 0)
if orig_hits > this_chunk:
tracker.num_lmcache_hit_blocks = this_chunk

return capped

def _patched_get_finished(self, scheduler_output):
result = _orig_get_finished(self, scheduler_output)
# Clean up chunk state for finished requests.
# vLLM passes scheduler_output as a set of request-ID strings
# (not a SchedulerOutput object), so iterate directly when it
# is a set/frozenset; fall back to the attribute path for
# forward compatibility.
if isinstance(scheduler_output, (set, frozenset)):
finished = scheduler_output
else:
finished = getattr(scheduler_output, "finished_req_ids", [])
for req in finished:
_chunk_state.pop(req, None)
Comment thread
cursor[bot] marked this conversation as resolved.
return result

cls.get_num_new_matched_tokens = _patched_get_num_new_matched_tokens
cls.get_finished = _patched_get_finished
cls._chunked_patch_applied = True
logger.info(
"Chunked LMCache connector patch applied "
"(max_tokens_per_load=%d)", _MAX_TOKENS,
)


_orig_import = builtins.__import__


def _patching_import(name, *args, **kwargs):
module = _orig_import(name, *args, **kwargs)
if (
name == "lmcache.integration.vllm.lmcache_mp_connector"
or (
name.startswith("lmcache")
and "lmcache.integration.vllm.lmcache_mp_connector" in sys.modules
)
):
_apply_patch()
return module


builtins.__import__ = _patching_import
_apply_patch()
PY
}

write_scheduler_assertion_patch() {
local patch_dir="$1"
mkdir -p "$patch_dir"
cat > "$patch_dir/scheduler_assertion_patch.py" <<'PY'
"""
Patch vLLM scheduler to handle stale finished_recving gracefully.

The assertion at scheduler.py crashes when a KV transfer reports
"finished recving" but the request is already in RUNNING state.
This happens when transfers complete asynchronously and the scheduler
has already moved the request forward.

Fix: Instead of asserting, log a warning and skip.
"""

import logging
import sys
import builtins

logger = logging.getLogger("scheduler_assertion_patch")


def _apply_patch():
"""Patch vLLM scheduler's _update_from_kv_xfer_finished."""
sched_mod = sys.modules.get("vllm.v1.core.sched.scheduler")
if sched_mod is None:
return
req_mod = sys.modules.get("vllm.v1.request")
if req_mod is None:
return
Scheduler = getattr(sched_mod, "Scheduler", None)
RequestStatus = getattr(req_mod, "RequestStatus", None)
if Scheduler is None or RequestStatus is None:
return
if getattr(Scheduler, "_kv_xfer_patch_applied", False):
return

_orig_update = Scheduler._update_from_kv_xfer_finished

def _patched_update(self, kv_connector_output):
if self.connector is not None:
self.connector.update_connector_output(kv_connector_output)
for req_id in kv_connector_output.finished_recving or ():
if req_id not in self.requests:
continue
req = self.requests[req_id]
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
self.finished_recving_kv_req_ids.add(req_id)
elif RequestStatus.is_finished(req.status):
self._free_blocks(self.requests[req_id])
else:
logger.warning(
"Stale finished_recving for req %s in status %s; skipping.",
req_id, req.status.name,
)
for req_id in kv_connector_output.finished_sending or ():
if req_id not in self.requests:
continue
self._free_blocks(self.requests[req_id])

Scheduler._update_from_kv_xfer_finished = _patched_update
Scheduler._kv_xfer_patch_applied = True
logger.info("Scheduler KV transfer assertion patch applied")


_orig_import = builtins.__import__


def _patching_import(name, *args, **kwargs):
module = _orig_import(name, *args, **kwargs)
if (
name == "vllm.v1.core.sched.scheduler"
or (
name.startswith("vllm")
and "vllm.v1.core.sched.scheduler" in sys.modules
)
):
_apply_patch()
return module


builtins.__import__ = _patching_import
_apply_patch()
PY
}

Expand Down Expand Up @@ -481,9 +713,16 @@ if not getattr(cupy_runtime, "is_hip", False):
PY
LMCACHE_ROCM_PATCH_DIR="$RESULT_DIR/lmcache_rocm_patch"
write_lmcache_rocm_mp_patch "$LMCACHE_ROCM_PATCH_DIR"
write_chunked_connector_patch "$LMCACHE_ROCM_PATCH_DIR"
write_scheduler_assertion_patch "$LMCACHE_ROCM_PATCH_DIR"
export LMCACHE_ROCM_MP_BLOCK_FALLBACK=1
export LMCACHE_ROCM_MP_BLOCK_FALLBACK_DTYPE=bfloat16
export LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR=1
# Cap external KV tokens loaded per scheduling step to prevent GPU
# block exhaustion deadlock at high concurrency (c>=32). Default
# 32768 keeps peak block demand within the GPU KV pool. Set to 0 to
# disable chunking (only safe at low concurrency).
export CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD="${CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD:-32768}"
export PYTHONPATH="$LMCACHE_ROCM_PATCH_DIR${PYTHONPATH:+:$PYTHONPATH}"
python3 -c "import lmcache.integration.vllm.lmcache_mp_connector" >/dev/null

Expand Down