diff --git a/src/trade_study/__init__.py b/src/trade_study/__init__.py index dd47245..8d167d0 100644 --- a/src/trade_study/__init__.py +++ b/src/trade_study/__init__.py @@ -10,6 +10,7 @@ from .io import load_results, save_results from .protocols import ( Annotation, + Constraint, Direction, Observable, ResultsTable, @@ -19,11 +20,18 @@ ) from .runner import run_adaptive, run_grid from .stacking import ensemble_predict, stack_bayesian, stack_scores -from .study import Phase, Study, top_k_pareto_filter, weighted_sum_filter +from .study import ( + Phase, + Study, + feasibility_filter, + top_k_pareto_filter, + weighted_sum_filter, +) from .viz import plot_calibration, plot_front, plot_parallel, plot_scores __all__ = [ "Annotation", + "Constraint", "Direction", "Factor", "FactorType", @@ -39,6 +47,7 @@ "coverage_curve", "ensemble_predict", "extract_front", + "feasibility_filter", "hypervolume", "igd_plus", "load_results", diff --git a/src/trade_study/protocols.py b/src/trade_study/protocols.py index 7095d7f..b472179 100644 --- a/src/trade_study/protocols.py +++ b/src/trade_study/protocols.py @@ -2,11 +2,14 @@ from __future__ import annotations +import operator as _operator from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: + from collections.abc import Callable + import numpy as np from numpy.typing import NDArray @@ -34,6 +37,61 @@ class Observable: weight: float = 1.0 +_OP_MAP: dict[str, Callable[[Any, Any], bool]] = { + ">=": _operator.ge, + "<=": _operator.le, + ">": _operator.gt, + "<": _operator.lt, + "==": _operator.eq, + "!=": _operator.ne, +} + + +@dataclass(frozen=True) +class Constraint: + """Feasibility constraint on an observable or annotation. + + A design is feasible when ``scores[observable] threshold`` is + true. + + Attributes: + name: Human-readable label (e.g. ``"min_conversion"``). + observable: Name of the observable or annotation column to test. + op: Comparison operator as a string (``">="`` ``"<="`` ``">"`` + ``"<"`` ``"=="`` ``"!="``). + threshold: Scalar threshold value. + """ + + name: str + observable: str + op: str + threshold: float + + def __post_init__(self) -> None: + """Validate the comparison operator. + + Raises: + ValueError: If *op* is not one of the supported operators. + """ + if self.op not in _OP_MAP: + msg = ( + f"Constraint {self.name!r}: unsupported operator {self.op!r}. " + f"Use one of {sorted(_OP_MAP)}" + ) + raise ValueError(msg) + + def check(self, value: float) -> bool: + """Test whether a scalar value satisfies the constraint. + + Args: + value: Scalar score or annotation value to test. + + Returns: + ``True`` if the value satisfies the constraint. + """ + return bool(_OP_MAP[self.op](value, self.threshold)) + + @runtime_checkable class Simulator(Protocol): """Protocol for generating ground truth and observations. @@ -131,3 +189,41 @@ class ResultsTable: annotations: NDArray[np.floating[Any]] | None = None # (n_trials, n_annotations) annotation_names: list[str] = field(default_factory=list) metadata: list[dict[str, Any]] = field(default_factory=list) + + def feasible(self, constraints: list[Constraint]) -> NDArray[np.bool_]: + """Return a boolean mask indicating which rows satisfy all constraints. + + Each constraint references an observable or annotation column by + name. A row is feasible only when **every** constraint evaluates + to ``True``. + + Args: + constraints: Constraint objects to evaluate. + + Returns: + Boolean array of shape ``(n_trials,)``. + + Raises: + KeyError: If a constraint references a column not found in + either ``observable_names`` or ``annotation_names``. + """ + import numpy as np + + mask = np.ones(len(self.configs), dtype=np.bool_) + for con in constraints: + if con.observable in self.observable_names: + col_idx = self.observable_names.index(con.observable) + values = self.scores[:, col_idx] + elif ( + con.observable in self.annotation_names and self.annotations is not None + ): + col_idx = self.annotation_names.index(con.observable) + values = self.annotations[:, col_idx] + else: + msg = ( + f"Constraint {con.name!r}: column {con.observable!r} " + f"not found in observables or annotations" + ) + raise KeyError(msg) + mask &= _OP_MAP[con.op](values, con.threshold) + return mask diff --git a/src/trade_study/study.py b/src/trade_study/study.py index 8f1d220..4ca2a57 100644 --- a/src/trade_study/study.py +++ b/src/trade_study/study.py @@ -23,6 +23,7 @@ from .protocols import ( Annotation, + Constraint, Observable, ResultsTable, Scorer, @@ -144,6 +145,28 @@ def _filter( return _filter +def feasibility_filter( + constraints: list[Constraint], +) -> Callable[[ResultsTable, list[Observable]], NDArray[np.intp]]: + """Create a filter that keeps only designs satisfying all constraints. + + Args: + constraints: Constraint objects to evaluate against results. + + Returns: + Filter function compatible with ``Phase.filter_fn``. + """ + + def _filter( + results: ResultsTable, + _observables: list[Observable], + ) -> NDArray[np.intp]: + mask = results.feasible(constraints) + return np.nonzero(mask)[0].astype(np.intp) + + return _filter + + @dataclass class Study: """Multi-phase model criticism study. diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 57e6a14..d9aaac0 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -8,7 +8,13 @@ import pytest from trade_study import Direction, Observable -from trade_study.protocols import Annotation, ResultsTable, Scorer, Simulator +from trade_study.protocols import ( + Annotation, + Constraint, + ResultsTable, + Scorer, + Simulator, +) # -- Direction enum ------------------------------------------------------------ @@ -298,3 +304,129 @@ def test_results_table_observable_names_order() -> None: observable_names=names, ) assert rt.observable_names == ["coverage", "rmse", "wall_time"] + + +# -- Constraint dataclass ------------------------------------------------------ + + +def test_constraint_creation() -> None: + c = Constraint(name="min_cov", observable="coverage", op=">=", threshold=0.5) + assert c.name == "min_cov" + assert c.observable == "coverage" + assert c.op == ">=" + assert c.threshold == pytest.approx(0.5) + + +def test_constraint_frozen() -> None: + c = Constraint(name="c", observable="x", op=">=", threshold=0.0) + with pytest.raises(AttributeError): + c.threshold = 1.0 # type: ignore[misc] + + +def test_constraint_invalid_op_raises() -> None: + with pytest.raises(ValueError, match="unsupported operator"): + Constraint(name="bad", observable="x", op="~", threshold=0.0) + + +@pytest.mark.parametrize( + ("op", "value", "threshold", "expected"), + [ + (">=", 0.6, 0.5, True), + (">=", 0.5, 0.5, True), + (">=", 0.4, 0.5, False), + ("<=", 0.4, 0.5, True), + ("<=", 0.5, 0.5, True), + ("<=", 0.6, 0.5, False), + (">", 0.6, 0.5, True), + (">", 0.5, 0.5, False), + ("<", 0.4, 0.5, True), + ("<", 0.5, 0.5, False), + ("==", 0.5, 0.5, True), + ("==", 0.6, 0.5, False), + ("!=", 0.6, 0.5, True), + ("!=", 0.5, 0.5, False), + ], +) +def test_constraint_check( + op: str, value: float, threshold: float, *, expected: bool +) -> None: + c = Constraint(name="test", observable="x", op=op, threshold=threshold) + assert c.check(value) is expected + + +# -- ResultsTable.feasible() --------------------------------------------------- + + +@pytest.fixture +def scored_table() -> ResultsTable: + """Table with 5 rows, 2 observables (coverage, cost). + + Returns: + A ResultsTable with coverage and cost columns. + """ + return ResultsTable( + configs=[{"a": i} for i in range(5)], + scores=np.array([ + [0.9, 100.0], + [0.4, 50.0], + [0.6, 80.0], + [0.3, 30.0], + [0.7, 60.0], + ]), + observable_names=["coverage", "cost"], + ) + + +def test_feasible_single_constraint(scored_table: ResultsTable) -> None: + constraints = [Constraint("min_cov", "coverage", ">=", 0.5)] + mask = scored_table.feasible(constraints) + assert mask.dtype == np.bool_ + expected = np.array([True, False, True, False, True]) + np.testing.assert_array_equal(mask, expected) + + +def test_feasible_multiple_constraints(scored_table: ResultsTable) -> None: + constraints = [ + Constraint("min_cov", "coverage", ">=", 0.5), + Constraint("max_cost", "cost", "<=", 80.0), + ] + mask = scored_table.feasible(constraints) + expected = np.array([False, False, True, False, True]) + np.testing.assert_array_equal(mask, expected) + + +def test_feasible_all_pass(scored_table: ResultsTable) -> None: + constraints = [Constraint("low_bar", "coverage", ">=", 0.0)] + mask = scored_table.feasible(constraints) + assert mask.all() + + +def test_feasible_none_pass(scored_table: ResultsTable) -> None: + constraints = [Constraint("high_bar", "coverage", ">", 1.0)] + mask = scored_table.feasible(constraints) + assert not mask.any() + + +def test_feasible_annotation_column() -> None: + rt = ResultsTable( + configs=[{"a": 1}, {"a": 2}], + scores=np.array([[0.5], [0.6]]), + observable_names=["rmse"], + annotations=np.array([[10.0], [50.0]]), + annotation_names=["dollar_cost"], + ) + constraints = [Constraint("budget", "dollar_cost", "<=", 20.0)] + mask = rt.feasible(constraints) + np.testing.assert_array_equal(mask, np.array([True, False])) + + +def test_feasible_unknown_column_raises(scored_table: ResultsTable) -> None: + constraints = [Constraint("bad", "nonexistent", ">=", 0.0)] + with pytest.raises(KeyError, match="nonexistent"): + scored_table.feasible(constraints) + + +def test_feasible_empty_constraints(scored_table: ResultsTable) -> None: + mask = scored_table.feasible([]) + assert mask.all() + assert len(mask) == 5 diff --git a/tests/test_study.py b/tests/test_study.py index 9ffa1ca..d2cd3aa 100644 --- a/tests/test_study.py +++ b/tests/test_study.py @@ -10,12 +10,19 @@ from trade_study.design import Factor, FactorType from trade_study.protocols import ( Annotation, + Constraint, Direction, Observable, ResultsTable, TrialResult, ) -from trade_study.study import Phase, Study, top_k_pareto_filter, weighted_sum_filter +from trade_study.study import ( + Phase, + Study, + feasibility_filter, + top_k_pareto_filter, + weighted_sum_filter, +) # --------------------------------------------------------------------------- # Toy implementations (same pattern as test_runner) @@ -888,3 +895,99 @@ def test_study_run_callback_multi_phase( study.run(callback=lambda i, n, r: calls.append((i, n, r))) # 5 trials in phase 1 + 2 trials in phase 2 assert len(calls) == 7 + + +# --------------------------------------------------------------------------- +# feasibility_filter (#74) +# --------------------------------------------------------------------------- + + +def test_feasibility_filter_returns_callable() -> None: + fn = feasibility_filter(constraints=[]) + assert callable(fn) + + +def test_feasibility_filter_keeps_feasible( + observables: list[Observable], +) -> None: + rt = ResultsTable( + configs=[{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]], + scores=np.array([ + [0.5, 0.0], + [0.25, 2.5], + [0.0, 5.0], + [0.25, 7.5], + [0.5, 10.0], + ]), + observable_names=["error", "cost"], + ) + fn = feasibility_filter([ + Constraint("low_error", "error", "<=", 0.25), + ]) + idx = fn(rt, observables) + assert set(idx.tolist()) == {1, 2, 3} + + +def test_feasibility_filter_multiple_constraints( + observables: list[Observable], +) -> None: + rt = ResultsTable( + configs=[{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]], + scores=np.array([ + [0.5, 0.0], + [0.25, 2.5], + [0.0, 5.0], + [0.25, 7.5], + [0.5, 10.0], + ]), + observable_names=["error", "cost"], + ) + fn = feasibility_filter([ + Constraint("low_error", "error", "<=", 0.25), + Constraint("low_cost", "cost", "<=", 5.0), + ]) + idx = fn(rt, observables) + assert set(idx.tolist()) == {1, 2} + + +def test_feasibility_filter_none_feasible( + observables: list[Observable], +) -> None: + rt = ResultsTable( + configs=[{"alpha": 0.0}], + scores=np.array([[0.5, 0.0]]), + observable_names=["error", "cost"], + ) + fn = feasibility_filter([ + Constraint("impossible", "error", "<", 0.0), + ]) + idx = fn(rt, observables) + assert len(idx) == 0 + + +def test_feasibility_filter_in_study( + world: _ToySimulator, + scorer: _ToyScorer, + observables: list[Observable], +) -> None: + """feasibility_filter works as a Phase.filter_fn in a Study.""" + grid = [{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]] + study = Study( + world=world, + scorer=scorer, + observables=observables, + phases=[ + Phase( + name="screen", + grid=grid, + filter_fn=feasibility_filter([ + Constraint("low_cost", "cost", "<=", 5.0), + ]), + ), + Phase(name="refine", grid="carry"), + ], + ) + study.run() + final = study.results("refine") + # alpha=0.0 (cost=0), alpha=0.25 (cost=2.5), alpha=0.5 (cost=5.0) satisfy cost <= 5 + assert final.scores.shape[0] == 3