Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions e2e_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

Fixtures
--------
setup_backend: Class-scoped fixture that launches workers + gateway per test class.
Returns (backend_name, model_path, client, gateway).
setup_backend: Class-scoped fixture that launches (or reuses) workers and a
fresh gateway per test class. Workers for non-PD local backends are
pooled across classes that share ``(model_id, engine, mode, count)`` so
cold-start cost is paid once per distinct config per session. Set
``E2E_DISABLE_WORKER_POOL=1`` to fall back to per-class worker
start/stop. Returns ``(backend_name, model_path, client, gateway)``.
model: Convenience fixture that returns just the model_path from setup_backend.
"""

Expand Down Expand Up @@ -108,6 +112,7 @@ def pytest_runtest_logstart(nodeid: str, location: tuple) -> None:
pytest_collection_modifyitems,
pytest_configure,
pytest_runtest_setup,
pytest_sessionfinish,
setup_backend,
)
from smg_client import SmgClient
Expand Down Expand Up @@ -151,6 +156,7 @@ def api_client(request, setup_backend):
"pytest_runtest_setup",
"pytest_collection_modifyitems",
"pytest_configure",
"pytest_sessionfinish",
# Fixtures
"setup_backend",
"backend_router",
Expand Down
2 changes: 2 additions & 0 deletions e2e_test/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
pytest_collection_modifyitems,
pytest_configure,
pytest_runtest_setup,
pytest_sessionfinish,
)

# Marker helpers
Expand All @@ -25,6 +26,7 @@
"pytest_collection_modifyitems",
"pytest_configure",
"pytest_runtest_setup",
"pytest_sessionfinish",
# Backend fixtures
"setup_backend",
"backend_router",
Expand Down
144 changes: 121 additions & 23 deletions e2e_test/fixtures/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
This module handles:
- Marker registration: Defining custom pytest markers
- Test filtering: Env-var-based filtering by engine, vendor, and GPU tier
- Test ordering: Cluster items by backend config so the session-scoped
worker pool in ``setup_backend`` can amortize cold starts across classes.
- Session cleanup: Stop any pooled workers at session end.
"""

from __future__ import annotations

import logging
import os

import pytest
from infra import get_runtime

from .markers import resolve_class_marker
from .setup_backend import shutdown_worker_pool

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Marker registration
Expand Down Expand Up @@ -106,36 +113,127 @@ def _get_marker(item: pytest.Item, name: str):
return resolve_class_marker(item, name)


def _parametrize_argnames_to_set(argnames: object) -> set[str]:
if isinstance(argnames, str):
return {n.strip() for n in argnames.split(",") if n.strip()}
if isinstance(argnames, (tuple, list)):
return {str(n) for n in argnames}
return set()


def _class_level_backend_sort_token(item: pytest.Item) -> str | None:
"""Stable token from class ``pytestmark`` ``parametrize`` for backend fixtures.

When ``setup_backend`` / ``backend_router`` are parametrized on the class,
using only ``callspec.params`` splits items by concrete value and scatters
them in the global sort. Aggregating class-level ``parametrize`` keeps all
variants of that class adjacent (stable sort preserves intra-class order).
"""
cls = getattr(item, "cls", None)
if cls is None:
return None
tokens: list[str] = []
for base in cls.__mro__:
if base is object:
continue
marks = getattr(base, "pytestmark", None)
if marks is None:
continue
if not isinstance(marks, list):
marks = [marks]
for mark in marks:
if getattr(mark, "name", None) != "parametrize":
continue
mark_args = tuple(getattr(mark, "args", ()) or ())
if len(mark_args) < 2:
continue
argnames_obj = mark_args[0]
argvalues_obj = mark_args[1]
names = _parametrize_argnames_to_set(argnames_obj)
if not (names & {"setup_backend", "backend_router"}):
continue
tokens.append(repr((argnames_obj, argvalues_obj)))
if not tokens:
return None
return "|".join(sorted(set(tokens)))


def _backend_sort_key(item: pytest.Item) -> tuple:
"""Stable ordering key used to cluster test classes by backend config.

Items that share ``(model_id, workers_config, backend_bucket)`` end up
adjacent so the session-scoped worker pool in ``setup_backend`` can
reuse warm workers between consecutive same-config classes. The backend
bucket prefers an aggregate token from class-level ``parametrize`` marks
for ``setup_backend`` / ``backend_router`` so mixed parametrizations do
not scatter tests; otherwise falls back to ``callspec.params``. Items
without these markers/params fall into a single neutral bucket and
keep their original collection order via Python's stable sort.
"""
model = resolve_class_marker(item, "model")
workers = resolve_class_marker(item, "workers")

model_id = model.args[0] if model and model.args else ""
if workers is not None:
wcount = workers.kwargs.get("count") or 0
wprefill = workers.kwargs.get("prefill") or 0
wdecode = workers.kwargs.get("decode") or 0
else:
wcount = wprefill = wdecode = 0

backend = _class_level_backend_sort_token(item)
if backend is None:
callspec = getattr(item, "callspec", None)
if callspec is not None:
params = getattr(callspec, "params", {})
backend = str(params.get("setup_backend") or params.get("backend_router") or "")
Comment thread
MohanKumar21 marked this conversation as resolved.
else:
backend = ""

return (model_id, wcount, wprefill, wdecode, backend)


def pytest_collection_modifyitems(
config: pytest.Config,
items: list[pytest.Item],
) -> None:
"""Filter collected tests based on E2E_ENGINE, E2E_VENDOR, and E2E_GPU_TIER env vars."""
"""Filter collected tests based on env vars, then cluster by backend config."""
engine = os.environ.get("E2E_ENGINE") or None
vendor = os.environ.get("E2E_VENDOR") or None
gpu_tier = os.environ.get("E2E_GPU_TIER") or None

if not any([engine, vendor, gpu_tier]):
return
if any([engine, vendor, gpu_tier]):
selected: list[pytest.Item] = []
for item in items:
if engine:
engine_marker = _get_marker(item, "engine")
if not engine_marker or engine not in engine_marker.args:
continue
if vendor:
vendor_marker = _get_marker(item, "vendor")
if not vendor_marker or vendor not in vendor_marker.args:
continue
if gpu_tier is not None:
gpu_marker = _get_marker(item, "gpu")
gpu_count = gpu_marker.args[0] if gpu_marker else 1
if str(gpu_count) != gpu_tier:
continue
selected.append(item)

items[:] = selected

# Stable sort: equal keys preserve the original (post-filter) order, so
# tests within a single class stay together and only the class-level
# grouping changes.
if os.environ.get("E2E_DISABLE_TEST_SORT", "").lower() not in ("1", "true", "yes"):
items.sort(key=_backend_sort_key)


# ---------------------------------------------------------------------------
# Session lifecycle
# ---------------------------------------------------------------------------

selected: list[pytest.Item] = []
for item in items:
# Filter by engine
if engine:
engine_marker = _get_marker(item, "engine")
if not engine_marker or engine not in engine_marker.args:
continue
# Filter by vendor
if vendor:
vendor_marker = _get_marker(item, "vendor")
if not vendor_marker or vendor not in vendor_marker.args:
continue
# Filter by GPU tier
if gpu_tier is not None:
gpu_marker = _get_marker(item, "gpu")
gpu_count = gpu_marker.args[0] if gpu_marker else 1
if str(gpu_count) != gpu_tier:
continue
selected.append(item)

items[:] = selected
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
"""Stop any workers held by the ``setup_backend`` pool at session end."""
shutdown_worker_pool()
Loading