From e4ebf71fe8a09c4e3a53d7afeefa82bf30408082 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 14 Jan 2026 11:53:58 -0800 Subject: [PATCH] ep max evals flag --- eval_protocol/pytest/evaluation_test.py | 2 ++ eval_protocol/pytest/evaluation_test_utils.py | 9 +++++++++ eval_protocol/pytest/plugin.py | 19 +++++++++++++++---- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 84a66805..eb05aa35 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -57,6 +57,7 @@ log_eval_status_and_rows, parse_ep_completion_params, parse_ep_completion_params_overwrite, + parse_ep_max_concurrent_evaluations, parse_ep_max_concurrent_rollouts, parse_ep_max_rows, parse_ep_num_runs, @@ -201,6 +202,7 @@ def evaluation_test( # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}'). num_runs = parse_ep_num_runs(num_runs) max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts) + max_concurrent_evaluations = parse_ep_max_concurrent_evaluations(max_concurrent_evaluations) max_dataset_rows = parse_ep_max_rows(max_dataset_rows) completion_params = parse_ep_completion_params(completion_params) completion_params = parse_ep_completion_params_overwrite(completion_params) diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index 64f0c8b3..b0ebd235 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -226,6 +226,15 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int: return int(raw) if raw is not None else default_value +def parse_ep_max_concurrent_evaluations(default_value: int) -> int: + """Read EP_MAX_CONCURRENT_EVALUATIONS env override as int. + + Assumes the environment variable was already validated by plugin.py. + """ + raw = os.getenv("EP_MAX_CONCURRENT_EVALUATIONS") + return int(raw) if raw is not None else default_value + + def parse_ep_completion_params( completion_params: Sequence[CompletionParams | None] | None, ) -> Sequence[CompletionParams | None]: diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 7b17b6d9..a6d71b31 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -45,6 +45,12 @@ def pytest_addoption(parser) -> None: default=None, help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."), ) + group.addoption( + "--ep-max-concurrent-evaluations", + action="store", + default=None, + help=("Override the maximum number of concurrent evaluations. Pass an integer (e.g., 8, 50, 100)."), + ) group.addoption( "--ep-print-summary", action="store_true", @@ -242,10 +248,15 @@ def pytest_configure(config) -> None: if norm_runs is not None: os.environ["EP_NUM_RUNS"] = norm_runs - max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts") - norm_concurrent = _normalize_number(max_concurrent_val) - if norm_concurrent is not None: - os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent + max_concurrent_rollouts_val = config.getoption("--ep-max-concurrent-rollouts") + norm_concurrent_rollouts = _normalize_number(max_concurrent_rollouts_val) + if norm_concurrent_rollouts is not None: + os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent_rollouts + + max_concurrent_evals_val = config.getoption("--ep-max-concurrent-evaluations") + norm_concurrent_evals = _normalize_number(max_concurrent_evals_val) + if norm_concurrent_evals is not None: + os.environ["EP_MAX_CONCURRENT_EVALUATIONS"] = norm_concurrent_evals if config.getoption("--ep-print-summary"): os.environ["EP_PRINT_SUMMARY"] = "1"