diff --git a/benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh b/benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh index 1e716aa4e..eb2dab447 100755 --- a/benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh +++ b/benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh @@ -336,6 +336,238 @@ if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1": raise ValueError(f"Unsupported transfer direction: {direction}") lmc.multi_layer_block_kv_transfer = multi_layer_block_kv_transfer + +# ---- Chunked KV loading (prevents GPU block exhaustion at high concurrency) ---- +if os.environ.get("CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD", "0") != "0": + import chunked_connector_patch # noqa: F401 + +# ---- vLLM scheduler assertion fix (stale KV transfer notifications) ---- +import scheduler_assertion_patch # noqa: F401 +PY +} + +write_chunked_connector_patch() { + local patch_dir="$1" + mkdir -p "$patch_dir" + cat > "$patch_dir/chunked_connector_patch.py" <<'PY' +""" +Monkey-patch for LMCacheMPConnector to add chunked KV loading. + +Fixes GPU block exhaustion deadlock at high concurrency by capping +the number of external tokens reported AND retrieved per scheduling step. + +Usage: set CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD= and import this +module from sitecustomize.py before LMCache is loaded. +""" + +import logging +import os +import sys +import builtins + +logger = logging.getLogger("chunked_lmcache_patch") + +_MAX_TOKENS = int(os.environ.get("CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD", "32768")) + +# Per-request chunk tracking (module-level, survives across calls) +_chunk_state: dict[str, dict] = {} + + +def _apply_patch(): + """Patch LMCacheMPConnector in-place.""" + mod = sys.modules.get("lmcache.integration.vllm.lmcache_mp_connector") + if mod is None: + return + cls = getattr(mod, "LMCacheMPConnector", None) + if cls is None or getattr(cls, "_chunked_patch_applied", False): + return + + LMCacheMPRequestState = getattr(mod, "LMCacheMPRequestState", None) + _orig_get_matched = cls.get_num_new_matched_tokens + _orig_get_finished = cls.get_finished + + def _get_blocks_per_chunk(self): + block_size = getattr(self, "block_size", 1) + return max(1, _MAX_TOKENS // block_size) + + def _patched_get_num_new_matched_tokens(self, request, num_computed_tokens): + full_match = _orig_get_matched(self, request, num_computed_tokens) + if full_match <= 0 or _MAX_TOKENS <= 0: + return full_match + + req_id = request.request_id + block_size = getattr(self, "block_size", 1) + blocks_per_chunk = _get_blocks_per_chunk(self) + full_match_blocks = full_match // block_size + + state = _chunk_state.get(req_id) + if state is None or state.get("num_computed_at_start") != num_computed_tokens: + state = { + "full_match_blocks": full_match_blocks, + "chunk_end_blocks": 0, + "num_computed_at_start": num_computed_tokens, + "lookup_done": False, + } + _chunk_state[req_id] = state + + if state["lookup_done"]: + return 0 + + remaining = state["full_match_blocks"] - state["chunk_end_blocks"] + if remaining <= 0: + state["lookup_done"] = True + return 0 + + this_chunk = min(remaining, blocks_per_chunk) + state["chunk_end_blocks"] += this_chunk + if state["chunk_end_blocks"] >= state["full_match_blocks"]: + state["lookup_done"] = True + + capped = this_chunk * block_size + if capped < full_match: + logger.debug( + "Chunked LMCache: req %s capped %d -> %d tokens " + "(chunk %d/%d blocks)", + req_id, full_match, capped, this_chunk, full_match_blocks, + ) + + # Cap the tracker's hit blocks to match what we report + tracker = getattr(request, "kv_transfer_params", None) + if tracker is not None: + orig_hits = getattr(tracker, "num_lmcache_hit_blocks", 0) + if orig_hits > this_chunk: + tracker.num_lmcache_hit_blocks = this_chunk + + return capped + + def _patched_get_finished(self, scheduler_output): + result = _orig_get_finished(self, scheduler_output) + # Clean up chunk state for finished requests. + # vLLM passes scheduler_output as a set of request-ID strings + # (not a SchedulerOutput object), so iterate directly when it + # is a set/frozenset; fall back to the attribute path for + # forward compatibility. + if isinstance(scheduler_output, (set, frozenset)): + finished = scheduler_output + else: + finished = getattr(scheduler_output, "finished_req_ids", []) + for req in finished: + _chunk_state.pop(req, None) + return result + + cls.get_num_new_matched_tokens = _patched_get_num_new_matched_tokens + cls.get_finished = _patched_get_finished + cls._chunked_patch_applied = True + logger.info( + "Chunked LMCache connector patch applied " + "(max_tokens_per_load=%d)", _MAX_TOKENS, + ) + + +_orig_import = builtins.__import__ + + +def _patching_import(name, *args, **kwargs): + module = _orig_import(name, *args, **kwargs) + if ( + name == "lmcache.integration.vllm.lmcache_mp_connector" + or ( + name.startswith("lmcache") + and "lmcache.integration.vllm.lmcache_mp_connector" in sys.modules + ) + ): + _apply_patch() + return module + + +builtins.__import__ = _patching_import +_apply_patch() +PY +} + +write_scheduler_assertion_patch() { + local patch_dir="$1" + mkdir -p "$patch_dir" + cat > "$patch_dir/scheduler_assertion_patch.py" <<'PY' +""" +Patch vLLM scheduler to handle stale finished_recving gracefully. + +The assertion at scheduler.py crashes when a KV transfer reports +"finished recving" but the request is already in RUNNING state. +This happens when transfers complete asynchronously and the scheduler +has already moved the request forward. + +Fix: Instead of asserting, log a warning and skip. +""" + +import logging +import sys +import builtins + +logger = logging.getLogger("scheduler_assertion_patch") + + +def _apply_patch(): + """Patch vLLM scheduler's _update_from_kv_xfer_finished.""" + sched_mod = sys.modules.get("vllm.v1.core.sched.scheduler") + if sched_mod is None: + return + req_mod = sys.modules.get("vllm.v1.request") + if req_mod is None: + return + Scheduler = getattr(sched_mod, "Scheduler", None) + RequestStatus = getattr(req_mod, "RequestStatus", None) + if Scheduler is None or RequestStatus is None: + return + if getattr(Scheduler, "_kv_xfer_patch_applied", False): + return + + _orig_update = Scheduler._update_from_kv_xfer_finished + + def _patched_update(self, kv_connector_output): + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + for req_id in kv_connector_output.finished_recving or (): + if req_id not in self.requests: + continue + req = self.requests[req_id] + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + self.finished_recving_kv_req_ids.add(req_id) + elif RequestStatus.is_finished(req.status): + self._free_blocks(self.requests[req_id]) + else: + logger.warning( + "Stale finished_recving for req %s in status %s; skipping.", + req_id, req.status.name, + ) + for req_id in kv_connector_output.finished_sending or (): + if req_id not in self.requests: + continue + self._free_blocks(self.requests[req_id]) + + Scheduler._update_from_kv_xfer_finished = _patched_update + Scheduler._kv_xfer_patch_applied = True + logger.info("Scheduler KV transfer assertion patch applied") + + +_orig_import = builtins.__import__ + + +def _patching_import(name, *args, **kwargs): + module = _orig_import(name, *args, **kwargs) + if ( + name == "vllm.v1.core.sched.scheduler" + or ( + name.startswith("vllm") + and "vllm.v1.core.sched.scheduler" in sys.modules + ) + ): + _apply_patch() + return module + + +builtins.__import__ = _patching_import +_apply_patch() PY } @@ -481,9 +713,16 @@ if not getattr(cupy_runtime, "is_hip", False): PY LMCACHE_ROCM_PATCH_DIR="$RESULT_DIR/lmcache_rocm_patch" write_lmcache_rocm_mp_patch "$LMCACHE_ROCM_PATCH_DIR" + write_chunked_connector_patch "$LMCACHE_ROCM_PATCH_DIR" + write_scheduler_assertion_patch "$LMCACHE_ROCM_PATCH_DIR" export LMCACHE_ROCM_MP_BLOCK_FALLBACK=1 export LMCACHE_ROCM_MP_BLOCK_FALLBACK_DTYPE=bfloat16 export LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR=1 + # Cap external KV tokens loaded per scheduling step to prevent GPU + # block exhaustion deadlock at high concurrency (c>=32). Default + # 32768 keeps peak block demand within the GPU KV pool. Set to 0 to + # disable chunking (only safe at low concurrency). + export CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD="${CHUNKED_LMCACHE_MAX_TOKENS_PER_LOAD:-32768}" export PYTHONPATH="$LMCACHE_ROCM_PATCH_DIR${PYTHONPATH:+:$PYTHONPATH}" python3 -c "import lmcache.integration.vllm.lmcache_mp_connector" >/dev/null