Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/trade_study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .io import load_results, save_results
from .protocols import (
Annotation,
Constraint,
Direction,
Observable,
ResultsTable,
Expand All @@ -19,11 +20,18 @@
)
from .runner import run_adaptive, run_grid
from .stacking import ensemble_predict, stack_bayesian, stack_scores
from .study import Phase, Study, top_k_pareto_filter, weighted_sum_filter
from .study import (
Phase,
Study,
feasibility_filter,
top_k_pareto_filter,
weighted_sum_filter,
)
from .viz import plot_calibration, plot_front, plot_parallel, plot_scores

__all__ = [
"Annotation",
"Constraint",
"Direction",
"Factor",
"FactorType",
Expand All @@ -39,6 +47,7 @@
"coverage_curve",
"ensemble_predict",
"extract_front",
"feasibility_filter",
"hypervolume",
"igd_plus",
"load_results",
Expand Down
96 changes: 96 additions & 0 deletions src/trade_study/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from __future__ import annotations

import operator as _operator
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

if TYPE_CHECKING:
from collections.abc import Callable

import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -34,6 +37,61 @@ class Observable:
weight: float = 1.0


_OP_MAP: dict[str, Callable[[Any, Any], bool]] = {
">=": _operator.ge,
"<=": _operator.le,
">": _operator.gt,
"<": _operator.lt,
"==": _operator.eq,
"!=": _operator.ne,
}


@dataclass(frozen=True)
class Constraint:
"""Feasibility constraint on an observable or annotation.

A design is feasible when ``scores[observable] <op> threshold`` is
true.

Attributes:
name: Human-readable label (e.g. ``"min_conversion"``).
observable: Name of the observable or annotation column to test.
op: Comparison operator as a string (``">="`` ``"<="`` ``">"``
``"<"`` ``"=="`` ``"!="``).
threshold: Scalar threshold value.
"""

name: str
observable: str
op: str
threshold: float

def __post_init__(self) -> None:
"""Validate the comparison operator.

Raises:
ValueError: If *op* is not one of the supported operators.
"""
if self.op not in _OP_MAP:
msg = (
f"Constraint {self.name!r}: unsupported operator {self.op!r}. "
f"Use one of {sorted(_OP_MAP)}"
)
raise ValueError(msg)

def check(self, value: float) -> bool:
"""Test whether a scalar value satisfies the constraint.

Args:
value: Scalar score or annotation value to test.

Returns:
``True`` if the value satisfies the constraint.
"""
return bool(_OP_MAP[self.op](value, self.threshold))


@runtime_checkable
class Simulator(Protocol):
"""Protocol for generating ground truth and observations.
Expand Down Expand Up @@ -131,3 +189,41 @@ class ResultsTable:
annotations: NDArray[np.floating[Any]] | None = None # (n_trials, n_annotations)
annotation_names: list[str] = field(default_factory=list)
metadata: list[dict[str, Any]] = field(default_factory=list)

def feasible(self, constraints: list[Constraint]) -> NDArray[np.bool_]:
"""Return a boolean mask indicating which rows satisfy all constraints.

Each constraint references an observable or annotation column by
name. A row is feasible only when **every** constraint evaluates
to ``True``.

Args:
constraints: Constraint objects to evaluate.

Returns:
Boolean array of shape ``(n_trials,)``.

Raises:
KeyError: If a constraint references a column not found in
either ``observable_names`` or ``annotation_names``.
"""
import numpy as np

mask = np.ones(len(self.configs), dtype=np.bool_)
for con in constraints:
if con.observable in self.observable_names:
col_idx = self.observable_names.index(con.observable)
values = self.scores[:, col_idx]
elif (
con.observable in self.annotation_names and self.annotations is not None
):
col_idx = self.annotation_names.index(con.observable)
values = self.annotations[:, col_idx]
else:
msg = (
f"Constraint {con.name!r}: column {con.observable!r} "
f"not found in observables or annotations"
)
raise KeyError(msg)
mask &= _OP_MAP[con.op](values, con.threshold)
return mask
23 changes: 23 additions & 0 deletions src/trade_study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .protocols import (
Annotation,
Constraint,
Observable,
ResultsTable,
Scorer,
Expand Down Expand Up @@ -144,6 +145,28 @@ def _filter(
return _filter


def feasibility_filter(
constraints: list[Constraint],
) -> Callable[[ResultsTable, list[Observable]], NDArray[np.intp]]:
"""Create a filter that keeps only designs satisfying all constraints.

Args:
constraints: Constraint objects to evaluate against results.

Returns:
Filter function compatible with ``Phase.filter_fn``.
"""

def _filter(
results: ResultsTable,
_observables: list[Observable],
) -> NDArray[np.intp]:
mask = results.feasible(constraints)
return np.nonzero(mask)[0].astype(np.intp)

return _filter


@dataclass
class Study:
"""Multi-phase model criticism study.
Expand Down
134 changes: 133 additions & 1 deletion tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import pytest

from trade_study import Direction, Observable
from trade_study.protocols import Annotation, ResultsTable, Scorer, Simulator
from trade_study.protocols import (
Annotation,
Constraint,
ResultsTable,
Scorer,
Simulator,
)

# -- Direction enum ------------------------------------------------------------

Expand Down Expand Up @@ -298,3 +304,129 @@ def test_results_table_observable_names_order() -> None:
observable_names=names,
)
assert rt.observable_names == ["coverage", "rmse", "wall_time"]


