Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
log_eval_status_and_rows,
parse_ep_completion_params,
parse_ep_completion_params_overwrite,
parse_ep_max_concurrent_evaluations,
parse_ep_max_concurrent_rollouts,
parse_ep_max_rows,
parse_ep_num_runs,
Expand Down Expand Up @@ -201,6 +202,7 @@ def evaluation_test(
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
num_runs = parse_ep_num_runs(num_runs)
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
max_concurrent_evaluations = parse_ep_max_concurrent_evaluations(max_concurrent_evaluations)
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
completion_params = parse_ep_completion_params_overwrite(completion_params)
Expand Down
9 changes: 9 additions & 0 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
return int(raw) if raw is not None else default_value


def parse_ep_max_concurrent_evaluations(default_value: int) -> int:
"""Read EP_MAX_CONCURRENT_EVALUATIONS env override as int.

Assumes the environment variable was already validated by plugin.py.
"""
raw = os.getenv("EP_MAX_CONCURRENT_EVALUATIONS")
return int(raw) if raw is not None else default_value


def parse_ep_completion_params(
completion_params: Sequence[CompletionParams | None] | None,
) -> Sequence[CompletionParams | None]:
Expand Down
19 changes: 15 additions & 4 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def pytest_addoption(parser) -> None:
default=None,
help=("Override the maximum number of concurrent rollouts. Pass an integer (e.g., 8, 50, 100)."),
)
group.addoption(
"--ep-max-concurrent-evaluations",
action="store",
default=None,
help=("Override the maximum number of concurrent evaluations. Pass an integer (e.g., 8, 50, 100)."),
)
group.addoption(
"--ep-print-summary",
action="store_true",
Expand Down Expand Up @@ -242,10 +248,15 @@ def pytest_configure(config) -> None:
if norm_runs is not None:
os.environ["EP_NUM_RUNS"] = norm_runs

max_concurrent_val = config.getoption("--ep-max-concurrent-rollouts")
norm_concurrent = _normalize_number(max_concurrent_val)
if norm_concurrent is not None:
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent
max_concurrent_rollouts_val = config.getoption("--ep-max-concurrent-rollouts")
norm_concurrent_rollouts = _normalize_number(max_concurrent_rollouts_val)
if norm_concurrent_rollouts is not None:
os.environ["EP_MAX_CONCURRENT_ROLLOUTS"] = norm_concurrent_rollouts

max_concurrent_evals_val = config.getoption("--ep-max-concurrent-evaluations")
norm_concurrent_evals = _normalize_number(max_concurrent_evals_val)
if norm_concurrent_evals is not None:
os.environ["EP_MAX_CONCURRENT_EVALUATIONS"] = norm_concurrent_evals

if config.getoption("--ep-print-summary"):
os.environ["EP_PRINT_SUMMARY"] = "1"
Expand Down
Loading