Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions src/trade_study/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,20 @@ def screen(
run_fn: Callable that takes a config dict and returns a dict of
observable name → scalar score.
factors: List of continuous factors to screen.
method: Screening method ("morris" or "sobol").
n_trajectories: Number of Morris trajectories or Sobol samples.
method: Screening method (``"morris"`` or ``"sobol"``).
n_trajectories: Number of Morris trajectories. For Sobol, this
controls the base sample size *N*; the total number of model
evaluations is *N* x (num_vars + 2).
seed: Random seed.

Returns:
Dictionary mapping observable names to arrays of factor importance
(mu_star for Morris, S1 for Sobol), one value per factor.

Raises:
NotImplementedError: If method is not "morris".
ValueError: If no continuous factors are provided.
ValueError: If *method* is unknown or no continuous factors are
provided.
"""
from SALib.analyze import morris as morris_analyze # type: ignore[import-untyped]
from SALib.sample import morris as morris_sample # type: ignore[import-untyped]

if method != "morris":
msg = f"Screening method {method!r} not yet implemented"
raise NotImplementedError(msg)

continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS]
if not continuous:
msg = "Screening requires at least one continuous factor"
Expand All @@ -257,9 +252,32 @@ def screen(
"names": [f.name for f in continuous],
"bounds": [list(f.bounds) for f in continuous if f.bounds is not None],
}

if method == "morris":
return _screen_morris(run_fn, problem, n_trajectories, seed)
if method == "sobol":
return _screen_sobol(run_fn, problem, n_trajectories, seed)

msg = f"Unknown screening method: {method!r}"
raise ValueError(msg)


def _screen_morris(
run_fn: Callable[[dict[str, Any]], dict[str, float]],
problem: dict[str, Any],
n_trajectories: int,
seed: int,
) -> dict[str, NDArray[np.floating[Any]]]:
"""Morris elementary-effects screening.

Returns:
Mapping from observable name to mu_star array.
"""
from SALib.analyze import morris as morris_analyze # type: ignore[import-untyped]
from SALib.sample import morris as morris_sample # type: ignore[import-untyped]

param_values = morris_sample.sample(problem, n_trajectories, seed=seed)

# Evaluate model at each sample point
results_by_obs: dict[str, list[float]] = {}
for row in param_values:
cfg = dict(zip(problem["names"], row, strict=True))
Expand All @@ -280,6 +298,41 @@ def screen(
return importance


def _screen_sobol(
run_fn: Callable[[dict[str, Any]], dict[str, float]],
problem: dict[str, Any],
n_samples: int,
seed: int,
) -> dict[str, NDArray[np.floating[Any]]]:
"""Sobol variance-based sensitivity analysis.

Returns:
Mapping from observable name to S1 (first-order) index array.
"""
from SALib.analyze import sobol as sobol_analyze
from SALib.sample import sobol as sobol_sample

param_values = sobol_sample.sample(problem, n_samples, seed=seed)

results_by_obs: dict[str, list[float]] = {}
for row in param_values:
cfg = dict(zip(problem["names"], row, strict=True))
scores = run_fn(cfg)
for obs_name, val in scores.items():
results_by_obs.setdefault(obs_name, []).append(val)

importance: dict[str, NDArray[np.floating[Any]]] = {}
for obs_name, vals in results_by_obs.items():
si = sobol_analyze.analyze(
problem,
np.array(vals),
seed=seed,
)
importance[obs_name] = np.asarray(si["S1"], dtype=np.float64)

return importance


def reduce_factors(
factors: list[Factor],
importance: dict[str, NDArray[np.floating[Any]]],
Expand Down
70 changes: 67 additions & 3 deletions tests/test_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def multi_obs(cfg: dict[str, Any]) -> dict[str, float]:
assert result["obs2"].shape == (2,)


def test_screen_rejects_non_morris() -> None:
def test_screen_rejects_unknown_method() -> None:
factors = [Factor("x", FactorType.CONTINUOUS, bounds=(0.0, 1.0))]
with pytest.raises(NotImplementedError, match="not yet implemented"):
screen(lambda _c: {"y": 0.0}, factors, method="sobol")
with pytest.raises(ValueError, match="Unknown screening method"):
screen(lambda _c: {"y": 0.0}, factors, method="bogus")


def test_screen_rejects_no_continuous() -> None:
Expand All @@ -326,6 +326,70 @@ def test_screen_rejects_no_continuous() -> None:
screen(lambda _c: {"y": 0.0}, factors)


# ---------------------------------------------------------------------------
# screen — Sobol (#76)
# ---------------------------------------------------------------------------


def test_screen_sobol_returns_dict(continuous_factors: list[Factor]) -> None:
result = screen(
_linear_model,
continuous_factors,
method="sobol",
n_trajectories=64,
seed=0,
)
assert isinstance(result, dict)
assert "y" in result


def test_screen_sobol_importance_shape(continuous_factors: list[Factor]) -> None:
result = screen(
_linear_model,
continuous_factors,
method="sobol",
n_trajectories=64,
seed=0,
)
assert result["y"].shape == (2,)


def test_screen_sobol_detects_influential_factor(
continuous_factors: list[Factor],
) -> None:
"""Sobol S1 for alpha should dominate; beta should be near zero."""
result = screen(
_linear_model,
continuous_factors,
method="sobol",
n_trajectories=256,
seed=0,
)
assert result["y"][0] > result["y"][1]
assert result["y"][1] == pytest.approx(0.0, abs=0.1)


def test_screen_sobol_multiple_observables(
continuous_factors: list[Factor],
) -> None:
def multi_obs(cfg: dict[str, Any]) -> dict[str, float]:
return {
"obs1": cfg["alpha"],
"obs2": cfg["beta"],
}

result = screen(
multi_obs,
continuous_factors,
method="sobol",
n_trajectories=64,
seed=0,
)
assert set(result.keys()) == {"obs1", "obs2"}
assert result["obs1"].shape == (2,)
assert result["obs2"].shape == (2,)


# ---------------------------------------------------------------------------
# reduce_factors (#10)
# ---------------------------------------------------------------------------
Expand Down
Loading