# -- Constraint dataclass ------------------------------------------------------


def test_constraint_creation() -> None:
c = Constraint(name="min_cov", observable="coverage", op=">=", threshold=0.5)
assert c.name == "min_cov"
assert c.observable == "coverage"
assert c.op == ">="
assert c.threshold == pytest.approx(0.5)


def test_constraint_frozen() -> None:
c = Constraint(name="c", observable="x", op=">=", threshold=0.0)
with pytest.raises(AttributeError):
c.threshold = 1.0 # type: ignore[misc]


def test_constraint_invalid_op_raises() -> None:
with pytest.raises(ValueError, match="unsupported operator"):
Constraint(name="bad", observable="x", op="~", threshold=0.0)


@pytest.mark.parametrize(
("op", "value", "threshold", "expected"),
[
(">=", 0.6, 0.5, True),
(">=", 0.5, 0.5, True),
(">=", 0.4, 0.5, False),
("<=", 0.4, 0.5, True),
("<=", 0.5, 0.5, True),
("<=", 0.6, 0.5, False),
(">", 0.6, 0.5, True),
(">", 0.5, 0.5, False),
("<", 0.4, 0.5, True),
("<", 0.5, 0.5, False),
("==", 0.5, 0.5, True),
("==", 0.6, 0.5, False),
("!=", 0.6, 0.5, True),
("!=", 0.5, 0.5, False),
],
)
def test_constraint_check(
op: str, value: float, threshold: float, *, expected: bool
) -> None:
c = Constraint(name="test", observable="x", op=op, threshold=threshold)
assert c.check(value) is expected


# -- ResultsTable.feasible() ---------------------------------------------------


@pytest.fixture
def scored_table() -> ResultsTable:
"""Table with 5 rows, 2 observables (coverage, cost).

Returns:
A ResultsTable with coverage and cost columns.
"""
return ResultsTable(
configs=[{"a": i} for i in range(5)],
scores=np.array([
[0.9, 100.0],
[0.4, 50.0],
[0.6, 80.0],
[0.3, 30.0],
[0.7, 60.0],
]),
observable_names=["coverage", "cost"],
)


def test_feasible_single_constraint(scored_table: ResultsTable) -> None:
constraints = [Constraint("min_cov", "coverage", ">=", 0.5)]
mask = scored_table.feasible(constraints)
assert mask.dtype == np.bool_
expected = np.array([True, False, True, False, True])
np.testing.assert_array_equal(mask, expected)


def test_feasible_multiple_constraints(scored_table: ResultsTable) -> None:
constraints = [
Constraint("min_cov", "coverage", ">=", 0.5),
Constraint("max_cost", "cost", "<=", 80.0),
]
mask = scored_table.feasible(constraints)
expected = np.array([False, False, True, False, True])
np.testing.assert_array_equal(mask, expected)


def test_feasible_all_pass(scored_table: ResultsTable) -> None:
constraints = [Constraint("low_bar", "coverage", ">=", 0.0)]
mask = scored_table.feasible(constraints)
assert mask.all()


def test_feasible_none_pass(scored_table: ResultsTable) -> None:
constraints = [Constraint("high_bar", "coverage", ">", 1.0)]
mask = scored_table.feasible(constraints)
assert not mask.any()


def test_feasible_annotation_column() -> None:
rt = ResultsTable(
configs=[{"a": 1}, {"a": 2}],
scores=np.array([[0.5], [0.6]]),
observable_names=["rmse"],
annotations=np.array([[10.0], [50.0]]),
annotation_names=["dollar_cost"],
)
constraints = [Constraint("budget", "dollar_cost", "<=", 20.0)]
mask = rt.feasible(constraints)
np.testing.assert_array_equal(mask, np.array([True, False]))


def test_feasible_unknown_column_raises(scored_table: ResultsTable) -> None:
constraints = [Constraint("bad", "nonexistent", ">=", 0.0)]
with pytest.raises(KeyError, match="nonexistent"):
scored_table.feasible(constraints)


def test_feasible_empty_constraints(scored_table: ResultsTable) -> None:
mask = scored_table.feasible([])
assert mask.all()
assert len(mask) == 5
Loading
Loading