From dab34008572fb05e7b52f3c9947d1e9a80693dc9 Mon Sep 17 00:00:00 2001 From: Dhinesh Ponnarasan <160256912+DhineshPonnarasan@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:47:11 -0400 Subject: [PATCH 1/5] [None][test] Add MLA chunked-prefill SM dispatch regression coverage (#13904) Signed-off-by: Dhinesh Ponnarasan --- .../_torch/attention/test_attention_mla.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index bd01ed363dc4..7d0d118b6542 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -379,6 +379,78 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: } +@pytest.mark.parametrize( + "sm_version,expected_path", + [ + (90, "cached_kv"), + (99, "cached_kv"), + (100, "chunked_prefill"), + ], +) +def test_mla_chunked_prefill_dispatch_by_sm(sm_version, expected_path, + monkeypatch): + import tensorrt_llm._torch.modules.attention as attention_module + + class FakeTrtllmAttention: + + @staticmethod + def has_cached_kv_for_mla_context_warmup(_metadata): + return False + + @staticmethod + def is_chunked_prefill_for_mla_context(_metadata): + return True + + @staticmethod + def has_cached_kv_for_mla_context(_metadata): + return False + + class FakeMetadata: + pass + + class FakeAttention: + + def __init__(self): + self.mha = FakeTrtllmAttention() + + @staticmethod + def forward_context_with_chunked_prefill(*_args, **_kwargs): + return "chunked_prefill" + + @staticmethod + def forward_context_with_cached_kv(*_args, **_kwargs): + return "cached_kv" + + @staticmethod + def forward_context_default(*_args, **_kwargs): + return "default" + + monkeypatch.setattr(attention_module, "TrtllmAttention", + FakeTrtllmAttention) + monkeypatch.setattr(attention_module, "TrtllmAttentionMetadata", + FakeMetadata) + monkeypatch.setattr(attention_module, "get_sm_version", lambda: sm_version) + + q = torch.empty((1, 8), dtype=torch.float16) + compressed_kv = torch.empty((1, 4), dtype=torch.float16) + k_pe = torch.empty((1, 4), dtype=torch.float16) + position_ids = torch.zeros((1, ), dtype=torch.int64) + output = torch.empty((1, 8), dtype=torch.float16) + latent_cache = torch.empty((1, 1, 8), dtype=torch.float16) + + result = attention_module.MLA.forward_context( + FakeAttention(), + q, + compressed_kv, + k_pe, + position_ids, + FakeMetadata(), + output, + latent_cache, + ) + assert result == expected_path + + # Convert parameterized tests to pytest parametrize @pytest.mark.parametrize("scenario", scenarios, ids=lambda x: f"scenario: {x}") @pytest.mark.parametrize("context_sequence_lengths", From 0be1447c71c7bdbceac6060984dfd2cfd7cc36e3 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Wed, 10 Jun 2026 09:06:47 -0700 Subject: [PATCH 2/5] [TRTLLM-12648][test] enable disagg cancellation stress test (#15174) Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../defs/stress_test/disagg_cancel/README.md | 173 ++++-- .../disagg_cancel/configs/README.md | 22 +- .../configs/marathon_cpp_v1_deepseek.yaml | 25 +- .../configs/marathon_python_v2_qwen.yaml | 5 + .../defs/stress_test/disagg_cancel/harness.py | 493 ++++++++++++++++-- .../test_disagg_cancel_stress.py | 143 +++-- .../test_lists/qa/llm_function_stress.txt | 1 + 7 files changed, 757 insertions(+), 105 deletions(-) diff --git a/tests/integration/defs/stress_test/disagg_cancel/README.md b/tests/integration/defs/stress_test/disagg_cancel/README.md index 8cc9650a8bf3..26ef1b83d780 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/README.md +++ b/tests/integration/defs/stress_test/disagg_cancel/README.md @@ -1,6 +1,6 @@ # Disaggregated Cancellation Stress-Test Suite -Marathon-style stress tests that gate regressions of the bug class +Disaggregated stress tests that gate regressions of the bug class fixed by (cleanup / lifetime / quiescence invariants in the disagg KV transceiver under heavy mid-flight cancellation). @@ -13,8 +13,22 @@ transceiver under heavy mid-flight cancellation). ## Status -The harness class structure and lifecycle are in place. Thread bodies -land incrementally: +The registered QA stress entry now launches a real C++/V1 DeepSeek +disaggregated cluster in `log_only` mode. That mode sends normal +non-cancel completion probes through the front-end and scans saved +worker/server logs for UAF, broken-promise, and segmentation-fault +signatures. It is intentionally narrow so it can run regularly before +in-flight cancellation and poison-buffer hardening are available. + +The full cancellation/poison marathon is implemented as an explicit +mode switch, but it is not the registered default yet. + +| Mode | CI status | Threads | Coverage | +|------|-----------|---------|----------| +| `log_only` | Registered in `qa/llm_function_stress.txt` | log-only probe + log scanner | startup/data-path crash guard: UAF, broken promise, segfault-class logs | +| `full_cancel_poison` | Opt-in only | load, canary, injector, log scanner, metrics | cancellation load, failure injection, poison canaries, KV-growth guard | + +Thread bodies: | Thread | Status | |--------|--------| @@ -26,8 +40,8 @@ land incrementally: Component-level coverage: `test_log_scanner.py`, `test_metrics_thread.py`, `test_injector.py`, `test_canary.py`, `test_load_thread.py`. The -parametrized marathon pytest still runs a lifecycle smoke until -`setup()` launches a real cluster. +parametrized C++/V1 DeepSeek run is registered in the QA stress test +list as a real `log_only` guardrail. ## File layout @@ -52,17 +66,76 @@ Future additions: - `tools/generate_canary_references.py` — one-shot reference generator that records greedy-decode token IDs for the canary prompts. - `configs/stress_canary_prompts.json` — canary prompts + recorded - reference token IDs (consumed by the canary thread). + reference token IDs for `full_cancel_poison`. - Per-scenario YAMLs covering additional axes: 1P1D, 4P2D, V1+Python, UCX, block-reuse-off, overlap-off, aggressive-timeout, multi-node (all Python-only test-side configuration). +## Mode Switch + +The active mode is controlled by +`configs/marathon_cpp_v1_deepseek.yaml`: + +```yaml +stress_config: + mode: log_only + duration_min: 10 +``` + +Use `log_only` for regular CI until both runtime features are in +place: + +- in-flight request cancellation support for the disaggregated path. +- poison-buffer hardening that makes poisoned cache transfers + expected and recoverable. + +To switch to the full cancellation/poison marathon after those +features are ready: + +1. Set `stress_config.mode: full_cancel_poison`. +2. Set `stress_config.duration_min: 120` for the two-hour marathon. +3. Keep or tune `base_concurrency`, `bursts`, and `injections`. +4. Add `configs/stress_canary_prompts.json` with token references and + keep `canary.check_token_equivalent: true`. +5. Add the poison-buffer hard-zero/expected-recovery patterns that + match the finalized runtime behavior. +6. Raise the test-list timeout back to a full-marathon budget, e.g. + `TIMEOUT (150)`. + ## How to run -The marathons are **not** registered in pre-merge CI. They are run -nightly / weekly via -`tests/integration/test_lists/qa/llm_function_stress.txt` (wiring -lands with the explicit CI-registration change). +### Scheduled QA stress run + +The C++/V1 DeepSeek marathon is registered in +`tests/integration/test_lists/qa/llm_function_stress.txt`, which makes +it eligible for the QA/Jenkins job that consumes that stress list. This +PR does not create or modify the scheduler for that job; the exact +cadence and wall-clock start time are owned by QA CI configuration +outside this directory. The in-repo QA README describes QA lists as +regular daily/release and weekly/release/on-demand coverage, but does +not define a file-specific cadence for `llm_function_stress.txt`. + +The registered entry is: + +```text +stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon[marathon_cpp_v1_deepseek.yaml] TIMEOUT (45) +``` + +The integration test-list parser interprets `TIMEOUT (45)` in +minutes. CI should run the list from `tests/integration/defs` with: + +```bash +pytest --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir= \ + -s -v +``` + +The scheduled runner must use the normal TRT-LLM integration container +or virtual environment with GPU access, `trtllm-serve` on `PATH`, and +`LLM_MODELS_ROOT` set so `DeepSeek-V3-Lite/bf16` resolves to local +model weights. The current registered run is `log_only`: setup can +take up to 20 minutes, then the harness probes for 10 minutes and +tails worker/server logs. ### Unit tests (no GPU, no cluster) @@ -105,12 +178,13 @@ python3 -m pytest -c /dev/null -o addopts= \ tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_all_marathon_yamls_parse_and_validate -v ``` -All component tests together: +All component tests together, excluding the real cluster entry: ```bash python3 -m pytest -c /dev/null -o addopts= \ --confcutdir=tests/integration/defs/stress_test \ - tests/integration/defs/stress_test/disagg_cancel/ -q + tests/integration/defs/stress_test/disagg_cancel/ \ + -k "not test_disagg_cancellation_marathon" -q ``` In a full TRT-LLM dev container/venv (with `transformers` installed), @@ -120,27 +194,63 @@ the same tests also run under the normal integration pytest path: pytest -sv tests/integration/defs/stress_test/disagg_cancel/test_injector.py ``` -### Lifecycle smoke (injector not exercised on real workers) +### Manual regular guardrail run + +From a full TRT-LLM integration environment: + +```bash +cd /path/to/TensorRT-LLM/tests/integration/defs +export LLM_MODELS_ROOT=/path/to/model/root + +pytest stress_test/disagg_cancel/test_disagg_cancel_stress.py \ + --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir=/tmp/trtllm-disagg-cancel-stress \ + -s -v +``` + +To collect without running: ```bash -pytest -sv tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon +pytest stress_test/disagg_cancel/test_disagg_cancel_stress.py \ + --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir=/tmp/trtllm-disagg-cancel-stress \ + -s --co -q +``` + +### Manual CI trigger + +On a GitHub pull request, ask the CI bot which stress stages are +available, then trigger the QA stress stage that consumes +`tests/integration/test_lists/qa/llm_function_stress.txt`: + +```text +/bot help +/bot run --extra-stage "" ``` -`setup()` is still a stub, so this only checks harness lifecycle -(`setup` → `start` → `wait` → `stop`). The injector thread exits -immediately because no workers are registered via -`bind_tracked_workers()`. +The bot stage name is owned by CI/Jenkins configuration and is not +declared in this directory. -### Local marathon (after `setup()` lands) +### Manual full cancellation/poison run -Once `setup()` launches a real 3P3D cluster and registers workers, -the full 2-hour marathon runs via the same pytest entry point. For -development, set `duration_min: 10` and trim `injections:` in the -YAML. +After the runtime support is in place, switch the YAML to +`mode: full_cancel_poison`, set the intended duration and canary +references, then run the same pytest entry point. For development, +use a shorter `duration_min` and trim `injections:` locally before +restoring the checked-in values. ## Pass criteria -A marathon run is "clean" iff all of the following hold: +`log_only` is clean iff all of the following hold: + +- The 3P3D disaggregated cluster starts and reaches readiness. +- At least one normal completion probe succeeds through the + disaggregated front-end. +- No hard-zero log patterns for UAF, broken promise, or + segmentation-fault-class failures appear in any saved worker or + disagg-server log. + +`full_cancel_poison` is clean iff all of the following hold: - No hard-zero log patterns (e.g. `Cannot cancel request`, `Broken promise`, `unquiesced`, double-free / UAF traces) appear in any @@ -157,23 +267,22 @@ A marathon run is "clean" iff all of the following hold: - KV-cache utilization growth ≤ 10 percentage points end-to-end (leak guard). -Concrete thresholds for each metric are declared in the marathon -YAML's `pass_criteria:` block. +Concrete thresholds for each metric are declared in the marathon YAML. ## How to debug a failure -(Stub — the full debug guide lands together with the thread -implementations.) - -For now, when the skeleton test fails: +When the regular guardrail fails: 1. Confirm the YAML parses: ```bash python -c "from harness import StressConfig; StressConfig.from_yaml_path('configs/marathon_cpp_v1_deepseek.yaml')" ``` 2. Check the `failure_reason` field in `collect_results()` output. -3. Look at the pytest stdout for harness `logger` lines (each thread - logs its identity on entry / exit). +3. Inspect the log tails printed by `disagg_test_utils.terminate()` + during teardown; saved worker logs and `disagg_server.log` are + tailed before cleanup. +4. If setup times out, confirm `LLM_MODELS_ROOT`, GPU count, and + `trtllm-serve` availability in the integration environment. ## Cross-references diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/README.md b/tests/integration/defs/stress_test/disagg_cancel/configs/README.md index a95545b7284a..b29f02390005 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/README.md +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/README.md @@ -25,6 +25,25 @@ The new `stress_config:` top-level block is consumed by `StressConfig` itself (dataclass field docstrings) and the example values in `marathon_cpp_v1_deepseek.yaml`. +## Harness modes + +`stress_config.mode` is the switch between the regular guardrail and +the full cancellation/poison marathon: + +- `log_only`: registered CI mode. The harness launches the real + disaggregated cluster, sends normal non-cancel probes, and scans + worker/server logs for UAF, broken-promise, and segmentation-fault + signatures. This mode is safe before in-flight cancellation and + poison-buffer hardening are available. +- `full_cancel_poison`: opt-in mode for the completed runtime. The + harness enables the cancellation load, SIGSTOP/SIGKILL injections, + token-equivalent canaries, metrics scraping, and KV-growth checks. + +When switching from `log_only` to `full_cancel_poison`, update +`duration_min`, canary references, poison-buffer log expectations, and +the test-list timeout together. The top-level README has the exact +checklist. + ## Backend-knob axis: KV-cache manager × transceiver runtime Two knobs select which (KV cache manager × transceiver runtime) @@ -51,7 +70,8 @@ Python changes** required beyond extending the parametrize list. To add a new config: 1. Copy `marathon_cpp_v1_deepseek.yaml` as a template. -2. Adjust `model`, `kv_cache_manager`, `transceiver`, and any +2. Choose `stress_config.mode`, then adjust `model`, + `kv_cache_manager`, `transceiver`, and any load-shape knobs (`base_concurrency`, `client_cancel_rate`, `output_length`, `injections:`, `pass_criteria:`). 3. Add the new filename to `_MARATHON_CONFIGS` in diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml index 7764035ce05f..e1fb72872d09 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml @@ -52,13 +52,29 @@ generation_servers: # Schema documented in ../README.md. # ============================================================================ stress_config: - duration_min: 120 # 2 h marathon + # MODE SWITCH (intentionally loud): + # - log_only: current registered CI mode. Launches the real disagg cluster, + # sends normal non-cancel probes, and scans saved worker/server logs for + # UAF, broken-promise, and segmentation-fault signatures. This mode does + # NOT require in-flight cancellation or poison-buffer support. + # - full_cancel_poison: future opt-in mode after in-flight cancellation and + # poison-buffer hardening are available. It enables the cancellation load, + # fault injections, token-equivalent canaries, and KV-growth checks below. + mode: log_only + duration_min: 10 # regular guardrail run; use 120 for full_cancel_poison # Backend-knob axis selectors. Must match # context_servers / generation_servers above. kv_cache_manager: v1 # V1 (C++) KV cache manager transceiver: cpp # C++-backed transceiver (BindKvCacheTransceiver) + log_only_probe: + interval_s: 30 + prompt: "Write one sentence about reliable distributed inference." + max_tokens: 32 + seed: 42 + request_timeout_s: 30 + base_concurrency: 64 client_cancel_rate: 0.10 input_length: @@ -120,10 +136,13 @@ stress_config: log_scan: hard_zero_patterns: - "Broken promise" - - "NO RECOVERY" - "Segfault" + - "Segmentation fault" - "SIGSEGV" - "0xffffffffffffffff" - - "Poisoned .* cache transfer buffer" + - "use-after-free" + - "heap-use-after-free" + - "AddressSanitizer:.*use-after-free" + - "double[- ]free" kv_cache_growth_max: 0.10 # final utilization ≤ baseline + 10 percentage points diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml index d74108335002..79b7a4b73b38 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml @@ -54,6 +54,11 @@ generation_servers: kv_transfer_timeout_ms: 60000 stress_config: + # Placeholder template for the future full marathon. This YAML is + # intentionally not parametrized until the runtime supports the + # cancellation + poison-buffer contract and canary references are + # recorded for Qwen2.5-7B-Instruct. + mode: full_cancel_poison duration_min: 120 kv_cache_manager: v2 diff --git a/tests/integration/defs/stress_test/disagg_cancel/harness.py b/tests/integration/defs/stress_test/disagg_cancel/harness.py index d85b9d88b936..f2eb18cb5f7a 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/harness.py +++ b/tests/integration/defs/stress_test/disagg_cancel/harness.py @@ -33,8 +33,10 @@ import os import random import re +import shutil import signal import subprocess +import tempfile import threading import time import urllib.error @@ -47,6 +49,10 @@ logger = logging.getLogger(__name__) +_STRESS_MODE_LOG_ONLY = "log_only" +_STRESS_MODE_FULL_CANCEL_POISON = "full_cancel_poison" +_STRESS_MODES = (_STRESS_MODE_LOG_ONLY, _STRESS_MODE_FULL_CANCEL_POISON) + # --------------------------------------------------------------------------- # Config dataclasses @@ -58,6 +64,7 @@ # simply aren't passed to the constructor, so the field defaults # apply automatically and are not duplicated here. _STRESS_CONFIG_COERCERS: dict[str, Callable[[Any], Any]] = { + "mode": str, "duration_min": float, "kv_cache_manager": str, "transceiver": str, @@ -76,6 +83,7 @@ class StressConfig: pass them around without re-parsing. """ + mode: str = _STRESS_MODE_LOG_ONLY duration_min: float = 120.0 kv_cache_manager: str = "v1" # v1 | v2 (v2 + CPP is invalid) transceiver: str = "cpp" # cpp | python @@ -137,6 +145,8 @@ def validate(self) -> None: supplied (the C++ transceiver only supports the V1 KV cache manager). """ + if self.mode not in _STRESS_MODES: + raise ValueError(f"mode must be one of {_STRESS_MODES}, got {self.mode!r}") if self.kv_cache_manager == "v2" and self.transceiver == "cpp": # The C++ transceiver (BindKvCacheTransceiver) only supports # the V1 KV cache manager. V2 must be paired with the Python @@ -152,6 +162,16 @@ def validate(self) -> None: if self.transceiver not in ("cpp", "python"): raise ValueError(f"transceiver must be 'cpp' or 'python', got {self.transceiver!r}") + @property + def is_log_only(self) -> bool: + """True when the harness should run the regular CI guardrail mode.""" + return self.mode == _STRESS_MODE_LOG_ONLY + + @property + def is_full_cancel_poison(self) -> bool: + """True when the harness should run the full cancellation/poison marathon.""" + return self.mode == _STRESS_MODE_FULL_CANCEL_POISON + _INJECTION_TARGET_RE = re.compile(r"^(ctx|gen)_worker_(\d+)$") @@ -778,7 +798,8 @@ def _load_iteration_shape(config: StressConfig, elapsed_s: float) -> dict[str, A if interval_s <= 0.0: raise ValueError( - f"stress_config.bursts.interval_min must be positive, got {bursts.get('interval_min')!r}" + "stress_config.bursts.interval_min must be positive, got " + f"{bursts.get('interval_min')!r}" ) if duration_s <= 0.0: raise ValueError( @@ -916,6 +937,7 @@ def __init__( self._cluster: Any = None # tuple returned by setup_disagg_cluster self._worker_specs: list[WorkerLaunchSpec] = [] self._tracked_workers: list[_TrackedWorker] = [] + self._server_log_path: Optional[str] = None self._marathon_start_monotonic: float = 0.0 # Disagg-server front-end the canary targets; populated by @@ -926,6 +948,7 @@ def __init__( # Thread handles (populated by start()). self._load_thread: Optional[threading.Thread] = None + self._log_only_thread: Optional[threading.Thread] = None self._canary_thread: Optional[threading.Thread] = None self._injector_thread: Optional[threading.Thread] = None self._log_scanner_thread: Optional[threading.Thread] = None @@ -944,13 +967,256 @@ def __init__( def setup(self) -> None: """Launch the disagg cluster from the YAML and record launch specs. - Stub: real implementation delegates to ``setup_disagg_cluster`` - in ``tests/integration/defs/disaggregated/test_disaggregated.py`` + Delegates the process launch to ``setup_disagg_cluster`` in + ``tests/integration/defs/disaggregated/test_disaggregated.py`` and shadow-tracks per-worker ``WorkerLaunchSpec`` so the injector thread can later relaunch a SIGKILLed worker without - modifying shared infrastructure. + modifying shared infrastructure. The harness-only + ``stress_config`` block is stripped from the temporary YAML + passed to the shared launcher so worker config validation only + sees normal ``trtllm-serve`` settings. """ - logger.info("[harness] setup() — stub: cluster not actually launched") + from test_disaggregated import ( + build_worker_config, + get_default_disagg_cluster_config, + get_ucx_tls, + setup_disagg_cluster, + ) + + cluster_config = self._load_sanitized_cluster_config() + raw_model_name = str(cluster_config.get("model") or "") + if not raw_model_name: + raise ValueError(f"YAML at {self.yaml_path} is missing top-level model") + model_name = self._resolve_model_name(raw_model_name) + + server_start_timeout_s = int(self.config.raw.get("server_start_timeout_s", 1200)) + run_env = os.environ.copy() + run_env["UCX_TLS"] = get_ucx_tls() + + setup_yaml_path = self._write_sanitized_cluster_yaml(cluster_config) + try: + self._cluster = setup_disagg_cluster( + setup_yaml_path, + model_name=model_name, + env=run_env, + server_start_timeout=server_start_timeout_s, + save_log=True, + ) + finally: + try: + os.unlink(setup_yaml_path) + except OSError: + logger.debug("[harness] could not unlink %s; ignoring", setup_yaml_path) + + config, ctx_workers, gen_workers, disagg_server, server_port, work_dir = self._cluster + server_host = config.get("hostname", "localhost") + server_url = f"http://{server_host}:{server_port}" + + disagg_cluster = get_default_disagg_cluster_config() + disagg_cluster["cluster_uri"] = server_url + ctx_servers = config.get("context_servers", {}) + gen_servers = config.get("generation_servers", {}) + disagg_cluster["minimal_instances"] = { + "context_servers": ctx_servers.get("num_instances", 1), + "generation_servers": gen_servers.get("num_instances", 1), + } + ctx_worker_config = build_worker_config(config, ctx_servers, disagg_cluster) + gen_worker_config = build_worker_config(config, gen_servers, disagg_cluster) + ctx_specs, gen_specs = self._build_worker_launch_specs( + ctx_workers=ctx_workers, + gen_workers=gen_workers, + ctx_worker_config=ctx_worker_config, + gen_worker_config=gen_worker_config, + ctx_servers=ctx_servers, + gen_servers=gen_servers, + model_name=model_name, + work_dir=work_dir, + env=run_env, + host=server_host, + ) + self._refresh_worker_ports_from_cluster_info(server_url, ctx_specs, gen_specs) + self.bind_tracked_workers(ctx_workers, gen_workers, ctx_specs, gen_specs) + self.bind_server_endpoint(server_url, model_name) + self._server_log_path = getattr(disagg_server, "log_path", None) + logger.info( + "[harness] setup() launched %d ctx worker(s), %d gen worker(s), server=%s", + len(ctx_workers), + len(gen_workers), + server_url, + ) + + def _load_sanitized_cluster_config(self) -> dict[str, Any]: + """Load YAML and remove harness-only fields before cluster launch.""" + with self.yaml_path.open("r", encoding="utf-8") as f: + doc = yaml.safe_load(f) + if not isinstance(doc, dict): + raise ValueError(f"YAML at {self.yaml_path} must be a mapping") + cluster_config = dict(doc) + cluster_config.pop("stress_config", None) + return cluster_config + + def _write_sanitized_cluster_yaml(self, cluster_config: dict[str, Any]) -> str: + """Write the launcher-facing YAML to a temporary file.""" + fd, path = tempfile.mkstemp(prefix="disagg_cancel_cluster_", suffix=".yaml") + with os.fdopen(fd, "w", encoding="utf-8") as f: + yaml.safe_dump(cluster_config, f) + return path + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve relative model names against ``LLM_MODELS_ROOT`` when set.""" + path = Path(model_name).expanduser() + if path.is_absolute() or path.exists(): + return str(path) + models_root = os.environ.get("LLM_MODELS_ROOT") + if models_root: + return str(Path(models_root).expanduser() / model_name) + return model_name + + def _build_worker_launch_specs( + self, + *, + ctx_workers: list[Any], + gen_workers: list[Any], + ctx_worker_config: dict[str, Any], + gen_worker_config: dict[str, Any], + ctx_servers: dict[str, Any], + gen_servers: dict[str, Any], + model_name: str, + work_dir: str, + env: dict[str, str], + host: str, + ) -> tuple[list[WorkerLaunchSpec], list[WorkerLaunchSpec]]: + """Reconstruct worker launch metadata for log scanning and respawn.""" + import torch + + num_gpus = torch.cuda.device_count() + if num_gpus <= 0: + raise RuntimeError("setup_disagg_cluster returned, but torch reports no CUDA devices") + + gpus_per_ctx = ( + int(ctx_servers.get("tensor_parallel_size", 1)) + * int(ctx_servers.get("pipeline_parallel_size", 1)) + * int(ctx_servers.get("context_parallel_size", 1)) + ) + gpus_per_gen = ( + int(gen_servers.get("tensor_parallel_size", 1)) + * int(gen_servers.get("pipeline_parallel_size", 1)) + * int(gen_servers.get("context_parallel_size", 1)) + ) + + ctx_specs: list[WorkerLaunchSpec] = [] + gen_specs: list[WorkerLaunchSpec] = [] + next_device = 0 + for index, wrapper in enumerate(ctx_workers): + device = self._format_device_ids(next_device, gpus_per_ctx, num_gpus) + next_device += gpus_per_ctx + ctx_specs.append( + self._make_worker_launch_spec( + role="ctx", + index=index, + wrapper=wrapper, + worker_config=ctx_worker_config, + model_name=model_name, + work_dir=work_dir, + device=device, + env=env, + host=host, + ) + ) + for index, wrapper in enumerate(gen_workers): + device = self._format_device_ids(next_device, gpus_per_gen, num_gpus) + next_device += gpus_per_gen + gen_specs.append( + self._make_worker_launch_spec( + role="gen", + index=index, + wrapper=wrapper, + worker_config=gen_worker_config, + model_name=model_name, + work_dir=work_dir, + device=device, + env=env, + host=host, + ) + ) + return ctx_specs, gen_specs + + def _format_device_ids(self, first_device: int, count: int, num_gpus: int) -> str: + """Return the CUDA_VISIBLE_DEVICES string used by setup_disagg_cluster.""" + return ",".join( + str(d) for d in dict.fromkeys((first_device + j) % num_gpus for j in range(count)) + ) + + def _make_worker_launch_spec( + self, + *, + role: str, + index: int, + wrapper: Any, + worker_config: dict[str, Any], + model_name: str, + work_dir: str, + device: str, + env: dict[str, str], + host: str, + ) -> WorkerLaunchSpec: + """Create one shadow launch spec from the shared ProcessWrapper.""" + return WorkerLaunchSpec( + role=role, + index=index, + model_name=model_name, + worker_config=worker_config, + work_dir=work_dir, + port=int(getattr(wrapper, "port", 0) or 0), + device=device, + env=env.copy(), + log_path=getattr(wrapper, "log_path", None), + host=host, + ) + + def _refresh_worker_ports_from_cluster_info( + self, + server_url: str, + ctx_specs: list[WorkerLaunchSpec], + gen_specs: list[WorkerLaunchSpec], + ) -> None: + """Populate worker host/port from disagg ``/cluster_info`` when available.""" + try: + with urllib.request.urlopen(f"{server_url}/cluster_info", timeout=5.0) as response: + info = json.loads(response.read().decode("utf-8", errors="replace")) + except (json.JSONDecodeError, TimeoutError, OSError, urllib.error.URLError) as exc: + logger.warning("[harness] could not read cluster_info for worker ports: %s", exc) + return + + current_workers = info.get("current_workers") or {} + for specs, key in ( + (ctx_specs, "context_servers"), + (gen_specs, "generation_servers"), + ): + workers = current_workers.get(key) or [] + if len(workers) != len(specs): + logger.warning( + "[harness] cluster_info %s count mismatch: %d worker(s), %d spec(s)", + key, + len(workers), + len(specs), + ) + for spec, worker_info in zip(specs, workers): + if not isinstance(worker_info, dict): + continue + host = worker_info.get("host") + port = worker_info.get("port") + if isinstance(host, str) and host: + spec.host = host + try: + spec.port = int(port) + except (TypeError, ValueError): + logger.warning( + "[harness] cluster_info %s worker %d has invalid port %r", + key, + spec.index, + port, + ) def bind_tracked_workers( self, @@ -988,30 +1254,47 @@ def bind_server_endpoint(self, server_url: str, model_name: str) -> None: self._model_name = model_name def start(self) -> None: - """Spawn the five worker threads. Returns immediately. + """Spawn the mode-specific worker threads. Returns immediately. + + ``log_only`` mode runs the regular CI guardrail: a normal + non-cancel probe loop plus log-pattern fail-fast. It + intentionally avoids the cancellation load, fault injector, + poison canary, and KV-growth gates until those runtime fixes + are present. - If ``setup()`` has not bound a live server endpoint yet, the - load thread warns and signals ``stop_event`` so the lifecycle - smoke still completes cleanly without waiting out the - ``wait_until_done`` timeout. + ``full_cancel_poison`` mode runs all five full-stress threads. """ self._marathon_start_monotonic = time.monotonic() - logger.info("[harness] start() — spawning worker threads") - self._load_thread = threading.Thread( - target=self._load_thread_body, name="stress-load", daemon=True - ) - self._canary_thread = threading.Thread( - target=self._canary_thread_body, name="stress-canary", daemon=True - ) - self._injector_thread = threading.Thread( - target=self._injector_thread_body, name="stress-injector", daemon=True - ) - self._log_scanner_thread = threading.Thread( - target=self._log_scanner_thread_body, name="stress-log-scanner", daemon=True - ) - self._metrics_thread = threading.Thread( - target=self._metrics_thread_body, name="stress-metrics", daemon=True - ) + logger.info("[harness] start() — mode=%s", self.config.mode) + if self.config.is_log_only: + self._log_only_thread = threading.Thread( + target=self._log_only_thread_body, + name="stress-log-only-probe", + daemon=True, + ) + self._log_scanner_thread = threading.Thread( + target=self._log_scanner_thread_body, + name="stress-log-scanner", + daemon=True, + ) + elif self.config.is_full_cancel_poison: + self._load_thread = threading.Thread( + target=self._load_thread_body, name="stress-load", daemon=True + ) + self._canary_thread = threading.Thread( + target=self._canary_thread_body, name="stress-canary", daemon=True + ) + self._injector_thread = threading.Thread( + target=self._injector_thread_body, name="stress-injector", daemon=True + ) + self._log_scanner_thread = threading.Thread( + target=self._log_scanner_thread_body, + name="stress-log-scanner", + daemon=True, + ) + self._metrics_thread = threading.Thread( + target=self._metrics_thread_body, name="stress-metrics", daemon=True + ) for t in self._all_threads(): t.start() @@ -1143,9 +1426,126 @@ def collect_results(self) -> dict[str, Any]: } # ------------------------------------------------------------------ - # Thread bodies (stubs — implemented incrementally) + # Thread bodies # ------------------------------------------------------------------ + def _configured_duration_s(self) -> float: + """Return the active run duration, honoring unit-test overrides.""" + if self._load_duration_s is not None: + return self._load_duration_s + return float(self.config.duration_min) * 60.0 + + def _log_only_thread_body(self) -> None: + """Run regular CI protection without cancellation or poison gates. + + This mode still launches the real disaggregated cluster and + sends normal completion probes through the front-end. It fails + the test on probe errors and runs concurrently with + ``log_scanner_thread`` so UAF, broken-promise, and segfault + signatures in worker/server logs remain hard-zero failures. + """ + if not self._server_url: + self.mark_failed("log_only mode requires setup() to bind a server endpoint") + self.stop_event.set() + return + + duration_s = self._configured_duration_s() + if duration_s <= 0.0: + logger.info("[log_only] non-positive duration %.3fs; exiting", duration_s) + self.stop_event.set() + return + + probe_cfg = self.config.raw.get("log_only_probe") or {} + try: + interval_s = float(probe_cfg.get("interval_s", 30.0)) + max_tokens = int(probe_cfg.get("max_tokens", 32)) + seed = int(probe_cfg.get("seed", 42)) + timeout_s = float(probe_cfg.get("request_timeout_s", self._canary_request_timeout_s)) + prompt = str( + probe_cfg.get( + "prompt", + "Write one sentence about reliable distributed inference.", + ) + ) + except (TypeError, ValueError) as exc: + self.mark_failed(f"log_only_probe config error: {exc}") + self.stop_event.set() + return + if interval_s <= 0.0: + self.mark_failed(f"log_only_probe.interval_s must be positive, got {interval_s}") + self.stop_event.set() + return + + deadline = time.monotonic() + duration_s + logger.info( + "[log_only] probing %s every %.1fs for %.1fs", + self._server_url, + interval_s, + duration_s, + ) + + while ( + time.monotonic() < deadline + and not self.stop_event.is_set() + and not self.failed_event.is_set() + ): + send_start = time.monotonic() + token_ids, _, err = self._send_log_only_probe( + prompt=prompt, + max_tokens=max_tokens, + seed=seed, + timeout_s=timeout_s, + ) + success = err is None + self._canary_records.append( + { + "timestamp": time.time(), + "elapsed_s": time.monotonic() - self._marathon_start_monotonic, + "mode": _STRESS_MODE_LOG_ONLY, + "prompt_index": 0, + "success": success, + "token_equivalent": None, + "latency_s": time.monotonic() - send_start, + "error": err, + "token_count": len(token_ids or []), + } + ) + if not success: + self.mark_failed(f"log_only probe failed: {err}") + break + + remaining = min(interval_s, max(0.0, deadline - time.monotonic())) + if remaining > 0.0: + self.stop_event.wait(timeout=remaining) + + if not self.failed_event.is_set() and not any( + record.get("success") for record in self._canary_records + ): + self.mark_failed("log_only mode completed without a successful probe") + if not self.failed_event.is_set(): + logger.info("[log_only] completed; signalling stop_event") + self.stop_event.set() + + def _send_log_only_probe( + self, + *, + prompt: str, + max_tokens: int, + seed: int, + timeout_s: float, + ) -> tuple[Optional[list[int]], Optional[str], Optional[str]]: + """Send one normal completion request for ``log_only`` mode.""" + if self._server_url is None: + return None, None, "missing_server_url" + return _send_canary_request( + self._server_url, + self._model_name or "log-only-probe", + prompt, + max_tokens, + seed, + timeout_s, + ) + def _load_thread_body(self) -> None: """Wrap ``run_cancel_stress_test`` in a duration-bounded loop. @@ -1164,11 +1564,7 @@ def _load_thread_body(self) -> None: self.stop_event.set() return - duration_s = ( - self._load_duration_s - if self._load_duration_s is not None - else float(self.config.duration_min) * 60.0 - ) + duration_s = self._configured_duration_s() if duration_s <= 0.0: logger.info("[load_thread] non-positive duration %.3fs; exiting", duration_s) self.stop_event.set() @@ -1570,6 +1966,19 @@ def _log_scanner_thread_body(self) -> None: return sources: list[_LogSource] = [] + if self._server_log_path is not None: + server_spec = WorkerLaunchSpec( + role="server", + index=0, + model_name=self._model_name or "disagg-server", + worker_config={}, + work_dir="", + port=0, + device="", + env={}, + log_path=self._server_log_path, + ) + sources.append(_LogSource(spec=server_spec, path=Path(self._server_log_path))) for spec in self._worker_specs: if spec.log_path is None: logger.warning( @@ -1589,7 +1998,7 @@ def _log_scanner_thread_body(self) -> None: return logger.info( - "[log_scanner] tailing %d worker log(s) against %d hard_zero pattern(s)", + "[log_scanner] tailing %d log source(s) against %d hard_zero pattern(s)", len(sources), len(patterns), ) @@ -1682,6 +2091,7 @@ def _all_threads(self) -> list[threading.Thread]: t for t in ( self._load_thread, + self._log_only_thread, self._canary_thread, self._injector_thread, self._log_scanner_thread, @@ -1691,10 +2101,17 @@ def _all_threads(self) -> list[threading.Thread]: ] def _teardown_cluster(self) -> None: - """Best-effort cluster shutdown via ``terminate()``. - - Stub: no-op since ``setup()`` doesn't actually launch yet. - """ + """Best-effort cluster shutdown via ``terminate()``.""" if self._cluster is None: return - logger.info("[harness] _teardown_cluster — stub") + from disagg_test_utils import terminate + + config, ctx_workers, gen_workers, disagg_server, _server_port, work_dir = self._cluster + del config + logger.info("[harness] tearing down disagg cluster work_dir=%s", work_dir) + try: + terminate(*ctx_workers, *gen_workers, disagg_server) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + self._cluster = None + self._server_log_path = None diff --git a/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py b/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py index 36d5b607d3bb..aa5c74b51779 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py +++ b/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py @@ -25,7 +25,11 @@ from __future__ import annotations +import textwrap +import threading +import time from pathlib import Path +from typing import Any import pytest @@ -57,29 +61,7 @@ def test_all_marathon_yamls_parse_and_validate() -> None: @pytest.mark.parametrize("config_filename", _MARATHON_CONFIGS) def test_disagg_cancellation_marathon(config_filename: str) -> None: - """Drive a long-running disagg cancellation marathon and assert pass criteria. - - Current scope: only what the already-implemented thread bodies - can contribute. The marathon entry point exists; the marathon - *content* lands incrementally as setup / pass-criteria wiring is - completed: - - - lifecycle plumbing (setup -> start -> wait -> stop -> - collect_results, fail-fast event propagation, dict-shape - contract). - - log-pattern fail-fast — a hard-zero pattern in any worker log - trips ``failure_reason`` via the log_scanner thread - (component-level coverage in ``test_log_scanner.py``). - - Marathon pass criteria not yet enforced here (will land alongside - their owning result aggregation in follow-up changes): canary error - rate, recovery time after each injection, KV-cache utilization - growth bound, injection-schedule completeness, sustained load - throughput. Until those land, this test passes trivially after - the lifecycle smoke completes; the value at this stage is that - the entry point and result-dict contract are pinned down so the - follow-up commits can extend in place rather than restructure. - """ + """Drive the configured disagg stress mode and assert current pass criteria.""" config_path = _CONFIG_DIR / config_filename assert config_path.exists(), ( f"Marathon config not found: {config_path}. " @@ -90,13 +72,8 @@ def test_disagg_cancellation_marathon(config_filename: str) -> None: try: harness.setup() harness.start() - # setup() is still a stub, so no server endpoint is bound. - # The load thread exits and signals ``stop_event`` on that - # no-endpoint path, which lets this lifecycle smoke complete - # almost instantly. Once setup launches a real cluster, the - # timeout becomes ``stress_config.duration_min`` plus a safety - # margin. - clean = harness.wait_until_done(timeout_s=10.0) + timeout_s = float(harness.config.duration_min) * 60.0 + 300.0 + clean = harness.wait_until_done(timeout_s=timeout_s) assert clean is True, ( f"wait_until_done did not return cleanly; failure_reason={harness.failure_reason!r}" ) @@ -112,5 +89,109 @@ def test_disagg_cancellation_marathon(config_filename: str) -> None: assert "kv_utilization_samples" in results assert "injection_events" in results assert results["failure_reason"] is None, ( - f"Harness tripped fail-fast in skeleton run: {results['failure_reason']!r}" + f"Harness tripped fail-fast: {results['failure_reason']!r}" + ) + if harness.config.is_log_only: + assert any( + record.get("mode") == "log_only" and record.get("success") + for record in results["canary_records"] + ), "log_only mode completed without a successful server probe" + + +def _write_mode_yaml(tmp_path: Path, stress_config: str) -> Path: + """Write a minimal marathon YAML for mode-level harness tests.""" + yaml_path = tmp_path / "mode.yaml" + content = textwrap.dedent( + """\ + hostname: localhost + model: dummy + backend: pytorch + context_servers: {} + generation_servers: {} + stress_config: + """ + ) + content += textwrap.indent(textwrap.dedent(stress_config).strip(), " ") + "\n" + yaml_path.write_text(content) + return yaml_path + + +@pytest.mark.parametrize("mode", ["log_only", "full_cancel_poison"]) +def test_stress_config_accepts_supported_modes(tmp_path: Path, mode: str) -> None: + """Both supported mode strings should parse and expose helper predicates.""" + cfg = StressConfig.from_yaml_path( + _write_mode_yaml( + tmp_path, + f"""\ + mode: {mode} + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + """, + ) + ) + + assert cfg.mode == mode + assert cfg.is_log_only is (mode == "log_only") + assert cfg.is_full_cancel_poison is (mode == "full_cancel_poison") + + +def test_stress_config_rejects_unknown_mode(tmp_path: Path) -> None: + """Typos in mode must fail during YAML validation.""" + with pytest.raises(ValueError, match="mode must be one of"): + StressConfig.from_yaml_path( + _write_mode_yaml( + tmp_path, + """\ + mode: accidental + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + """, + ) + ) + + +def test_log_only_thread_sends_probe_and_stops( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """The regular-protection mode should require at least one clean probe.""" + h = DisaggCancellationStressHarness( + _write_mode_yaml( + tmp_path, + """\ + mode: log_only + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + log_only_probe: + interval_s: 0.01 + max_tokens: 8 + request_timeout_s: 1 + log_scan: + hard_zero_patterns: + - "Broken promise" + """, + ), + load_duration_s=0.03, ) + h.bind_server_endpoint("http://127.0.0.1:8000", "test-model") + h._marathon_start_monotonic = time.monotonic() + + calls: list[dict[str, Any]] = [] + + def fake_probe(**kwargs: Any) -> tuple[list[int], None, None]: + calls.append(kwargs) + return [1, 2], None, None + + monkeypatch.setattr(h, "_send_log_only_probe", fake_probe) + + thread = threading.Thread(target=h._log_only_thread_body, name="test-log-only", daemon=True) + thread.start() + thread.join(timeout=2.0) + + assert not thread.is_alive() + assert h.stop_event.is_set() + assert not h.failed_event.is_set() + assert calls + assert any(record["success"] for record in h._canary_records) diff --git a/tests/integration/test_lists/qa/llm_function_stress.txt b/tests/integration/test_lists/qa/llm_function_stress.txt index 1bf7c4f3f77b..9d31b6149c46 100644 --- a/tests/integration/test_lists/qa/llm_function_stress.txt +++ b/tests/integration/test_lists/qa/llm_function_stress.txt @@ -5,6 +5,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-outp disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_eagle_trtllm_stress] disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_triton_stress] disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-qwen3_5_4b_fp8_stress] +stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon[marathon_cpp_v1_deepseek.yaml] TIMEOUT (45) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_fp8_8gpus accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_nvfp4_4gpus accuracy/test_llm_api_pytorch.py::TestKimiK2::test_nvfp4_longseq_trtllm_moe_stress From 03ed843cd76420985329d5689deef5fb6c928b1d Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 10 Jun 2026 10:09:52 -0700 Subject: [PATCH 3/5] [None][feat] Preserve cache_salt string in KV cache events (#13051) Signed-off-by: jthomson04 --- benchmarks/cpp/disaggServerBenchmark.cpp | 4 +- benchmarks/cpp/gptManagerBenchmark.cpp | 4 +- .../tensorrt_llm/batch_manager/blockKey.h | 12 +-- .../batch_manager/kvCacheManager.h | 1 - .../tensorrt_llm/batch_manager/llmRequest.h | 34 ++++---- cpp/include/tensorrt_llm/executor/executor.h | 15 ++-- cpp/include/tensorrt_llm/executor/types.h | 1 - cpp/include/tensorrt_llm/runtime/common.h | 1 - cpp/tensorrt_llm/batch_manager/blockKey.cpp | 10 +-- .../batch_manager/kvCacheEventManager.cpp | 3 +- .../batch_manager/kvCacheManager.cpp | 6 +- cpp/tensorrt_llm/executor/request.cpp | 14 +-- cpp/tensorrt_llm/executor/requestImpl.h | 31 +++++-- cpp/tensorrt_llm/executor/serialization.cpp | 17 ++-- .../nanobind/batch_manager/bindings.cpp | 20 ++--- .../nanobind/batch_manager/llmRequest.cpp | 4 +- .../nanobind/batch_manager/llmRequest.h | 3 +- .../nanobind/executor/bindings.cpp | 1 + .../nanobind/executor/request.cpp | 16 ++-- .../batch_manager/agentTreeTest.cpp | 2 +- .../batch_manager/kvCacheManagerTest.cpp | 52 ++++++------ .../executor/serializeUtilsTest.cpp | 2 +- .../cpp/executor/executorExampleKvEvents.cpp | 12 +-- examples/llm-api/llm_kv_cache_connector.py | 11 ++- .../connectors/kv_cache_connector.py | 6 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 31 ++++--- tensorrt_llm/_utils.py | 2 + tensorrt_llm/executor/base_worker.py | 2 +- tensorrt_llm/executor/executor.py | 9 +- tensorrt_llm/executor/request.py | 20 ++++- tensorrt_llm/inputs/__init__.py | 4 +- tensorrt_llm/inputs/utils.py | 14 +-- tensorrt_llm/llmapi/llm.py | 6 +- .../llmapi/test_llm_kv_cache_events.py | 85 ++++++++++++++++++- 35 files changed, 287 insertions(+), 170 deletions(-) diff --git a/benchmarks/cpp/disaggServerBenchmark.cpp b/benchmarks/cpp/disaggServerBenchmark.cpp index 057e7898afeb..bc3a7a2659fd 100644 --- a/benchmarks/cpp/disaggServerBenchmark.cpp +++ b/benchmarks/cpp/disaggServerBenchmark.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -543,7 +543,7 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, - std::nullopt); // cacheSaltID + std::nullopt); // cacheSalt request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY); return request; } diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index b4f0948c1155..287cbba343ce 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -838,7 +838,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, - std::nullopt); // cacheSaltID + std::nullopt); // cacheSalt } void benchmarkExecutor(std::optional const& decoderEngineDir, diff --git a/cpp/include/tensorrt_llm/batch_manager/blockKey.h b/cpp/include/tensorrt_llm/batch_manager/blockKey.h index 002b4356c869..920212845331 100644 --- a/cpp/include/tensorrt_llm/batch_manager/blockKey.h +++ b/cpp/include/tensorrt_llm/batch_manager/blockKey.h @@ -29,7 +29,6 @@ using VecTokens = std::vector; using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; -using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; using MmKey = tensorrt_llm::executor::MmKey; //! \brief Generate the multimodal extra keys for a single KV cache block. @@ -49,7 +48,8 @@ struct BlockKey // Extra keys for multimodal data (similar to VLLM's approach) // Each extra key is a pair of (mm_hash, start_offset_in_block) std::vector extraKeys; - std::optional cacheSaltID = std::nullopt; + // Cache salt string. Used as part of the block key so blocks from different salts do not match. + std::optional cacheSalt = std::nullopt; BlockKey() = default; @@ -64,12 +64,12 @@ struct BlockKey } explicit BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens, - std::vector extraKeys = {}, std::optional cacheSaltID = std::nullopt) + std::vector extraKeys = {}, std::optional cacheSalt = std::nullopt) : usesExtraIds{usesExtraIds} , loraTaskId{loraTaskId} , uniqueTokens{std::move(uniqueTokens)} , extraKeys{std::move(extraKeys)} - , cacheSaltID{cacheSaltID} + , cacheSalt{std::move(cacheSalt)} { } @@ -86,7 +86,7 @@ struct BlockKey } //! \brief Count the number of leading tokens that match between this key and \p other. - //! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSaltID differ, because those fields must + //! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSalt differ, because those fields must //! match exactly before token content is considered. //! \param other The key to compare against. //! \return Number of leading uniqueTokens that are identical in both keys. @@ -94,7 +94,7 @@ struct BlockKey { SizeType32 numMatched{0}; if (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && extraKeys == other.extraKeys - && cacheSaltID == other.cacheSaltID) + && cacheSalt == other.cacheSalt) { auto [matchEnd, otherMatchEnd] = std::mismatch( uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end()); diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d3966adf2f20..c665f7a8df95 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -82,7 +82,6 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; using BlocksPerWindow = std::map>; -using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; using MmKey = tensorrt_llm::executor::MmKey; using WindowSizeType = SizeType32; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 263a15b50970..886147a09c73 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -108,7 +108,6 @@ class GenericLlmRequest using MillisecondsType = std::chrono::milliseconds; using TimePoint = std::chrono::time_point; using Duration = std::chrono::time_point::duration; - using CacheSaltIDType = runtime::CacheSaltIDType; GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr const& inputTokens, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional endId = std::nullopt, @@ -147,11 +146,12 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional>> multimodalItemRunCuOffsets = std::nullopt, std::optional>> multimodalRunPositions = std::nullopt, - std::optional>> multimodalRunLengths = std::nullopt) + std::optional>> multimodalRunLengths = std::nullopt, + std::optional cacheSalt = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -213,7 +213,7 @@ class GenericLlmRequest , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(std::move(cacheSalt)) , mAgentHierarchy(std::move(agent_hierarchy)) { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) @@ -242,7 +242,7 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt) + std::optional cacheSalt = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens.size()) , mMaxNewTokens(maxNewTokens) @@ -283,7 +283,7 @@ class GenericLlmRequest , mContextPhaseParams(contextPhaseParams) , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(std::move(cacheSalt)) { if (mEncoderTokens.has_value()) { @@ -323,7 +323,7 @@ class GenericLlmRequest , mGuidedDecodingParams(req.getGuidedDecodingParams()) , mLanguageAdapterUid(req.getLanguageAdapterUid()) , mAllottedTimeMs(req.getAllottedTimeMs()) - , mCacheSaltID(req.getCacheSaltID()) + , mCacheSalt(req.getCacheSalt()) { if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY) { @@ -1897,9 +1897,9 @@ class GenericLlmRequest return mLanguageAdapterUid; } - [[nodiscard]] std::optional getCacheSaltID() const + [[nodiscard]] std::optional getCacheSalt() const { - return mCacheSaltID; + return mCacheSalt; } std::vector getLanguageAdapterRouting( @@ -2196,8 +2196,8 @@ class GenericLlmRequest bool mUseDraftModel{false}; - // Cache salt id for each request. - std::optional mCacheSaltID{std::nullopt}; + // Cache salt string. Used in BlockKey hashing/matching and surfaced in KV cache events. + std::optional mCacheSalt{std::nullopt}; std::optional>> mAgentHierarchy{std::nullopt}; @@ -2394,11 +2394,12 @@ class LlmRequest : public GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional> multimodalItemRunCuOffsets = std::nullopt, std::optional> multimodalRunPositions = std::nullopt, - std::optional> multimodalRunLengths = std::nullopt) + std::optional> multimodalRunLengths = std::nullopt, + std::optional cacheSalt = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), @@ -2431,8 +2432,8 @@ class LlmRequest : public GenericLlmRequest inputTokenExtraIds ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID, - arrivalTime, std::move(agent_hierarchy), + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, arrivalTime, + std::move(agent_hierarchy), multimodalItemRunCuOffsets.has_value() ? std::make_shared>(std::move(multimodalItemRunCuOffsets.value())) : std::optional>>(std::nullopt), @@ -2441,7 +2442,8 @@ class LlmRequest : public GenericLlmRequest : std::optional>>(std::nullopt), multimodalRunLengths.has_value() ? std::make_shared>(std::move(multimodalRunLengths.value())) - : std::optional>>(std::nullopt)) + : std::optional>>(std::nullopt), + std::move(cacheSalt)) { } diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 1f625d57084c..f716bef6e3cc 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -714,8 +714,9 @@ class Request /// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut /// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism /// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled. - /// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string. /// @param disaggRequestId Disaggregated request ID. + /// @param cacheSalt Optional cache salt string. If provided, KV cache blocks are tagged so reuse is limited to + /// requests with the same salt. The string is also surfaced in KV cache events. Defaults to std::nullopt. Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -743,8 +744,7 @@ class Request std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional cacheSaltID = std::nullopt, - std::optional disaggRequestId = std::nullopt); + std::optional disaggRequestId = std::nullopt, std::optional cacheSalt = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -792,7 +792,7 @@ class Request [[nodiscard]] std::optional getGuidedDecodingParams() const; [[nodiscard]] std::optional getLanguageAdapterUid() const; [[nodiscard]] std::optional getAllottedTimeMs() const; - [[nodiscard]] std::optional getCacheSaltID() const; + [[nodiscard]] std::optional getCacheSalt() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; [[nodiscard]] std::optional getDisaggRequestId() const; @@ -829,7 +829,7 @@ class Request void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams); void setLanguageAdapterUid(SizeType32 languageAdapterUid); void setAllottedTimeMs(MillisecondsType allottedTimeMs); - void setCacheSaltID(CacheSaltIDType cacheSaltID); + void setCacheSalt(std::optional cacheSalt); void setDisaggRequestId(IdType disaggRequestId); private: @@ -1729,13 +1729,14 @@ struct KVCacheStoredBlockData KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens tokens, std::optional loraId, SizeType32 cacheLevel, SizeType32 priority, - std::vector mmKeys = {}) + std::vector mmKeys = {}, std::optional cacheSalt = std::nullopt) : blockHash{blockHash} , tokens{std::move(tokens)} , loraId{loraId} , cacheLevel{cacheLevel} , priority{priority} , mmKeys{std::move(mmKeys)} + , cacheSalt{std::move(cacheSalt)} { } @@ -1751,6 +1752,8 @@ struct KVCacheStoredBlockData SizeType32 priority; /// @brief The multimodal keys of the block std::vector mmKeys; + /// @brief The original cache salt string of the block, if any + std::optional cacheSalt; }; struct KVCacheStoredData diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 0800865df7f1..2e6051291629 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -59,7 +59,6 @@ using RandomSeedType = std::uint64_t; using VecLogProbs = std::vector; using StreamPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; -using CacheSaltIDType = std::uint64_t; using LogitsPostProcessor = std::function)>; using LogitsPostProcessorMap = std::unordered_map; diff --git a/cpp/include/tensorrt_llm/runtime/common.h b/cpp/include/tensorrt_llm/runtime/common.h index 7a3079d0bd75..2cda8821c133 100644 --- a/cpp/include/tensorrt_llm/runtime/common.h +++ b/cpp/include/tensorrt_llm/runtime/common.h @@ -44,7 +44,6 @@ using TokenIdType = std::int32_t; using LoraTaskIdType = std::uint64_t; using TokenExtraIdType = std::uint64_t; using VecTokenExtraIds = std::vector; -using CacheSaltIDType = std::uint64_t; struct UniqueToken { diff --git a/cpp/tensorrt_llm/batch_manager/blockKey.cpp b/cpp/tensorrt_llm/batch_manager/blockKey.cpp index 33092a5a37fa..e8125b0106f4 100644 --- a/cpp/tensorrt_llm/batch_manager/blockKey.cpp +++ b/cpp/tensorrt_llm/batch_manager/blockKey.cpp @@ -334,7 +334,7 @@ std::vector buildBlockKeys( currentTokenIdx += uniqueTokens.size(); blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), - std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID()); + std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSalt()); } return blockKeys; } @@ -342,7 +342,7 @@ std::vector buildBlockKeys( bool BlockKey::operator==(BlockKey const& other) const noexcept { return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens - && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); + && extraKeys == other.extraKeys && cacheSalt == other.cacheSalt); } BlockKey BlockKey::shorten(int newNumTokens) const @@ -364,10 +364,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no // Constants provide very good distribution - each input bit affects each output bit with ~50% probability. size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9); - if (parentHash == 0 && blockKey.cacheSaltID) + if (parentHash == 0 && blockKey.cacheSalt) { - // Only hashing the cache salt ID for the first block in the sequence - uint64_t c = blockKey.cacheSaltID.value(); + // Only mix the cache salt into the hash for the first block in the sequence. + uint64_t c = static_cast(std::hash{}(blockKey.cacheSalt.value())); seed = hash64Mix(c, seed); } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index c8f6ddd474f4..2a986adf310f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -105,7 +105,8 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks for (auto const& block : blocks) { data.blocks.emplace_back(block->getHash(), block->getUniqueTokens(), block->getBlockKey().loraTaskId, - block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), block->getExtraKeys()); + block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), block->getExtraKeys(), + block->getBlockKey().cacheSalt); } enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 1a74166e5235..d1112439686e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2076,7 +2076,7 @@ std::shared_ptr WindowBlockManager::findBlocksInReuseTreeByBlockKe for (auto const& blockedUniqueTokensList : blockedUniqueTokens) { blockKeys.emplace_back(blockKey.usesExtraIds, blockKey.loraTaskId, blockedUniqueTokensList, blockKey.extraKeys, - blockKey.cacheSaltID); + blockKey.cacheSalt); } return searchReuseTree(blockKeys); } @@ -4460,7 +4460,7 @@ std::vector KVCacheManager::commitAndGetBlockHashesForRequest( bool const usesExtraIds = llmRequest.getInputTokensExtraIds().has_value(); auto const loraTaskId = llmRequest.getLoraTaskId(); - auto const cacheSaltID = llmRequest.getCacheSaltID(); + auto const cacheSalt = llmRequest.getCacheSalt(); std::vector hashes; hashes.reserve(static_cast(limit)); @@ -4476,7 +4476,7 @@ std::vector KVCacheManager::commitAndGetBlockHashesForRequest( SizeType32 const tokenEnd = tokenStart + tokensPerBlock; auto extraKeys = generateBlockHashExtraKeys(llmRequest, tokenStart, tokenEnd); VecUniqueTokens blockTokens(uniqueTokens.begin() + tokenStart, uniqueTokens.begin() + tokenEnd); - BlockKey blockKey(usesExtraIds, loraTaskId, std::move(blockTokens), std::move(extraKeys), cacheSaltID); + BlockKey blockKey(usesExtraIds, loraTaskId, std::move(blockTokens), std::move(extraKeys), cacheSalt); block->setBlockKey(blockKey, /*isFull=*/true); // setHash() chains through mPrevBlockInSeq, which was wired in addBlockToBeam. The // loop walks blocks in allocation order, so by the time we reach block b its diff --git a/cpp/tensorrt_llm/executor/request.cpp b/cpp/tensorrt_llm/executor/request.cpp index 5ac62d3fcb64..e32045892ba7 100644 --- a/cpp/tensorrt_llm/executor/request.cpp +++ b/cpp/tensorrt_llm/executor/request.cpp @@ -40,8 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, - std::optional allottedTimeMs, std::optional cacheSaltID, - std::optional disaggRequestId) + std::optional allottedTimeMs, std::optional disaggRequestId, + std::optional cacheSalt) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), @@ -50,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, - allottedTimeMs, cacheSaltID, disaggRequestId)) + allottedTimeMs, disaggRequestId, std::move(cacheSalt))) { } @@ -249,9 +249,9 @@ std::optional Request::getLanguageAdapterUid() const return mImpl->getLanguageAdapterUid(); } -std::optional Request::getCacheSaltID() const +std::optional Request::getCacheSalt() const { - return mImpl->getCacheSaltID(); + return mImpl->getCacheSalt(); } std::optional Request::getDisaggRequestId() const @@ -424,9 +424,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid) mImpl->setLanguageAdapterUid(languageAdapterUid); } -void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) +void Request::setCacheSalt(std::optional cacheSalt) { - mImpl->setCacheSaltID(cacheSaltID); + mImpl->setCacheSalt(std::move(cacheSalt)); } void Request::setDisaggRequestId(IdType disaggRequestId) diff --git a/cpp/tensorrt_llm/executor/requestImpl.h b/cpp/tensorrt_llm/executor/requestImpl.h index 281f81d462a7..55610885b1ae 100644 --- a/cpp/tensorrt_llm/executor/requestImpl.h +++ b/cpp/tensorrt_llm/executor/requestImpl.h @@ -32,6 +32,21 @@ class Request::Impl { public: + //! Maximum allowed length of a cache salt string. Cache salts are copied into every BlockKey and emitted + //! with KV cache events, so unbounded strings would inflate memory and serialization cost proportional to + //! the number of blocks. + static constexpr std::size_t kMaxCacheSaltLength{256}; + + static std::optional validateCacheSalt(std::optional cacheSalt) + { + if (cacheSalt.has_value() && cacheSalt->size() > kMaxCacheSaltLength) + { + TLLM_THROW("cacheSalt length (%zu) exceeds the maximum supported length (%zu).", cacheSalt->size(), + kMaxCacheSaltLength); + } + return cacheSalt; + } + Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig, OutputConfig outputConfig, std::optional const& endId, std::optional const& padId, std::optional> positionIds, std::optional> badWords, @@ -48,7 +63,7 @@ class Request::Impl std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, std::optional allottedTimeMs, - std::optional cacheSaltID, std::optional disaggRequestId) + std::optional disaggRequestId, std::optional cacheSalt = std::nullopt) : mInputTokenIds(std::move(inputTokenIds)) , mMaxNewTokens(maxNewTokens) , mStreaming(streaming) @@ -85,7 +100,7 @@ class Request::Impl , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(validateCacheSalt(std::move(cacheSalt))) , mDisaggRequestId(disaggRequestId) { validate(); @@ -298,9 +313,9 @@ class Request::Impl return mLanguageAdapterUid; } - [[nodiscard]] std::optional getCacheSaltID() const + [[nodiscard]] std::optional getCacheSalt() const { - return mCacheSaltID; + return mCacheSalt; } [[nodiscard]] std::optional getDisaggRequestId() const @@ -482,9 +497,9 @@ class Request::Impl mLanguageAdapterUid = languageAdapterUid; } - void setCacheSaltID(CacheSaltIDType cacheSaltID) + void setCacheSalt(std::optional cacheSalt) { - mCacheSaltID = cacheSaltID; + mCacheSalt = validateCacheSalt(std::move(cacheSalt)); } void setDisaggRequestId(IdType disaggRequestId) @@ -565,8 +580,8 @@ class Request::Impl lambda(mGuidedDecodingParams); lambda(mLanguageAdapterUid); lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt); - lambda(mCacheSaltID); lambda(mDisaggRequestId); + lambda(mCacheSalt); } VecTokens mInputTokenIds; @@ -605,7 +620,7 @@ class Request::Impl std::optional mGuidedDecodingParams; std::optional mLanguageAdapterUid; std::optional mAllottedTimeMs; - std::optional mCacheSaltID; + std::optional mCacheSalt; std::optional mDisaggRequestId; }; diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index e4c325423126..ce081e10c603 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -885,8 +885,8 @@ Request Serialization::deserializeRequest(std::istream& is) auto allottedTimeMs = allottedTimeInt ? std::optional(std::chrono::milliseconds(*allottedTimeInt)) : std::nullopt; - auto cacheSaltID = su::deserialize>(is); auto disaggRequestId = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), @@ -896,7 +896,7 @@ Request Serialization::deserializeRequest(std::istream& is) std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks), - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID, disaggRequestId); + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, disaggRequestId, std::move(cacheSalt)); } void Serialization::serialize(Request const& request, std::ostream& os) @@ -2517,6 +2517,7 @@ size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) totalSize += su::serializedSize(data.cacheLevel); totalSize += su::serializedSize(data.priority); totalSize += su::serializedSize(data.mmKeys); + totalSize += su::serializedSize(data.cacheSalt); return totalSize; } @@ -2528,6 +2529,7 @@ void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& su::serialize(data.cacheLevel, os); su::serialize(data.priority, os); su::serialize(data.mmKeys, os); + su::serialize(data.cacheSalt, os); } KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) @@ -2538,8 +2540,9 @@ KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::ist auto cacheLevel = su::deserialize(is); auto priority = su::deserialize(is); auto mmKeys = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); - return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority, mmKeys}; + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority, mmKeys, cacheSalt}; } // KVcacheRemovedData @@ -2686,7 +2689,7 @@ size_t Serialization::serializedSize(tensorrt_llm::batch_manager::kv_cache_manag totalSize += su::serializedSize(key.uniqueTokens); // std::vector where MmKey is pair, SizeType32> totalSize += su::serializedSize(key.extraKeys); - totalSize += su::serializedSize(key.cacheSaltID); + totalSize += su::serializedSize(key.cacheSalt); return totalSize; } @@ -2696,7 +2699,7 @@ void Serialization::serialize(tensorrt_llm::batch_manager::kv_cache_manager::Blo su::serialize(key.loraTaskId, os); su::serialize(key.uniqueTokens, os); su::serialize(key.extraKeys, os); - su::serialize(key.cacheSaltID, os); + su::serialize(key.cacheSalt, os); } tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deserializeBlockKey(std::istream& is) @@ -2705,13 +2708,13 @@ tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deseriali auto loraTaskId = su::deserialize>(is); auto uniqueTokens = su::deserialize>(is); auto extraKeys = su::deserialize>(is); - auto cacheSaltID = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); tensorrt_llm::batch_manager::kv_cache_manager::BlockKey key; key.usesExtraIds = usesExtraIds; key.loraTaskId = std::move(loraTaskId); key.uniqueTokens = std::move(uniqueTokens); key.extraKeys = std::move(extraKeys); - key.cacheSaltID = std::move(cacheSaltID); + key.cacheSalt = std::move(cacheSalt); return key; } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 5447f13c60e2..086cdf7547c2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -202,7 +202,7 @@ void initBindings(nb::module_& m) .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) .def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId) .def_prop_ro("is_child", &GenLlmReq::isChild) - .def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID) + .def_prop_ro("cache_salt", &GenLlmReq::getCacheSalt) .def_prop_ro("kv_cache_retention_config", &GenLlmReq::getKvCacheRetentionConfig) .def_prop_ro("multimodal_hashes", [](GenLlmReq& self) @@ -347,12 +347,12 @@ void initBindings(nb::module_& m) std::optional language_adapter_uid, std::optional allotted_time_ms, std::optional context_phase_params, - std::optional cache_salt_id, std::optional arrival_time, std::optional>> agent_hierarchy, std::optional> multimodal_item_run_cu_offsets, std::optional> multimodal_run_positions, - std::optional> multimodal_run_lengths) + std::optional> multimodal_run_lengths, + std::optional cache_salt) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -392,9 +392,9 @@ void initBindings(nb::module_& m) encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, - guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, - arrival_time, std::move(agent_hierarchy), multimodal_item_run_cu_offsets, multimodal_run_positions, - multimodal_run_lengths}; + guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, arrival_time, + std::move(agent_hierarchy), multimodal_item_run_cu_offsets, multimodal_run_positions, + multimodal_run_lengths, std::move(cache_salt)}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -420,10 +420,10 @@ void initBindings(nb::module_& m) nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, - nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt, - nb::arg("arrival_time") = std::nullopt, nb::arg("agent_hierarchy") = std::nullopt, - nb::arg("multimodal_item_run_cu_offsets") = std::nullopt, - nb::arg("multimodal_run_positions") = std::nullopt, nb::arg("multimodal_run_lengths") = std::nullopt) + nb::arg("context_phase_params") = std::nullopt, nb::arg("arrival_time") = std::nullopt, + nb::arg("agent_hierarchy") = std::nullopt, nb::arg("multimodal_item_run_cu_offsets") = std::nullopt, + nb::arg("multimodal_run_positions") = std::nullopt, nb::arg("multimodal_run_lengths") = std::nullopt, + nb::arg("cache_salt") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size")) .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index 796909bd419a..21f9ea39823b 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -127,11 +127,11 @@ std::shared_ptr LlmRequest::toTrtLlm() const mLanguageAdapterUid, // mAllottedTimeMs, // mContextPhaseParams, // - mCacheSaltID, // mPerfMetrics.timingMetrics.arrivalTime, // mAgentHierarchy, // mMultimodalItemRunCuOffsets, // mMultimodalRunPositions, // - mMultimodalRunLengths // + mMultimodalRunLengths, // + mCacheSalt // ); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 967870c8177c..387d915ab4aa 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -86,7 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional> multimodalItemRunCuOffsets = std::nullopt, std::optional> multimodalRunPositions = std::nullopt, @@ -155,7 +155,6 @@ class LlmRequest : public tb::GenericLlmRequest languageAdapterUid, // allottedTimeMs, // contextPhaseParams, // - cacheSaltID, // arrivalTime, // std::move(agent_hierarchy), // multimodalItemRunCuOffsets.has_value() diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index fbec513de3a1..b0ad31b7347e 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -230,6 +230,7 @@ void initBindings(nb::module_& m) .def_ro("lora_id", &tle::KVCacheStoredBlockData::loraId) .def_ro("cache_level", &tle::KVCacheStoredBlockData::cacheLevel) .def_ro("priority", &tle::KVCacheStoredBlockData::priority) + .def_ro("cache_salt", &tle::KVCacheStoredBlockData::cacheSalt) .def_prop_ro("mm_keys", [](tle::KVCacheStoredBlockData const& self) { diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 6cffe7740c13..b502370504d1 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -613,7 +613,7 @@ void initRequestBindings(nb::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId()); + self.getGuidedDecodingParams(), self.getDisaggRequestId(), self.getCacheSalt()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { @@ -642,7 +642,7 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[29]), 1, nb::cast>(state[30]), nb::cast>(state[31]), nb::cast>(state[32]), std::nullopt, std::nullopt, - nb::cast>(state[33]), nb::cast>(state[34])); + nb::cast>(state[33]), nb::cast>(state[34])); }; nb::class_ request(m, "Request", nb::dynamic_attr()); @@ -683,8 +683,8 @@ void initRequestBindings(nb::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional, // cacheSaltID - std::optional // disaggRequestId + std::optional, // disaggRequestId + std::optional // cacheSalt >(), // clang-format off nb::arg("input_token_ids"), @@ -724,9 +724,9 @@ void initRequestBindings(nb::module_& m) nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), - nb::arg("cache_salt_id") = nb::none(), - nb::arg("disagg_request_id") = nb::none() - ) // clang-format on + nb::arg("disagg_request_id") = nb::none(), + nb::arg("cache_salt") = nb::none() + ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -768,7 +768,7 @@ void initRequestBindings(nb::module_& m) .def_prop_rw( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) - .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) + .def_prop_rw("cache_salt", &tle::Request::getCacheSalt, &tle::Request::setCacheSalt) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def_prop_rw("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId) .def_prop_rw("priority", &tle::Request::getPriority, &tle::Request::setPriority) diff --git a/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp b/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp index 32e0b3e0ea17..5104f2a88a71 100644 --- a/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp @@ -64,7 +64,7 @@ class AgentTreeTest : public ::testing::Test std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, tensorrt_llm::executor::Request::kDefaultPriority, std::nullopt, std::nullopt, std::nullopt, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1, std::nullopt, std::nullopt, - false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, agentHierarchy); + false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, agentHierarchy); } LlmRequestPtr createAgentDeepResearchRequest(SizeType32 nodeId, SizeType32 requestId) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index bebc6dd920ac..bf1d7589ce31 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -2094,12 +2094,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } -TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) +TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltTest) { - // Test that cache_salt_id prevents KV cache reuse between requests with same tokens - // but different cache_salt_id values. + // Test that cache_salt prevents KV cache reuse between requests with same tokens + // but different cache_salt values. using VecTokenExtraIds = LlmRequest::VecTokenExtraIds; - using CacheSaltIDType = LlmRequest::CacheSaltIDType; auto constexpr numLayers = 12; auto constexpr numKvHeads = 6; @@ -2135,7 +2134,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto const inputLength = static_cast(inputTokens->size()); /////////////////////////////////////////////////////////////////////////// - // Test Case 1: Request without cache_salt_id + // Test Case 1: Request without cache_salt LlmRequest::RequestIdType requestId{0}; auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2143,8 +2142,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt); // No cache_salt_id + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -2177,21 +2176,22 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 2: Request with same tokens but with cache_salt_id = 12345 + // Test Case 2: Request with same tokens but with cache_salt = "tenant-A" requestId = 1; - CacheSaltIDType cacheSaltId1{12345}; + std::string const cacheSalt1{"tenant-A"}; auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId1); // With cache_salt_id = 12345 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt1); // With cache_salt = "tenant-A" GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should NOT reuse blocks despite same tokens, because cache_salt_id is different + // Should NOT reuse blocks despite same tokens, because cache_salt is different auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen1 = blockManager @@ -2215,7 +2215,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 3: Request with same tokens and same cache_salt_id = 12345 + // Test Case 3: Request with same tokens and same cache_salt = "tenant-A" requestId = 2; auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2223,12 +2223,13 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId1); // Same cache_salt_id = 12345 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt1); // Same cache_salt = "tenant-A" GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // SHOULD reuse blocks because both tokens and cache_salt_id match + // SHOULD reuse blocks because both tokens and cache_salt match auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen2 = blockManager @@ -2252,21 +2253,22 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 4: Request with same tokens but different cache_salt_id = 67890 + // Test Case 4: Request with same tokens but different cache_salt = "tenant-B" requestId = 3; - CacheSaltIDType cacheSaltId2{67890}; + std::string const cacheSalt2{"tenant-B"}; auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId2); // Different cache_salt_id = 67890 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt2); // Different cache_salt = "tenant-B" GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should NOT reuse blocks from any previous request because cache_salt_id is different + // Should NOT reuse blocks from any previous request because cache_salt is different auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen3 = blockManager @@ -2284,7 +2286,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); /////////////////////////////////////////////////////////////////////////// - // Test Case 5: Request without cache_salt_id again + // Test Case 5: Request without cache_salt again requestId = 4; auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2292,12 +2294,12 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt); // No cache_salt_id + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt GenerationRequest seq4{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id + // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen4 = blockManager diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index d80d0be456b4..98569bfc2aa0 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -1097,7 +1097,7 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras) VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}}; std::optional loraTaskId = LoraTaskIdType{42}; - // Note: cacheSaltID is intentionally not set since it is not serialized + // Note: cacheSalt is intentionally not set; round-tripping with it set is covered separately. BlockKey key(true, loraTaskId, uniqueTokens, extraKeys); testSerializeDeserialize(key); diff --git a/examples/cpp/executor/executorExampleKvEvents.cpp b/examples/cpp/executor/executorExampleKvEvents.cpp index a48cbdfa9769..ea1923294382 100644 --- a/examples/cpp/executor/executorExampleKvEvents.cpp +++ b/examples/cpp/executor/executorExampleKvEvents.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -50,13 +50,14 @@ struct RuntimeOptions struct KVCacheBlock { KVCacheBlock(size_t hash, int cacheLevel, int priority, std::optional loraId = std::nullopt, - std::shared_ptr prevBlock = nullptr); + std::shared_ptr prevBlock = nullptr, std::optional cacheSalt = std::nullopt); size_t hash; int cacheLevel; int priority; std::optional loraId; + std::optional cacheSalt; std::shared_ptr prevBlock; std::unordered_map> nextBlocks; @@ -196,12 +197,13 @@ RuntimeOptions parseArgs(int argc, char* argv[]) return runtimeOpts; } -KVCacheBlock::KVCacheBlock( - size_t hash, int cacheLevel, int priority, std::optional loraId, std::shared_ptr prevBlock) +KVCacheBlock::KVCacheBlock(size_t hash, int cacheLevel, int priority, std::optional loraId, + std::shared_ptr prevBlock, std::optional cacheSalt) : hash{hash} , cacheLevel{cacheLevel} , priority{priority} , loraId{loraId} + , cacheSalt{std::move(cacheSalt)} , prevBlock{prevBlock} , nextBlocks{} { @@ -255,7 +257,7 @@ void RadixTree::pollEvents() TLLM_CHECK(block.tokens.size() > 0); auto thisBlock = std::make_shared( - block.blockHash, block.cacheLevel, block.priority, block.loraId, prevBlock); + block.blockHash, block.cacheLevel, block.priority, block.loraId, prevBlock, block.cacheSalt); blockTable[block.blockHash] = thisBlock; // Link the parent to the new block diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 0e6aa3d83aa9..882478993e5a 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -192,7 +192,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): len(block_ids)): if len(chunks[block_pos]) == self.block_size: hashed_tokens = self._hash_tokens(chunks[block_pos], - req.cache_salt_id) + req.cache_salt) file_path = self._file_path(hashed_tokens) @@ -202,11 +202,10 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): return metadata - def _hash_tokens(self, tokens: list[int], - cache_salt_id: Optional[int]) -> int: - # cache_salt_id must participate in the hash so that requests carrying + def _hash_tokens(self, tokens: list[int], cache_salt: Optional[str]) -> int: + # cache_salt must participate in the hash so that requests carrying # different salts (or no salt) cannot collide on the same cache file. - return abs(hash((cache_salt_id, tuple(tokens)))) + return abs(hash((cache_salt, tuple(tokens)))) def _file_path(self, hash_value: int) -> Path: return Path(self.cache_folder) / f"{hash_value}.pt" @@ -238,7 +237,7 @@ def get_num_new_matched_tokens( for chunk in remaining_chunks: # Only do full blocks. if len(chunk) == self.block_size: - hashed_tokens = self._hash_tokens(chunk, request.cache_salt_id) + hashed_tokens = self._hash_tokens(chunk, request.cache_salt) file_path = self._file_path(hashed_tokens) diff --git a/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py index f5034256e142..99da2a42265c 100644 --- a/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py @@ -82,9 +82,9 @@ class RequestData: # Per-request cache salt that the KV cache manager uses to isolate reuse # between requests carrying different salts. Connectors that key cached # content on token sequences (e.g. by hashing tokens to a file path or - # remote object id) MUST mix cache_salt_id into their identifiers, + # remote object id) MUST mix cache_salt into their identifiers, # otherwise blocks from a different salt could be incorrectly reused. - cache_salt_id: Optional[int] = None + cache_salt: Optional[str] = None # A class to store some basic data regarding all inflight requests. @@ -361,7 +361,7 @@ def update_and_build_data(self, req: LlmRequest, kv_cache_manager: "KVCacheManag num_scheduled_tokens, block_hashes=block_hashes, priorities=priorities, - cache_salt_id=req.cache_salt_id, + cache_salt=req.cache_salt, ) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 21bdd6c0b3e4..a7d6614d6f4f 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1099,7 +1099,7 @@ def executor_request_to_llm_request( priority=executor_request.priority, llm_request_type=llm_request_type, context_phase_params=executor_request.context_phase_params, - cache_salt_id=executor_request.cache_salt_id, + cache_salt=executor_request.cache_salt, arrival_time=getattr(executor_request, "py_arrival_time", None), py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index d8e948d15689..c67bec2f22b9 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -3,6 +3,7 @@ import copy import enum +import hashlib import math import os from abc import ABC, abstractmethod @@ -2996,11 +2997,10 @@ def prepare_context(self, req: LlmRequest) -> bool: all_tokens, req, end=len(all_tokens) - 1) else: tokens = None - kv_cache = self._create_kv_cache( - req.py_request_id, - req.lora_task_id, - tokens, - cache_salt_id=req.cache_salt_id) + kv_cache = self._create_kv_cache(req.py_request_id, + req.lora_task_id, + tokens, + cache_salt=req.cache_salt) if kv_cache is None: return False kv_cache.cuda_stream = self._stream.cuda_stream @@ -3123,11 +3123,10 @@ def _prepare_draft_resources(self, scheduled_batch: ScheduledRequests): for req in scheduled_batch.context_requests: kv_cache = self.kv_cache_map.get(req.py_request_id) if kv_cache is None: - kv_cache = self._create_kv_cache( - req.py_request_id, - req.lora_task_id, - None, - cache_salt_id=req.cache_salt_id) + kv_cache = self._create_kv_cache(req.py_request_id, + req.lora_task_id, + None, + cache_salt=req.cache_salt) kv_cache.stop_committing() if not self._resume_and_restore(req.py_request_id, kv_cache): raise RuntimeError( @@ -3295,7 +3294,7 @@ def release_resources(current_request: LlmRequest, if prepare_resource: # Dummy/warmup request. ``stop_committing()`` below blocks all # writes to the radix tree, so the choice of branch does not - # affect committed state. ``cache_salt_id`` is left defaulted + # affect committed state. ``cache_salt`` is left defaulted # to None to avoid coupling synthetic data to any salted branch. kv_cache = self._create_kv_cache(req.py_request_id, req.lora_task_id, input_tokens) @@ -3674,7 +3673,7 @@ def _create_kv_cache(self, request_id: int, lora_task_id: int | None, input_tokens: Sequence[TokenIdExt] | None, - cache_salt_id: int | None = None): + cache_salt: str | None = None): assert request_id not in self.kv_cache_map, f"KV cache for request {request_id} already exists" if self.index_mapper.num_free_slots() == 0: logger.warning( @@ -3683,8 +3682,14 @@ def _create_kv_cache(self, "Skipping KV cache creation; request will retry next iteration.", request_id, self.index_mapper.size(), self.index_mapper.size()) return None + # ReuseScope.salt is int|None; derive a deterministic int from the + # cache_salt string so the same string yields the same reuse namespace + # across processes (matches C++ blockKey hashing on cacheSalt). + salt_int = (int.from_bytes( + hashlib.sha256(cache_salt.encode("utf-8")).digest()[:8], "little") + if cache_salt is not None else None) kv_cache = self.impl.create_kv_cache( - ReuseScope(lora_id=lora_task_id, salt=cache_salt_id), + ReuseScope(lora_id=lora_task_id, salt=salt_int), input_tokens, ) self.kv_cache_map[request_id] = kv_cache diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 841f396104f5..955a0ab55082 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -1212,6 +1212,8 @@ def _stored_block_to_json(data): for token in data.tokens ], # "lora_id": data.lora_id, # TODO (shreyasm): enable serialization of lora_id + "cache_salt": + data.cache_salt, "cache_level": data.cache_level, "priority": diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index e42f2b25b24c..090742b38d19 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -594,8 +594,8 @@ def _deduce_max_tokens(request: GenerationRequest, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type, - cache_salt_id=request.cache_salt_id, disagg_request_id=disagg_request_id, + cache_salt=request.cache_salt, priority=request.priority) executor_request.py_original_end_id = request.sampling_params.end_id executor_request.py_num_logprobs = request.sampling_params.logprobs diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index a5f5efea6427..5ea904531f22 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -134,7 +134,7 @@ def generate_async( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, - cache_salt_id: Optional[int] = None, + cache_salt: Optional[str] = None, arrival_time: Optional[float] = None, priority: float = DEFAULT_REQUEST_PRIORITY, ) -> GenerationResult: @@ -162,7 +162,7 @@ def generate_async( trace_headers=trace_headers, multimodal_params=multimodal_params, scheduling_params=scheduling_params, - cache_salt_id=cache_salt_id, + cache_salt=cache_salt, arrival_time=arrival_time, priority=priority) result = self.submit(request) @@ -180,6 +180,7 @@ def generate( prompt_adapter_request: Optional[Union[ PromptAdapterRequest, List[PromptAdapterRequest]]] = None, disaggregated_params: Optional[DisaggregatedParams] = None, + cache_salt: Optional[Union[str, List[Optional[str]]]] = None, ) -> Union[GenerationResult, List[GenerationResult]]: """Generate output for the given prompt token ids in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts. @@ -205,6 +206,7 @@ def generate( pa_req = prompt_adapter_request[i] else: pa_req = prompt_adapter_request + cs = cache_salt[i] if isinstance(cache_salt, list) else cache_salt future = self.generate_async( p, sampling_params=sp, @@ -212,7 +214,8 @@ def generate( lora_request=lora_req, prompt_adapter_request=pa_req, streaming=False, - disaggregated_params=disaggregated_params) + disaggregated_params=disaggregated_params, + cache_salt=cs) futures.append(future) for future in futures: diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index adbc2358b372..43ea11706e54 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -89,6 +89,8 @@ def local_path(self): class GenerationRequest: + # Mirrors C++ Request::Impl::kMaxCacheSaltLength + MAX_CACHE_SALT_LEN: int = 256 def __init__( self, @@ -105,7 +107,7 @@ def __init__( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, - cache_salt_id: Optional[int] = None, + cache_salt: Optional[str] = None, arrival_time: Optional[float] = None, priority: float = DEFAULT_REQUEST_PRIORITY, ): @@ -134,7 +136,21 @@ def __init__( self.disaggregated_params = disaggregated_params self.trace_headers = trace_headers self.scheduling_params = scheduling_params - self.cache_salt_id = cache_salt_id + if cache_salt is not None: + if not isinstance(cache_salt, str): + raise TypeError( + f"cache_salt must be str or None, got {type(cache_salt).__name__}" + ) + # The C++ side validates against UTF-8 byte length, so do the same here + # (Python `len()` would count Unicode code points, which can pass this + # guard but fail at C++ dispatch for non-ASCII salts). + cache_salt_byte_len = len(cache_salt.encode("utf-8")) + if cache_salt_byte_len > self.MAX_CACHE_SALT_LEN: + raise ValueError( + f"cache_salt UTF-8 byte length ({cache_salt_byte_len}) " + f"exceeds the maximum supported length " + f"({self.MAX_CACHE_SALT_LEN}).") + self.cache_salt = cache_salt self.arrival_time = arrival_time if not (0.0 <= priority <= 1.0): raise ValueError( diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 4f2d4cc7e99b..3b9a5d51d052 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -24,8 +24,7 @@ async_load_audio, async_load_image, async_load_video, convert_image_mode, default_multimodal_input_loader, encode_base64_content_from_url, encode_base64_image, - get_cache_salt_id, load_base64_image_embeds, load_image, - load_video) + load_base64_image_embeds, load_image, load_video) # yapf: enable @@ -69,7 +68,6 @@ "encode_base64_image", "load_image", "load_video", - "get_cache_salt_id", "compute_retained_tokens_count", "compute_retained_tokens_from_tubelet_budget", "compute_retention_mask", diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 59483e625790..eaa98cad44ef 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -27,8 +27,7 @@ _safe_aiohttp_get, _safe_request_get) from tensorrt_llm.inputs.media_io import \ convert_image_mode as convert_image_mode -from tensorrt_llm.inputs.multimodal import (MultimodalServerConfig, - default_hasher) +from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.inputs.multimodal_data import \ BaseModalityData as BaseModalityData from tensorrt_llm.inputs.multimodal_data import VideoData as VideoData @@ -894,14 +893,3 @@ def convert_to_conversation_message( inputs.append(input) return inputs - - -def get_cache_salt_id(cache_salt: str) -> int: - b = cache_salt.encode("utf-8") - h = default_hasher(b).digest(length=8) - cache_salt_id = int.from_bytes(h, "little", signed=False) - if cache_salt_id < 0 or cache_salt_id >= (1 << 64): - raise ValueError( - f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.") - - return cache_salt_id diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index d8c471d8624e..f99aa1593c9f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -36,7 +36,7 @@ from ..executor.utils import (RequestError, create_mpi_comm_session, get_spawn_proxy_process_env) from ..inputs import (PromptInputs, create_input_processor, - create_input_processor_with_hash, get_cache_salt_id, + create_input_processor_with_hash, maybe_compute_mm_embed_cumsum, prompt_inputs) from ..logger import logger from ..sampling_params import SamplingParams @@ -476,8 +476,6 @@ def generate_async( sampling_params = self._prepare_sampling_params(sampling_params) - cache_salt_id = get_cache_salt_id( - cache_salt) if cache_salt is not None else None # With pytorch backend, py_executor has logic to handle max_tokens of 1, # so set to 1 to avoid allocating unnecessary KV cache blocks for single request # TODO: Also support for trt backend @@ -520,7 +518,7 @@ def generate_async( postproc_params=_postproc_params, multimodal_params=multimodal_params, scheduling_params=scheduling_params, - cache_salt_id=cache_salt_id, + cache_salt=cache_salt, arrival_time=arrival_time, priority=priority, ) diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 5831c5b7a57a..ee002905d26c 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -109,6 +109,9 @@ def test_kv_cache_event_data_serialization(): # Verify mm_keys field exists (empty for text-only requests) assert "mm_keys" in serialized_event[0]["data"]["blocks"][0] assert serialized_event[0]["data"]["blocks"][0]["mm_keys"] == [] + # Verify cache_salt field exists (None for unsalted requests) + assert "cache_salt" in serialized_event[0]["data"]["blocks"][0] + assert serialized_event[0]["data"]["blocks"][0]["cache_salt"] is None req2 = create_llm_request(1, [1, 2, 3, 4, 5]) kv_cache_manager.impl.add_sequence_batch( @@ -779,7 +782,7 @@ def test_mm_keys_in_stored_events(): events = llm.get_kv_cache_events(5) - # Find stored events and verify mm_keys field + # Find stored events and verify mm_keys and cache_salt fields for event in events: if event and event["data"]["type"] == "stored": blocks = event["data"]["blocks"] @@ -789,6 +792,86 @@ def test_mm_keys_in_stored_events(): assert isinstance(block["mm_keys"], list) # For text-only requests, mm_keys should be empty assert block["mm_keys"] == [] + # cache_salt should be present (None for unsalted requests) + assert "cache_salt" in block + assert block["cache_salt"] is None + + +def test_cache_salt_in_stored_events(): + """Test that cache_salt string is preserved in stored block events.""" + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, temperature=0.01) + prompt = "Hello, my name is" + + _ = llm.generate(prompt, + sampling_params=sampling_params, + cache_salt="tenant-A") + + events = llm.get_kv_cache_events(5) + + # Find stored events and verify cache_salt field + found_stored = False + for event in events: + if event and event["data"]["type"] == "stored": + found_stored = True + blocks = event["data"]["blocks"] + for block in blocks: + assert "cache_salt" in block + assert block["cache_salt"] == "tenant-A" + + assert found_stored, "No stored events found" + + +def test_cache_salt_max_length_validation(): + """cache_salt longer than MAX_CACHE_SALT_LEN UTF-8 bytes is rejected.""" + from tensorrt_llm.executor.request import GenerationRequest + + max_len = GenerationRequest.MAX_CACHE_SALT_LEN + sampling_params = SamplingParams() + + # ASCII salt at the limit is accepted. + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt="a" * max_len) + + # ASCII salt one byte over the limit is rejected. + with pytest.raises(ValueError, match="cache_salt UTF-8 byte length"): + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt="a" * (max_len + 1)) + + # Non-ASCII salt: each character is 3 UTF-8 bytes. A salt whose + # `len()` is below the limit but whose UTF-8 byte count exceeds it + # must be rejected (this is the case Python's len()-based check missed). + char_count = (max_len // 3) + 1 # len() is well below max_len + salt = "中" * char_count # Chinese character, 3 UTF-8 bytes each + assert len(salt) <= max_len + assert len(salt.encode("utf-8")) > max_len + with pytest.raises(ValueError, match="cache_salt UTF-8 byte length"): + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt=salt) + + +def test_non_ascii_cache_salt_in_stored_events(): + """Test that a non-ASCII cache_salt string is preserved in stored block events.""" + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, temperature=0.01) + prompt = "Hello, my name is" + salt = "tenant-中文" # mixed ASCII + Chinese + + _ = llm.generate(prompt, sampling_params=sampling_params, cache_salt=salt) + + events = llm.get_kv_cache_events(5) + + found_stored = False + for event in events: + if event and event["data"]["type"] == "stored": + found_stored = True + for block in event["data"]["blocks"]: + assert block.get("cache_salt") == salt + + assert found_stored, "No stored events found" def test_expected_kv_cache_events(): From 309c76423b294c9bf6b7ec4a738e210dd013cc8f Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Wed, 10 Jun 2026 12:35:07 -0700 Subject: [PATCH 4/5] [https://nvbugs/6104831][fix] Port dataTransceiver shared_ptr lifetime fix (#14979) Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../batch_manager/cacheTransceiver.cpp | 8 +- .../batch_manager/dataTransceiver.cpp | 90 +++++++++++++------ .../batch_manager/dataTransceiver.h | 12 +-- .../multi_gpu/cacheTransceiverTest.cpp | 17 ++-- 4 files changed, 82 insertions(+), 45 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index de146a7652de..ed37464bb2aa 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -394,7 +394,7 @@ void CacheTransceiver::respondAndSendAsync(std::shared_ptr llmReques return; } setContextState(llmRequest.get()); - auto future = mCacheSender->sendAsync(*llmRequest); + auto future = mCacheSender->sendAsync(llmRequest); mSenderFutures.emplace_back(std::move(llmRequest), std::move(future)); } @@ -410,7 +410,7 @@ void CacheTransceiver::respondAndSendLayerWise( llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); - auto future = mCacheSender->sendAsync(*llmRequest); + auto future = mCacheSender->sendAsync(llmRequest); mSenderFutures.emplace_back(llmRequest, std::move(future)); } } @@ -419,7 +419,7 @@ void CacheTransceiver::requestAndReceiveSync(std::shared_ptr llmRequ { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); { - auto future = mCacheReceiver->receiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(llmRequest); future.get(); } llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); @@ -438,7 +438,7 @@ void CacheTransceiver::requestAndReceiveAsync(std::shared_ptr llmReq return; } - auto future = mCacheReceiver->receiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(llmRequest); auto* requestPtr = llmRequest.get(); mRequesterFutures.emplace_back(std::move(llmRequest), std::move(future)); requestPtr->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 13e95dd86e48..853866687d53 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -296,16 +296,16 @@ class CacheSender::Impl } } - [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future sendAsync(std::shared_ptr const& llmRequest) { + TLLM_CHECK(llmRequest != nullptr); std::promise promise; auto future = promise.get_future(); - llmRequest.setKvCacheTransferStart(LlmRequest::getSteadyClockNow()); + llmRequest->setKvCacheTransferStart(LlmRequest::getSteadyClockNow()); { { std::scoped_lock lkResp(mSenderMutex); - mReadyResponses.emplace( - llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); + mReadyResponses.emplace(llmRequest->mRequestId, Response{llmRequest, std::move(promise)}); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; @@ -477,7 +477,9 @@ class CacheSender::Impl private: struct Response { - LlmRequest* mRequest; + // shared_ptr so this struct co-owns the request until the promise resolves; + // protects worker-side dereferences and the promise itself from premature destruction. + std::shared_ptr mRequest; std::promise mPromise; }; @@ -511,7 +513,12 @@ class CacheSender::Impl resp = std::move(resource.mSendQueue.front()); resource.mSendQueue.pop_front(); } - sendAndRemoveResponse(resp.mRequest->mRequestId, std::move(resp)); + // Sequence the read before the move: argument initializations + // are indeterminately sequenced, so inlining resp.mRequest->... + // alongside std::move(resp) is UB once mRequest is a shared_ptr. + TLLM_CHECK(resp.mRequest != nullptr); + auto const reqId = resp.mRequest->mRequestId; + sendAndRemoveResponse(reqId, std::move(resp)); } } @@ -584,14 +591,23 @@ class CacheSender::Impl { // TODO: if the generation does not require the kv cache, the request will // not be removed from mCancelledRequests. This should be handled by timeout. - auto it = mReadyResponses.find(mCurrentRequest.value()); - TLLM_CHECK(it != mReadyResponses.end()); + auto const cancelledReqId = mCurrentRequest.value(); + Response cancelledResponse; { std::scoped_lock lkResp(mSenderMutex); + auto it = mReadyResponses.find(cancelledReqId); + TLLM_CHECK(it != mReadyResponses.end()); + // Move out before erasing so the promise survives the + // map cleanup and can be resolved (vs. destroyed unfulfilled, + // which would surface as std::future_error: Broken promise). + cancelledResponse = std::move(it->second); mReadyResponses.erase(it); - mCancelledRequests.erase(mCurrentRequest.value()); - mRemainSendCount.erase(mCurrentRequest.value()); + mCancelledRequests.erase(cancelledReqId); + mRemainSendCount.erase(cancelledReqId); } + cancelledResponse.mPromise.set_exception(std::make_exception_ptr( + TLLM_REQUEST_EXCEPTION(cancelledReqId, common::RequestErrorCode::kNETWORK_ERROR, + "KV cache transfer for request %zu was cancelled", cancelledReqId))); mCurrentRequest = std::nullopt; if (mReadyResponses.empty()) @@ -762,23 +778,27 @@ class CacheReceiver::Impl TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } - [[nodiscard]] std::future receiveAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future receiveAsync(std::shared_ptr const& llmRequest) { + TLLM_CHECK(llmRequest != nullptr); // TODO: Modify the implementation here to avoid frequent thread creation. - return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest)); + // Capture by value so the async task owns a strong reference for its lifetime. + auto llmRequestCopy = llmRequest; + return std::async(std::launch::async, [this, llmRequestCopy]() { requestSync(*llmRequestCopy); }); } - [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest) + [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(std::shared_ptr const& llmRequest) { + TLLM_CHECK(llmRequest != nullptr); try { auto promise = std::make_unique>(); auto future = promise->get_future(); - TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); + TLLM_CHECK(llmRequest->getDataTransceiverState().getCommState().has_value()); std::string processInfo = kDefaultProcessInfo; if (common::getEnvRequestKVCacheConcurrent()) { - processInfo = llmRequest.getDataTransceiverState().getCommState()->toString(); + processInfo = llmRequest->getDataTransceiverState().getCommState()->toString(); } if (mInstanceToAsyncResource.find(processInfo) == mInstanceToAsyncResource.end()) { @@ -791,7 +811,7 @@ class CacheReceiver::Impl auto& asyncResource = mInstanceToAsyncResource.at(processInfo); { std::unique_lock lck(asyncResource->mMtxForQueue); - asyncResource->mRequestsQueue.emplace_back(std::addressof(llmRequest), std::move(promise)); + asyncResource->mRequestsQueue.emplace_back(llmRequest, std::move(promise)); } asyncResource->mCVforQueue.notify_all(); return future; @@ -1003,6 +1023,23 @@ class CacheReceiver::Impl { return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; }); if (it != asyncResource->mRequestsQueue.end()) { + // Resolve the promise before erasing so the future returned by + // receiveAsync surfaces a structured cancellation error rather + // than std::future_error: Broken promise from the destroyed promise. + if (it->mPromise) + { + try + { + it->mPromise->set_exception(std::make_exception_ptr( + TLLM_REQUEST_EXCEPTION(llmRequest.mRequestId, common::RequestErrorCode::kNETWORK_ERROR, + "Generation KV cache request cancelled before send for request %zu", + llmRequest.mRequestId))); + } + catch (std::future_error const&) + { + // Promise already satisfied; nothing to do. + } + } asyncResource->mRequestsQueue.erase(it); isCancelled = true; } @@ -1083,7 +1120,9 @@ class CacheReceiver::Impl struct RequestAndPromise { - LlmRequest* mRequest; + // shared_ptr so this struct co-owns the request until the promise resolves; + // protects worker-side dereferences and the promise itself from premature destruction. + std::shared_ptr mRequest; std::unique_ptr> mPromise; RequestAndPromise() @@ -1092,8 +1131,8 @@ class CacheReceiver::Impl { } - RequestAndPromise(LlmRequest* request, std::unique_ptr>&& promise) - : mRequest(request) + RequestAndPromise(std::shared_ptr request, std::unique_ptr>&& promise) + : mRequest(std::move(request)) , mPromise(std::move(promise)) { } @@ -1101,26 +1140,23 @@ class CacheReceiver::Impl RequestAndPromise(RequestAndPromise const&) = delete; RequestAndPromise(RequestAndPromise&& other) noexcept - : mRequest(other.mRequest) + : mRequest(std::move(other.mRequest)) , mPromise(std::move(other.mPromise)) { - other.mRequest = nullptr; } RequestAndPromise& operator=(RequestAndPromise&& other) noexcept { if (this != &other) { - mRequest = nullptr; + mRequest.reset(); if (mPromise) { mPromise.reset(); } - mRequest = other.mRequest; + mRequest = std::move(other.mRequest); mPromise = std::move(other.mPromise); - - other.mRequest = nullptr; } return *this; } @@ -1228,7 +1264,7 @@ CacheSender::CacheSender( { } -std::future CacheSender::sendAsync(LlmRequest& llmRequest) const +std::future CacheSender::sendAsync(std::shared_ptr const& llmRequest) const { return mImpl->sendAsync(llmRequest); } @@ -1277,7 +1313,7 @@ CacheReceiver::CacheReceiver( { } -std::future CacheReceiver::receiveAsync(LlmRequest& llmRequest) const +std::future CacheReceiver::receiveAsync(std::shared_ptr const& llmRequest) const { return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest); } diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 14d0ba8c6e5d..3362574da902 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -256,10 +256,10 @@ class CacheSender CacheSender() = default; /// @brief Asynchronously respond to the request and send data. - /// @param llmRequest Request object. Its data should be ready when called, and the data for this request - /// should remain valid until future synchronization. + /// @param llmRequest Request object. Its data should be ready when called. shared_ptr so the async send + /// worker can extend the request's lifetime past the caller's reference. /// @return Once the data is fully sent, the future object will become valid. - [[nodiscard]] virtual std::future sendAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future sendAsync(std::shared_ptr const& llmRequest) const; /// @brief Return the internal communicator status. /// @return The communicator status. @@ -319,10 +319,10 @@ class CacheReceiver CacheReceiver() = default; /// @brief Asynchronously send a request to receive data. - /// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the - /// data for this request should remain intact only after future synchronization. + /// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called. + /// shared_ptr so the async receive worker can extend the request's lifetime past the caller's reference. /// @return Once the data is fully received, the future object will become valid. - [[nodiscard]] virtual std::future receiveAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future receiveAsync(std::shared_ptr const& llmRequest) const; virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest); diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 4514dfe26030..d1ca104cca1a 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -348,11 +348,11 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- TLLM_CUDA_CHECK(cudaMemset(it->data(), llmRequest->getPromptLen(), it->getSizeInBytes())); } } - mFutures.emplace_back(mSender->sendAsync(*llmRequest)); + mFutures.emplace_back(mSender->sendAsync(llmRequest)); } else { - auto future = mRequester->receiveAsync(*llmRequest); + auto future = mRequester->receiveAsync(llmRequest); future.get(); TLLM_CUDA_CHECK(cudaDeviceSynchronize()); auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); @@ -468,12 +468,13 @@ struct CPMetaData struct WrappedLlmRequest { - std::unique_ptr mLlmRequest; + // shared_ptr to match CacheSender::sendAsync / CacheReceiver::receiveAsync signatures. + std::shared_ptr mLlmRequest; std::optional mCPMetaData; using RequestIdType = LlmRequest::RequestIdType; - WrappedLlmRequest(std::unique_ptr llmRequest, std::optional cpMetaData) + WrappedLlmRequest(std::shared_ptr llmRequest, std::optional cpMetaData) : mLlmRequest(std::move(llmRequest)) , mCPMetaData(std::move(cpMetaData)) { @@ -887,7 +888,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(mRequestId++, std::move(request)); + auto llmRequestPtr = std::make_shared(mRequestId++, std::move(request)); return std::make_unique(std::move(llmRequestPtr), cpMetaData); } @@ -919,7 +920,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamsetCacheState(cacheState); auto stats = texec::ContextPhaseParams({}, requestId, state.release(), std::nullopt); request.setContextPhaseParams(std::move(stats)); - auto llmRequestPtr = std::make_unique(requestId, std::move(request)); + auto llmRequestPtr = std::make_shared(requestId, std::move(request)); return std::make_unique(std::move(llmRequestPtr), cpMetaData); } @@ -973,7 +974,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamsendAsync(*llmRequest); + auto future = mSender->sendAsync(llmRequest); return future; } @@ -984,7 +985,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParammLlmRequest; mManager->addSequenceBatch( {{{llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth}}}, {std::ref(*llmRequest)}); - return mRequester->receiveAsync(*llmRequest); + return mRequester->receiveAsync(llmRequest); } void generationVerifyKVCache(std::shared_ptr const& request) From 7301075055e9b943cbdecb45446d208737b93d78 Mon Sep 17 00:00:00 2001 From: Bala Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> Date: Wed, 10 Jun 2026 12:49:06 -0700 Subject: [PATCH 5/5] [None][fix] Fix AutoDeploy transform docs generation (#15228) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com> --- docs/source/_ext/trtllm_auto_deploy.py | 252 ++++++++++++++++++++----- 1 file changed, 202 insertions(+), 50 deletions(-) diff --git a/docs/source/_ext/trtllm_auto_deploy.py b/docs/source/_ext/trtllm_auto_deploy.py index 44943432b842..e38f2f39d535 100644 --- a/docs/source/_ext/trtllm_auto_deploy.py +++ b/docs/source/_ext/trtllm_auto_deploy.py @@ -4,19 +4,23 @@ from __future__ import annotations import ast -import pkgutil from dataclasses import dataclass from pathlib import Path import yaml from docutils import nodes from docutils.statemachine import StringList +from docutils.utils import SystemMessage from sphinx.application import Sphinx +from sphinx.errors import SphinxError +from sphinx.util import logging from sphinx.util.docutils import SphinxDirective from sphinx.util.nodes import nested_parse_with_titles -AUTO_DEPLOY_TRANSFORM_LIBRARY_PACKAGE = "tensorrt_llm._torch.auto_deploy.transform.library" -AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH = Path("tensorrt_llm/_torch/auto_deploy/transform/library") +LOGGER = logging.getLogger(__name__) + +AUTO_DEPLOY_TRANSFORM_PACKAGE = "tensorrt_llm._torch.auto_deploy.transform" +AUTO_DEPLOY_TRANSFORM_PATH = Path("tensorrt_llm/_torch/auto_deploy/transform") AUTO_DEPLOY_TRANSFORM_CONFIGS = ( ("graph", Path("tensorrt_llm/_torch/auto_deploy/config/default.yaml")), ( @@ -67,31 +71,31 @@ @dataclass(frozen=True) class RegisteredTransform: key: str + package_name: str module_name: str class_name: str config_class_name: str + config_package_name: str | None config_module_name: str | None @property def qualified_class_name(self) -> str: - return f"{AUTO_DEPLOY_TRANSFORM_LIBRARY_PACKAGE}.{self.module_name}.{self.class_name}" + return f"{self.package_name}.{self.module_name}.{self.class_name}" @property def qualified_module_name(self) -> str: - return f"{AUTO_DEPLOY_TRANSFORM_LIBRARY_PACKAGE}.{self.module_name}" + return f"{self.package_name}.{self.module_name}" @property def qualified_config_class_name(self) -> str | None: if self.config_module_name is None: return None - return ( - f"{AUTO_DEPLOY_TRANSFORM_LIBRARY_PACKAGE}.{self.config_module_name}" - f".{self.config_class_name}" - ) + return f"{self.config_package_name}.{self.config_module_name}.{self.config_class_name}" @dataclass(frozen=True) class ParsedClass: + package_name: str module_name: str class_name: str base_class_names: tuple[str, ...] @@ -110,22 +114,43 @@ def _repo_root_from_source_dir(source_dir: str) -> Path: """Return the nearest ancestor that contains the AutoDeploy transform library.""" source_path = Path(source_dir).resolve() for path in (source_path, *source_path.parents): - if (path / AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH).is_dir(): + if (path / AUTO_DEPLOY_TRANSFORM_PATH).is_dir(): return path raise FileNotFoundError( - f"Could not find repository root containing {AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH}" + f"Could not find repository root containing {AUTO_DEPLOY_TRANSFORM_PATH}" + ) + + +def _transform_sources(repo_root: Path) -> tuple[tuple[str, Path], ...]: + transform_path = repo_root / AUTO_DEPLOY_TRANSFORM_PATH + if not transform_path.is_dir(): + LOGGER.warning("AutoDeploy transform root not found: %s", transform_path) + return () + + sources = [(AUTO_DEPLOY_TRANSFORM_PACKAGE, transform_path)] + sources.extend( + ( + f"{AUTO_DEPLOY_TRANSFORM_PACKAGE}.{path.name}", + path, + ) + for path in sorted( + transform_path.iterdir(), key=lambda path: (path.name != "library", path.name) + ) + if path.is_dir() and not path.name.startswith("_") and (path / "__init__.py").is_file() ) + return tuple(sources) def _discover_transform_modules(library_path: Path) -> list[str]: """Discover public AutoDeploy transform modules without importing them.""" if not library_path.is_dir(): - raise FileNotFoundError(f"AutoDeploy transform library not found: {library_path}") + LOGGER.warning("AutoDeploy transform source not found: %s", library_path) + return [] return sorted( - module_info.name - for module_info in pkgutil.iter_modules([str(library_path)]) - if not module_info.name.startswith("_") + module_path.stem + for module_path in library_path.glob("*.py") + if module_path.name != "__init__.py" and not module_path.stem.startswith("_") ) @@ -179,6 +204,7 @@ def _get_config_class_name(node: ast.ClassDef) -> str | None: def _parse_transform_classes( + package_name: str, library_path: Path, ) -> tuple[list[ParsedClass], dict[str, list[ParsedClass]]]: parsed_classes: list[ParsedClass] = [] @@ -186,13 +212,22 @@ def _parse_transform_classes( for module_name in _discover_transform_modules(library_path): module_path = library_path / f"{module_name}.py" - tree = ast.parse(module_path.read_text(encoding="utf-8")) + try: + tree = ast.parse(module_path.read_text(encoding="utf-8"), filename=str(module_path)) + except (OSError, SyntaxError, UnicodeDecodeError) as error: + LOGGER.warning( + "Skipping AutoDeploy transform module %s while generating docs: %s", + module_path, + error, + ) + continue for node in tree.body: if not isinstance(node, ast.ClassDef): continue parsed_class = ParsedClass( + package_name=package_name, module_name=module_name, class_name=node.name, base_class_names=tuple( @@ -215,6 +250,7 @@ def _parse_transform_classes( def _get_library_class( class_name: str, + package_name: str, module_name: str, classes_by_name: dict[str, list[ParsedClass]], ) -> ParsedClass | None: @@ -222,7 +258,7 @@ def _get_library_class( if len(classes) == 1: return classes[0] for parsed_class in classes: - if parsed_class.module_name == module_name: + if parsed_class.package_name == package_name and parsed_class.module_name == module_name: return parsed_class return None @@ -238,6 +274,7 @@ def _resolve_config_class( return None return _get_library_class( parsed_class.config_class_name, + parsed_class.package_name, parsed_class.module_name, classes_by_name, ) @@ -255,26 +292,42 @@ def _resolve_config_class( return None -def _discover_registered_transforms(library_path: Path) -> dict[str, RegisteredTransform]: +def _discover_registered_transforms(repo_root: Path) -> dict[str, RegisteredTransform]: """Discover registered transform classes without importing transform modules.""" registered_transforms: dict[str, RegisteredTransform] = {} - parsed_classes, classes_by_name = _parse_transform_classes(library_path) + parsed_classes: list[ParsedClass] = [] + classes_by_name: dict[str, list[ParsedClass]] = {} + + for package_name, library_path in _transform_sources(repo_root): + source_classes, source_classes_by_name = _parse_transform_classes( + package_name, library_path + ) + parsed_classes.extend(source_classes) + for class_name, class_entries in source_classes_by_name.items(): + classes_by_name.setdefault(class_name, []).extend(class_entries) for parsed_class in parsed_classes: config_class = _resolve_config_class(parsed_class, classes_by_name) for transform_key in parsed_class.transform_keys: if transform_key in registered_transforms: previous = registered_transforms[transform_key] - raise ValueError( - f"Transform {transform_key!r} is registered by both " - f"{previous.qualified_class_name} and " - f"{parsed_class.module_name}.{parsed_class.class_name}" + LOGGER.warning( + "Transform %r is registered by both %s and %s.%s.%s; using %s.", + transform_key, + previous.qualified_class_name, + parsed_class.package_name, + parsed_class.module_name, + parsed_class.class_name, + previous.qualified_class_name, ) + continue registered_transforms[transform_key] = RegisteredTransform( key=transform_key, + package_name=parsed_class.package_name, module_name=parsed_class.module_name, class_name=parsed_class.class_name, config_class_name=config_class.class_name if config_class else "TransformConfig", + config_package_name=config_class.package_name if config_class else None, config_module_name=config_class.module_name if config_class else None, ) @@ -287,23 +340,67 @@ def _load_configured_transforms(repo_root: Path) -> list[ConfiguredTransform]: configured_transforms: list[ConfiguredTransform] = [] for mode, config_path in AUTO_DEPLOY_TRANSFORM_CONFIGS: - config = yaml.safe_load((repo_root / config_path).read_text(encoding="utf-8")) + try: + config = yaml.safe_load((repo_root / config_path).read_text(encoding="utf-8")) + except (OSError, yaml.YAMLError) as error: + LOGGER.warning( + "Skipping AutoDeploy config %s while generating docs: %s", config_path, error + ) + continue + + if config is None: + continue + if not isinstance(config, dict): + LOGGER.warning("Skipping AutoDeploy config %s: expected a mapping.", config_path) + continue + transforms = config.get("transforms", {}) + if transforms is None: + continue + if not isinstance(transforms, dict): + LOGGER.warning( + "Skipping AutoDeploy config %s transforms: expected a mapping.", + config_path, + ) + continue for transform_key, transform_config in transforms.items(): + if not isinstance(transform_key, str): + LOGGER.warning( + "Skipping AutoDeploy transform entry %r in %s: expected a string key.", + transform_key, + config_path, + ) + continue + if not isinstance(transform_config, dict): + LOGGER.warning( + "Skipping AutoDeploy transform %r in %s: expected a mapping.", + transform_key, + config_path, + ) + continue + stage = transform_config.get("stage") - if not stage: - raise ValueError( - f"Transform {transform_key!r} in {config_path} does not define a stage" + if not isinstance(stage, str) or not stage: + LOGGER.warning( + "Skipping AutoDeploy transform %r in %s: missing string stage.", + transform_key, + config_path, ) + continue configured_transform = configured_by_key.get(transform_key) if configured_transform is not None: if configured_transform.stage != stage: - raise ValueError( - f"Transform {transform_key!r} has stages " - f"{configured_transform.stage!r} and {stage!r}" + LOGGER.warning( + "Skipping AutoDeploy transform %r in %s: stage %r conflicts with " + "previous stage %r.", + transform_key, + config_path, + stage, + configured_transform.stage, ) + continue configured_transform.modes.append(mode) continue @@ -366,14 +463,33 @@ def _transform_section( def _note_auto_deploy_dependencies(directive: SphinxDirective, repo_root: Path) -> None: - library_path = repo_root / AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH - directive.env.note_dependency(str(library_path)) - for path in sorted(library_path.glob("*.py")): - directive.env.note_dependency(str(path)) + for _, library_path in _transform_sources(repo_root): + directive.env.note_dependency(str(library_path)) + for path in sorted(library_path.glob("*.py")): + directive.env.note_dependency(str(path)) for _, config_path in AUTO_DEPLOY_TRANSFORM_CONFIGS: directive.env.note_dependency(str(repo_root / config_path)) +def _unavailable_nodes(message: str) -> list[nodes.Node]: + return [nodes.paragraph(text=message)] + + +def _parse_generated_lines( + directive: SphinxDirective, + generated_lines: StringList, +) -> list[nodes.Node]: + container = nodes.container() + try: + nested_parse_with_titles(directive.state, generated_lines, container) + except (SystemMessage, SphinxError) as error: + LOGGER.warning("Skipping generated AutoDeploy transform docs: %s", error) + return _unavailable_nodes( + "AutoDeploy transform documentation is unavailable in this build." + ) + return container.children + + class AutoDeployTransformStageDirective(SphinxDirective): """Render autodoc sections for configured transforms in one pipeline stage.""" @@ -382,11 +498,18 @@ class AutoDeployTransformStageDirective(SphinxDirective): def run(self) -> list[nodes.Node]: stage = self.arguments[0] - repo_root = _repo_root_from_source_dir(self.env.app.srcdir) - library_path = repo_root / AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH + try: + repo_root = _repo_root_from_source_dir(self.env.app.srcdir) + except FileNotFoundError as error: + LOGGER.warning("Skipping AutoDeploy transform docs: %s", error) + return _unavailable_nodes( + "AutoDeploy transform documentation is unavailable in this build." + ) + + transform_path = repo_root / AUTO_DEPLOY_TRANSFORM_PATH _note_auto_deploy_dependencies(self, repo_root) - registered_transforms = _discover_registered_transforms(library_path) + registered_transforms = _discover_registered_transforms(repo_root) configured_transforms = [ transform for transform in _load_configured_transforms(repo_root) @@ -402,22 +525,39 @@ def run(self) -> list[nodes.Node]: ] generated_lines = StringList() + rendered_count = 0 + missing_transforms: list[str] = [] for configured_transform in configured_transforms: registered_transform = registered_transforms.get(configured_transform.key) if registered_transform is None: - raise ValueError( - f"Configured transform {configured_transform.key!r} is not registered" + missing_transforms.append(configured_transform.key) + LOGGER.warning( + "Configured AutoDeploy transform %r is not discoverable by the docs " + "extension; skipping autodoc for it.", + configured_transform.key, ) + continue for line in _transform_section( configured_transform.key, registered_transform, configured_transform.modes, ): - generated_lines.append(line, source=str(library_path)) + generated_lines.append(line, source=str(transform_path)) + rendered_count += 1 + + if rendered_count == 0: + title = STAGE_TITLES.get(stage, stage) + if missing_transforms: + return [ + nodes.paragraph( + text=( + f"No discoverable AutoDeploy transforms are documented for the " + f"{title} stage." + ) + ) + ] - container = nodes.container() - nested_parse_with_titles(self.state, generated_lines, container) - return container.children + return _parse_generated_lines(self, generated_lines) class AutoDeployAdditionalTransformsDirective(SphinxDirective): @@ -426,11 +566,25 @@ class AutoDeployAdditionalTransformsDirective(SphinxDirective): has_content = False def run(self) -> list[nodes.Node]: - repo_root = _repo_root_from_source_dir(self.env.app.srcdir) - library_path = repo_root / AUTO_DEPLOY_TRANSFORM_LIBRARY_PATH + try: + repo_root = _repo_root_from_source_dir(self.env.app.srcdir) + except FileNotFoundError as error: + LOGGER.warning("Skipping AutoDeploy additional transform docs: %s", error) + return _unavailable_nodes( + "AutoDeploy transform documentation is unavailable in this build." + ) + + transform_path = repo_root / AUTO_DEPLOY_TRANSFORM_PATH _note_auto_deploy_dependencies(self, repo_root) - registered_transforms = _discover_registered_transforms(library_path) + registered_transforms = _discover_registered_transforms(repo_root) + if not registered_transforms: + return [ + nodes.paragraph( + text="No discoverable AutoDeploy transforms are documented in this build." + ) + ] + configured_keys = {transform.key for transform in _load_configured_transforms(repo_root)} additional_transforms = [ registered_transform @@ -451,11 +605,9 @@ def run(self) -> list[nodes.Node]: registered_transform.key, registered_transform, ): - generated_lines.append(line, source=str(library_path)) + generated_lines.append(line, source=str(transform_path)) - container = nodes.container() - nested_parse_with_titles(self.state, generated_lines, container) - return container.children + return _parse_generated_lines(self, generated_lines) def setup(app: Sphinx) -> dict[str, bool | str]: