Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/trade_study/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

import optuna

from .design import Factor

ProgressCallback = Callable[[int, int, TrialResult], None]


def _run_single(
world: Simulator,
Expand All @@ -52,6 +56,7 @@ def run_grid(
*,
annotations: list[Annotation] | None = None,
n_jobs: int = 1,
callback: ProgressCallback | None = None,
) -> ResultsTable:
"""Run all configurations in a grid.

Expand All @@ -62,18 +67,29 @@ def run_grid(
observables: Observable definitions (for column ordering).
annotations: Optional external annotations (costs, etc.).
n_jobs: Number of parallel workers (-1 for all CPUs).
callback: Optional progress callback invoked after each trial
with ``(trial_index, total_trials, trial_result)``.

Returns:
ResultsTable with scored results.
"""
total = len(grid)
if n_jobs == 1:
results = [_run_single(world, scorer, cfg) for cfg in grid]
results: list[TrialResult] = []
for i, cfg in enumerate(grid):
r = _run_single(world, scorer, cfg)
results.append(r)
if callback is not None:
callback(i, total, r)
else:
from joblib import Parallel, delayed # type: ignore[import-untyped]

results = Parallel(n_jobs=n_jobs)(
delayed(_run_single)(world, scorer, cfg) for cfg in grid
)
if callback is not None:
for i, r in enumerate(results):
callback(i, total, r)

obs_names = [o.name for o in observables]
score_matrix = np.array([
Expand Down
15 changes: 14 additions & 1 deletion src/trade_study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Scorer,
Simulator,
)
from .runner import ProgressCallback

GridCallable = Callable[[ResultsTable, list[Observable]], list[dict[str, Any]]]

Expand Down Expand Up @@ -165,9 +166,19 @@ class Study:

_results: dict[str, ResultsTable] = field(default_factory=dict, init=False)

def run(self, *, n_jobs: int = 1) -> None:
def run(
self,
*,
n_jobs: int = 1,
callback: ProgressCallback | None = None,
) -> None:
"""Execute all phases sequentially.

Args:
n_jobs: Number of parallel workers for grid phases.
callback: Optional progress callback invoked after each trial
with ``(trial_index, total_trials, trial_result)``.

Raises:
ValueError: If a callable grid is used on the first phase
(no previous results to pass).
Expand Down Expand Up @@ -198,6 +209,7 @@ def run(self, *, n_jobs: int = 1) -> None:
self.observables,
annotations=self.annotations or None,
n_jobs=n_jobs,
callback=callback,
)
else:
grid = (
Expand All @@ -210,6 +222,7 @@ def run(self, *, n_jobs: int = 1) -> None:
self.observables,
annotations=self.annotations or None,
n_jobs=n_jobs,
callback=callback,
)

self._results[phase.name] = result
Expand Down
40 changes: 39 additions & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from trade_study.design import Factor, FactorType
from trade_study.protocols import Annotation, Direction, Observable
from trade_study.protocols import Annotation, Direction, Observable, TrialResult
from trade_study.runner import run_adaptive, run_grid

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -320,3 +320,41 @@ def test_run_adaptive_deterministic_seed(
r1 = run_adaptive(world, scorer, factors, observables, n_trials=10, seed=7)
r2 = run_adaptive(world, scorer, factors, observables, n_trials=10, seed=7)
np.testing.assert_allclose(r1.scores, r2.scores)


# ---------------------------------------------------------------------------
# Progress callback (#77)
# ---------------------------------------------------------------------------


def test_run_grid_callback_called(
world: _ToySimulator,
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""Callback is invoked once per trial with correct arguments."""
grid = [{"alpha": v} for v in [0.0, 0.25, 0.5]]
calls: list[tuple[int, int, TrialResult]] = []
run_grid(
world,
scorer,
grid,
observables,
callback=lambda i, n, r: calls.append((i, n, r)),
)
assert len(calls) == 3
for i, (idx, total, result) in enumerate(calls):
assert idx == i
assert total == 3
assert isinstance(result, TrialResult)


def test_run_grid_callback_none(
world: _ToySimulator,
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""No callback (default) runs without error."""
grid = [{"alpha": 0.5}]
result = run_grid(world, scorer, grid, observables)
assert len(result.configs) == 1
57 changes: 56 additions & 1 deletion tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import pytest

from trade_study.design import Factor, FactorType
from trade_study.protocols import Annotation, Direction, Observable, ResultsTable
from trade_study.protocols import (
Annotation,
Direction,
Observable,
ResultsTable,
TrialResult,
)
from trade_study.study import Phase, Study, top_k_pareto_filter, weighted_sum_filter

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -833,3 +839,52 @@ def test_weighted_sum_filter_in_phase(
)
study.run()
assert len(study.results("refine").configs) == 2


# ---------------------------------------------------------------------------
# Study.run() progress callback (#77)
# ---------------------------------------------------------------------------


def test_study_run_callback(
world: _ToySimulator,
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""Study.run(callback=...) invokes callback for each trial."""
grid = [{"alpha": v} for v in [0.0, 0.25, 0.5]]
calls: list[tuple[int, int, TrialResult]] = []
study = Study(
world=world,
scorer=scorer,
observables=observables,
phases=[Phase(name="p1", grid=grid)],
)
study.run(callback=lambda i, n, r: calls.append((i, n, r)))
assert len(calls) == 3


def test_study_run_callback_multi_phase(
world: _ToySimulator,
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""Callback fires across multiple grid phases."""
grid = [{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]]
calls: list[tuple[int, int, TrialResult]] = []
study = Study(
world=world,
scorer=scorer,
observables=observables,
phases=[
Phase(
name="disc",
grid=grid,
filter_fn=top_k_pareto_filter(k=2),
),
Phase(name="refine", grid="carry"),
],
)
study.run(callback=lambda i, n, r: calls.append((i, n, r)))
# 5 trials in phase 1 + 2 trials in phase 2
assert len(calls) == 7
Loading