From a56c558ce03b2c11d6098e71340a4eec5d02519f Mon Sep 17 00:00:00 2001 From: MohanKumar21! Date: Wed, 13 May 2026 17:00:08 +0530 Subject: [PATCH] perf(e2e): pool local workers across classes and cluster tests by config Signed-off-by: MohanKumar21! --- e2e_test/conftest.py | 10 +- e2e_test/fixtures/__init__.py | 2 + e2e_test/fixtures/hooks.py | 144 ++++++++++++++++---- e2e_test/fixtures/setup_backend.py | 211 +++++++++++++++++++++++++++-- 4 files changed, 329 insertions(+), 38 deletions(-) diff --git a/e2e_test/conftest.py b/e2e_test/conftest.py index 4a9d33fa8..6faf56a10 100644 --- a/e2e_test/conftest.py +++ b/e2e_test/conftest.py @@ -19,8 +19,12 @@ Fixtures -------- -setup_backend: Class-scoped fixture that launches workers + gateway per test class. - Returns (backend_name, model_path, client, gateway). +setup_backend: Class-scoped fixture that launches (or reuses) workers and a + fresh gateway per test class. Workers for non-PD local backends are + pooled across classes that share ``(model_id, engine, mode, count)`` so + cold-start cost is paid once per distinct config per session. Set + ``E2E_DISABLE_WORKER_POOL=1`` to fall back to per-class worker + start/stop. Returns ``(backend_name, model_path, client, gateway)``. model: Convenience fixture that returns just the model_path from setup_backend. """ @@ -108,6 +112,7 @@ def pytest_runtest_logstart(nodeid: str, location: tuple) -> None: pytest_collection_modifyitems, pytest_configure, pytest_runtest_setup, + pytest_sessionfinish, setup_backend, ) from smg_client import SmgClient @@ -151,6 +156,7 @@ def api_client(request, setup_backend): "pytest_runtest_setup", "pytest_collection_modifyitems", "pytest_configure", + "pytest_sessionfinish", # Fixtures "setup_backend", "backend_router", diff --git a/e2e_test/fixtures/__init__.py b/e2e_test/fixtures/__init__.py index 291915a2a..3cfc8ae10 100644 --- a/e2e_test/fixtures/__init__.py +++ b/e2e_test/fixtures/__init__.py @@ -12,6 +12,7 @@ pytest_collection_modifyitems, pytest_configure, pytest_runtest_setup, + pytest_sessionfinish, ) # Marker helpers @@ -25,6 +26,7 @@ "pytest_collection_modifyitems", "pytest_configure", "pytest_runtest_setup", + "pytest_sessionfinish", # Backend fixtures "setup_backend", "backend_router", diff --git a/e2e_test/fixtures/hooks.py b/e2e_test/fixtures/hooks.py index ca15e0027..36d686591 100644 --- a/e2e_test/fixtures/hooks.py +++ b/e2e_test/fixtures/hooks.py @@ -3,16 +3,23 @@ This module handles: - Marker registration: Defining custom pytest markers - Test filtering: Env-var-based filtering by engine, vendor, and GPU tier +- Test ordering: Cluster items by backend config so the session-scoped + worker pool in ``setup_backend`` can amortize cold starts across classes. +- Session cleanup: Stop any pooled workers at session end. """ from __future__ import annotations +import logging import os import pytest from infra import get_runtime from .markers import resolve_class_marker +from .setup_backend import shutdown_worker_pool + +logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Marker registration @@ -106,36 +113,127 @@ def _get_marker(item: pytest.Item, name: str): return resolve_class_marker(item, name) +def _parametrize_argnames_to_set(argnames: object) -> set[str]: + if isinstance(argnames, str): + return {n.strip() for n in argnames.split(",") if n.strip()} + if isinstance(argnames, (tuple, list)): + return {str(n) for n in argnames} + return set() + + +def _class_level_backend_sort_token(item: pytest.Item) -> str | None: + """Stable token from class ``pytestmark`` ``parametrize`` for backend fixtures. + + When ``setup_backend`` / ``backend_router`` are parametrized on the class, + using only ``callspec.params`` splits items by concrete value and scatters + them in the global sort. Aggregating class-level ``parametrize`` keeps all + variants of that class adjacent (stable sort preserves intra-class order). + """ + cls = getattr(item, "cls", None) + if cls is None: + return None + tokens: list[str] = [] + for base in cls.__mro__: + if base is object: + continue + marks = getattr(base, "pytestmark", None) + if marks is None: + continue + if not isinstance(marks, list): + marks = [marks] + for mark in marks: + if getattr(mark, "name", None) != "parametrize": + continue + mark_args = tuple(getattr(mark, "args", ()) or ()) + if len(mark_args) < 2: + continue + argnames_obj = mark_args[0] + argvalues_obj = mark_args[1] + names = _parametrize_argnames_to_set(argnames_obj) + if not (names & {"setup_backend", "backend_router"}): + continue + tokens.append(repr((argnames_obj, argvalues_obj))) + if not tokens: + return None + return "|".join(sorted(set(tokens))) + + +def _backend_sort_key(item: pytest.Item) -> tuple: + """Stable ordering key used to cluster test classes by backend config. + + Items that share ``(model_id, workers_config, backend_bucket)`` end up + adjacent so the session-scoped worker pool in ``setup_backend`` can + reuse warm workers between consecutive same-config classes. The backend + bucket prefers an aggregate token from class-level ``parametrize`` marks + for ``setup_backend`` / ``backend_router`` so mixed parametrizations do + not scatter tests; otherwise falls back to ``callspec.params``. Items + without these markers/params fall into a single neutral bucket and + keep their original collection order via Python's stable sort. + """ + model = resolve_class_marker(item, "model") + workers = resolve_class_marker(item, "workers") + + model_id = model.args[0] if model and model.args else "" + if workers is not None: + wcount = workers.kwargs.get("count") or 0 + wprefill = workers.kwargs.get("prefill") or 0 + wdecode = workers.kwargs.get("decode") or 0 + else: + wcount = wprefill = wdecode = 0 + + backend = _class_level_backend_sort_token(item) + if backend is None: + callspec = getattr(item, "callspec", None) + if callspec is not None: + params = getattr(callspec, "params", {}) + backend = str(params.get("setup_backend") or params.get("backend_router") or "") + else: + backend = "" + + return (model_id, wcount, wprefill, wdecode, backend) + + def pytest_collection_modifyitems( config: pytest.Config, items: list[pytest.Item], ) -> None: - """Filter collected tests based on E2E_ENGINE, E2E_VENDOR, and E2E_GPU_TIER env vars.""" + """Filter collected tests based on env vars, then cluster by backend config.""" engine = os.environ.get("E2E_ENGINE") or None vendor = os.environ.get("E2E_VENDOR") or None gpu_tier = os.environ.get("E2E_GPU_TIER") or None - if not any([engine, vendor, gpu_tier]): - return + if any([engine, vendor, gpu_tier]): + selected: list[pytest.Item] = [] + for item in items: + if engine: + engine_marker = _get_marker(item, "engine") + if not engine_marker or engine not in engine_marker.args: + continue + if vendor: + vendor_marker = _get_marker(item, "vendor") + if not vendor_marker or vendor not in vendor_marker.args: + continue + if gpu_tier is not None: + gpu_marker = _get_marker(item, "gpu") + gpu_count = gpu_marker.args[0] if gpu_marker else 1 + if str(gpu_count) != gpu_tier: + continue + selected.append(item) + + items[:] = selected + + # Stable sort: equal keys preserve the original (post-filter) order, so + # tests within a single class stay together and only the class-level + # grouping changes. + if os.environ.get("E2E_DISABLE_TEST_SORT", "").lower() not in ("1", "true", "yes"): + items.sort(key=_backend_sort_key) + + +# --------------------------------------------------------------------------- +# Session lifecycle +# --------------------------------------------------------------------------- - selected: list[pytest.Item] = [] - for item in items: - # Filter by engine - if engine: - engine_marker = _get_marker(item, "engine") - if not engine_marker or engine not in engine_marker.args: - continue - # Filter by vendor - if vendor: - vendor_marker = _get_marker(item, "vendor") - if not vendor_marker or vendor not in vendor_marker.args: - continue - # Filter by GPU tier - if gpu_tier is not None: - gpu_marker = _get_marker(item, "gpu") - gpu_count = gpu_marker.args[0] if gpu_marker else 1 - if str(gpu_count) != gpu_tier: - continue - selected.append(item) - items[:] = selected +def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None: + """Stop any workers held by the ``setup_backend`` pool at session end.""" + shutdown_worker_pool() diff --git a/e2e_test/fixtures/setup_backend.py b/e2e_test/fixtures/setup_backend.py index 54a7d71d7..435638c31 100644 --- a/e2e_test/fixtures/setup_backend.py +++ b/e2e_test/fixtures/setup_backend.py @@ -1,14 +1,29 @@ """Backend setup fixtures for E2E tests. -Simplified backend lifecycle -- one set of workers and gateway per test class. -No model pool, no thread-local caching. Direct worker management via -start_workers/stop_workers. +Backend lifecycle: one gateway per test class, but the underlying workers +are pooled across classes that share ``(model_id, engine, mode, count)``. + +Workers are the expensive part (cold start is 1-3 minutes); gateways start +in seconds. Decoupling worker lifetime from class scope lets consecutive +same-config classes reuse warm workers and just spin a fresh gateway. + +Pool behaviour: +- Keyed on ``(model_id, engine, mode, count)``; PD and cloud backends are + never pooled. +- LRU with a single-entry default (``E2E_WORKER_POOL_SIZE``) so GPU memory + stays bounded. +- Flushed at session end via ``shutdown_worker_pool``; on eviction the + evicted workers are stopped immediately. +- Opt out with ``E2E_DISABLE_WORKER_POOL=1`` (falls back to per-class + start/stop semantics). """ from __future__ import annotations import logging import os +from collections import OrderedDict +from typing import NamedTuple import anthropic import openai @@ -58,6 +73,103 @@ def _start_workers_tracked(**kwargs) -> list: raise +def _require_exact_worker_count(*, role: str, requested: int, workers: list) -> None: + """Fail the run if we did not acquire exactly the requested worker count.""" + got = len(workers) if workers is not None else 0 + if got != requested: + pytest.fail( + f"E2E worker acquisition failed: expected exactly {requested} {role} " + f"worker(s), obtained {got}" + ) + + +# --------------------------------------------------------------------------- +# Session-scoped worker pool +# --------------------------------------------------------------------------- + + +class _WorkerKey(NamedTuple): + model_id: str + engine: str + mode: str # "http" | "grpc" + worker_count: int + + +_worker_pool: OrderedDict[_WorkerKey, list] = OrderedDict() + +_raw_pool_size = os.environ.get("E2E_WORKER_POOL_SIZE", "1") or "1" +try: + _parsed_pool_size = int(_raw_pool_size) +except (TypeError, ValueError): + logger.warning("Invalid E2E_WORKER_POOL_SIZE=%r, using 1", _raw_pool_size) + _parsed_pool_size = 1 +_POOL_MAX = max(1, _parsed_pool_size) + +_POOL_DISABLED = os.environ.get("E2E_DISABLE_WORKER_POOL", "").lower() in ("1", "true", "yes") + + +def _pool_enabled() -> bool: + return not _POOL_DISABLED + + +def _pool_get(key: _WorkerKey) -> list | None: + """Return cached workers for ``key`` and mark them most-recently-used.""" + workers = _worker_pool.get(key) + if workers is not None: + _worker_pool.move_to_end(key) + return workers + + +def _pool_put(key: _WorkerKey, workers: list) -> None: + """Insert ``workers`` into the pool, evicting LRU entries past the cap. + + Eviction stops the evicted workers eagerly so GPU memory is reclaimed + before the next cold start. + """ + _worker_pool[key] = workers + _worker_pool.move_to_end(key) + while len(_worker_pool) > _POOL_MAX: + evict_key = next(iter(_worker_pool)) + if evict_key == key: + # Should not happen with _POOL_MAX >= 1, but guard anyway. + break + evict_workers = _worker_pool.pop(evict_key) + logger.info("Evicting cached workers for %s", evict_key) + stop_workers(evict_workers) + + +def _pool_make_room_for(key: _WorkerKey) -> None: + """Evict LRU pool entries to free GPUs before launching ``key`` fresh. + + Pooled workers occupy GPUs 0..tp*count-1; a brand-new worker set for a + different config would collide with them. Drop everything that does not + match ``key`` until the pool has strictly less than ``_POOL_MAX`` entries, + leaving room for the new entry. Called only on cache-miss so a true cache + hit never evicts anyone. + """ + while len(_worker_pool) >= _POOL_MAX: + evict_key = next(iter(_worker_pool)) + if evict_key == key: + # Pool already has our key (shouldn't happen on a miss); bail out. + break + evict_workers = _worker_pool.pop(evict_key) + logger.info("Evicting cached workers for %s to free GPUs", evict_key) + stop_workers(evict_workers) + + +def _pool_drop(key: _WorkerKey) -> list | None: + """Remove a key from the pool without stopping its workers.""" + return _worker_pool.pop(key, None) + + +def shutdown_worker_pool() -> None: + """Stop and forget every pooled worker. Called at session end.""" + while _worker_pool: + key, workers = _worker_pool.popitem(last=False) + logger.info("Stopping pooled workers for %s at session end", key) + stop_workers(workers) + + def _start_gateway(gateway: Gateway, gateway_config: dict, **mode_kwargs) -> None: """Start gateway with mode-specific kwargs and shared config.""" gateway.start( @@ -175,17 +287,50 @@ def _setup_local( backend_name, log_dir, ): - """Launch regular workers + gateway, yield result tuple, tear down.""" + """Launch (or reuse) regular workers + a fresh gateway, yield, tear down. + + Workers are pooled across test classes that share + ``(model_id, engine, connection_mode, count)`` so the 1-3 min cold start + is paid once per distinct config per session instead of per class. The + gateway is always fresh — its startup is cheap and isolates router state. + """ num_workers = workers_config.get("count") or 1 - logger.info("Starting %s backend: model=%s, workers=%d", backend_name, model_id, num_workers) - - workers = _start_workers_tracked( - model_id=model_id, - engine=engine, - mode=connection_mode, - count=num_workers, - log_dir=log_dir, + key = _WorkerKey(model_id, engine, connection_mode.value, num_workers) + use_pool = _pool_enabled() + + cached = _pool_get(key) if use_pool else None + if cached is not None: + logger.info( + "Reusing pooled workers for %s backend: model=%s, workers=%d", + backend_name, + model_id, + num_workers, + ) + workers = cached + is_fresh = False + else: + if use_pool: + # Free up GPUs occupied by stale pool entries before launching. + _pool_make_room_for(key) + logger.info( + "Starting %s backend: model=%s, workers=%d", backend_name, model_id, num_workers + ) + workers = _start_workers_tracked( + model_id=model_id, + engine=engine, + mode=connection_mode, + count=num_workers, + log_dir=log_dir, + ) + is_fresh = True + + _require_exact_worker_count( + role=f"{backend_name} ({model_id})", + requested=num_workers, + workers=workers, ) + + gateway_started = False try: _start_gateway( gateway, @@ -193,12 +338,30 @@ def _setup_local( worker_urls=[w.base_url for w in workers], model_path=model_path, ) + gateway_started = True logger.info("%s backend ready at %s", backend_name, gateway.base_url) yield backend_name, model_path, _make_openai_client(gateway), gateway finally: logger.info("Tearing down %s backend", backend_name) gateway.shutdown() - stop_workers(workers) + if not use_pool: + if is_fresh: + stop_workers(workers) + elif gateway_started: + # Gateway came up healthy → workers are good to keep warm. + if is_fresh: + _pool_put(key, workers) + # else: cached workers stayed in pool throughout (never popped). + else: + # Gateway never came up: treat these workers as suspect. + if is_fresh: + logger.info("Discarding workers after gateway start failure") + stop_workers(workers) + else: + evicted = _pool_drop(key) + if evicted is not None: + logger.info("Evicting cached workers after gateway start failure") + stop_workers(evicted) # --------------------------------------------------------------------------- @@ -231,6 +394,13 @@ def _setup_pd( num_decode, ) + # PD workers assign GPUs starting at 0, which would collide with any + # regular (non-PD) workers kept warm in the pool. Drain the pool first + # so the GPUs are free. + if _worker_pool: + logger.info("Draining %d pooled worker set(s) before PD setup", len(_worker_pool)) + shutdown_worker_pool() + all_workers: list = [] try: prefill_workers = _start_workers_tracked( @@ -241,6 +411,11 @@ def _setup_pd( worker_type=WorkerType.PREFILL, log_dir=log_dir, ) + _require_exact_worker_count( + role=f"PD prefill ({model_id})", + requested=num_prefill, + workers=prefill_workers, + ) all_workers.extend(prefill_workers) # Decode workers start on GPUs after prefill workers @@ -254,6 +429,11 @@ def _setup_pd( log_dir=log_dir, gpu_offset=decode_gpu_offset, ) + _require_exact_worker_count( + role=f"PD decode ({model_id})", + requested=num_decode, + workers=decode_workers, + ) all_workers.extend(decode_workers) _start_gateway( @@ -332,6 +512,11 @@ def test_router_state(backend_router): model_path = get_model_spec(model_id)["model"] workers = start_workers(model_id, engine=get_runtime(), mode=connection_mode, count=1) + _require_exact_worker_count( + role=f"backend_router {backend_name} ({model_id})", + requested=1, + workers=workers, + ) gateway = Gateway() try: gateway.start(worker_urls=[w.base_url for w in workers], model_path=model_path)