From 74852e173fe1d7566f7aaa2133d669db8e498aeb Mon Sep 17 00:00:00 2001 From: Joshua Date: Fri, 17 Apr 2026 10:11:15 -0400 Subject: [PATCH] feat: add progress callback to run_grid and Study.run - Add callback: ProgressCallback | None parameter to run_grid() Called after each trial with (trial_index, total_trials, result) - In serial mode, callback fires immediately after each trial - In parallel mode, callback fires post-hoc after all trials complete - Forward callback from Study.run() to run_grid() for grid phases - Add ProgressCallback type alias in runner module - 4 new tests (2 in test_runner, 2 in test_study) Closes #77 --- src/trade_study/runner.py | 18 ++++++++++++- src/trade_study/study.py | 15 ++++++++++- tests/test_runner.py | 40 ++++++++++++++++++++++++++- tests/test_study.py | 57 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/src/trade_study/runner.py b/src/trade_study/runner.py index 43c1cd7..1eed569 100644 --- a/src/trade_study/runner.py +++ b/src/trade_study/runner.py @@ -22,10 +22,14 @@ ) if TYPE_CHECKING: + from collections.abc import Callable + import optuna from .design import Factor + ProgressCallback = Callable[[int, int, TrialResult], None] + def _run_single( world: Simulator, @@ -52,6 +56,7 @@ def run_grid( *, annotations: list[Annotation] | None = None, n_jobs: int = 1, + callback: ProgressCallback | None = None, ) -> ResultsTable: """Run all configurations in a grid. @@ -62,18 +67,29 @@ def run_grid( observables: Observable definitions (for column ordering). annotations: Optional external annotations (costs, etc.). n_jobs: Number of parallel workers (-1 for all CPUs). + callback: Optional progress callback invoked after each trial + with ``(trial_index, total_trials, trial_result)``. Returns: ResultsTable with scored results. """ + total = len(grid) if n_jobs == 1: - results = [_run_single(world, scorer, cfg) for cfg in grid] + results: list[TrialResult] = [] + for i, cfg in enumerate(grid): + r = _run_single(world, scorer, cfg) + results.append(r) + if callback is not None: + callback(i, total, r) else: from joblib import Parallel, delayed # type: ignore[import-untyped] results = Parallel(n_jobs=n_jobs)( delayed(_run_single)(world, scorer, cfg) for cfg in grid ) + if callback is not None: + for i, r in enumerate(results): + callback(i, total, r) obs_names = [o.name for o in observables] score_matrix = np.array([ diff --git a/src/trade_study/study.py b/src/trade_study/study.py index e54aa3c..8f1d220 100644 --- a/src/trade_study/study.py +++ b/src/trade_study/study.py @@ -28,6 +28,7 @@ Scorer, Simulator, ) + from .runner import ProgressCallback GridCallable = Callable[[ResultsTable, list[Observable]], list[dict[str, Any]]] @@ -165,9 +166,19 @@ class Study: _results: dict[str, ResultsTable] = field(default_factory=dict, init=False) - def run(self, *, n_jobs: int = 1) -> None: + def run( + self, + *, + n_jobs: int = 1, + callback: ProgressCallback | None = None, + ) -> None: """Execute all phases sequentially. + Args: + n_jobs: Number of parallel workers for grid phases. + callback: Optional progress callback invoked after each trial + with ``(trial_index, total_trials, trial_result)``. + Raises: ValueError: If a callable grid is used on the first phase (no previous results to pass). @@ -198,6 +209,7 @@ def run(self, *, n_jobs: int = 1) -> None: self.observables, annotations=self.annotations or None, n_jobs=n_jobs, + callback=callback, ) else: grid = ( @@ -210,6 +222,7 @@ def run(self, *, n_jobs: int = 1) -> None: self.observables, annotations=self.annotations or None, n_jobs=n_jobs, + callback=callback, ) self._results[phase.name] = result diff --git a/tests/test_runner.py b/tests/test_runner.py index fc5d556..ec0ef19 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -8,7 +8,7 @@ import pytest from trade_study.design import Factor, FactorType -from trade_study.protocols import Annotation, Direction, Observable +from trade_study.protocols import Annotation, Direction, Observable, TrialResult from trade_study.runner import run_adaptive, run_grid # --------------------------------------------------------------------------- @@ -320,3 +320,41 @@ def test_run_adaptive_deterministic_seed( r1 = run_adaptive(world, scorer, factors, observables, n_trials=10, seed=7) r2 = run_adaptive(world, scorer, factors, observables, n_trials=10, seed=7) np.testing.assert_allclose(r1.scores, r2.scores) + + +# --------------------------------------------------------------------------- +# Progress callback (#77) +# --------------------------------------------------------------------------- + + +def test_run_grid_callback_called( + world: _ToySimulator, + scorer: _ToyScorer, + observables: list[Observable], +) -> None: + """Callback is invoked once per trial with correct arguments.""" + grid = [{"alpha": v} for v in [0.0, 0.25, 0.5]] + calls: list[tuple[int, int, TrialResult]] = [] + run_grid( + world, + scorer, + grid, + observables, + callback=lambda i, n, r: calls.append((i, n, r)), + ) + assert len(calls) == 3 + for i, (idx, total, result) in enumerate(calls): + assert idx == i + assert total == 3 + assert isinstance(result, TrialResult) + + +def test_run_grid_callback_none( + world: _ToySimulator, + scorer: _ToyScorer, + observables: list[Observable], +) -> None: + """No callback (default) runs without error.""" + grid = [{"alpha": 0.5}] + result = run_grid(world, scorer, grid, observables) + assert len(result.configs) == 1 diff --git a/tests/test_study.py b/tests/test_study.py index 2c28d22..9ffa1ca 100644 --- a/tests/test_study.py +++ b/tests/test_study.py @@ -8,7 +8,13 @@ import pytest from trade_study.design import Factor, FactorType -from trade_study.protocols import Annotation, Direction, Observable, ResultsTable +from trade_study.protocols import ( + Annotation, + Direction, + Observable, + ResultsTable, + TrialResult, +) from trade_study.study import Phase, Study, top_k_pareto_filter, weighted_sum_filter # --------------------------------------------------------------------------- @@ -833,3 +839,52 @@ def test_weighted_sum_filter_in_phase( ) study.run() assert len(study.results("refine").configs) == 2 + + +# --------------------------------------------------------------------------- +# Study.run() progress callback (#77) +# --------------------------------------------------------------------------- + + +def test_study_run_callback( + world: _ToySimulator, + scorer: _ToyScorer, + observables: list[Observable], +) -> None: + """Study.run(callback=...) invokes callback for each trial.""" + grid = [{"alpha": v} for v in [0.0, 0.25, 0.5]] + calls: list[tuple[int, int, TrialResult]] = [] + study = Study( + world=world, + scorer=scorer, + observables=observables, + phases=[Phase(name="p1", grid=grid)], + ) + study.run(callback=lambda i, n, r: calls.append((i, n, r))) + assert len(calls) == 3 + + +def test_study_run_callback_multi_phase( + world: _ToySimulator, + scorer: _ToyScorer, + observables: list[Observable], +) -> None: + """Callback fires across multiple grid phases.""" + grid = [{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]] + calls: list[tuple[int, int, TrialResult]] = [] + study = Study( + world=world, + scorer=scorer, + observables=observables, + phases=[ + Phase( + name="disc", + grid=grid, + filter_fn=top_k_pareto_filter(k=2), + ), + Phase(name="refine", grid="carry"), + ], + ) + study.run(callback=lambda i, n, r: calls.append((i, n, r))) + # 5 trials in phase 1 + 2 trials in phase 2 + assert len(calls) == 7