diff --git a/src/trade_study/design.py b/src/trade_study/design.py index 39c577d..f407f65 100644 --- a/src/trade_study/design.py +++ b/src/trade_study/design.py @@ -228,8 +228,10 @@ def screen( run_fn: Callable that takes a config dict and returns a dict of observable name → scalar score. factors: List of continuous factors to screen. - method: Screening method ("morris" or "sobol"). - n_trajectories: Number of Morris trajectories or Sobol samples. + method: Screening method (``"morris"`` or ``"sobol"``). + n_trajectories: Number of Morris trajectories. For Sobol, this + controls the base sample size *N*; the total number of model + evaluations is *N* x (num_vars + 2). seed: Random seed. Returns: @@ -237,16 +239,9 @@ def screen( (mu_star for Morris, S1 for Sobol), one value per factor. Raises: - NotImplementedError: If method is not "morris". - ValueError: If no continuous factors are provided. + ValueError: If *method* is unknown or no continuous factors are + provided. """ - from SALib.analyze import morris as morris_analyze # type: ignore[import-untyped] - from SALib.sample import morris as morris_sample # type: ignore[import-untyped] - - if method != "morris": - msg = f"Screening method {method!r} not yet implemented" - raise NotImplementedError(msg) - continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS] if not continuous: msg = "Screening requires at least one continuous factor" @@ -257,9 +252,32 @@ def screen( "names": [f.name for f in continuous], "bounds": [list(f.bounds) for f in continuous if f.bounds is not None], } + + if method == "morris": + return _screen_morris(run_fn, problem, n_trajectories, seed) + if method == "sobol": + return _screen_sobol(run_fn, problem, n_trajectories, seed) + + msg = f"Unknown screening method: {method!r}" + raise ValueError(msg) + + +def _screen_morris( + run_fn: Callable[[dict[str, Any]], dict[str, float]], + problem: dict[str, Any], + n_trajectories: int, + seed: int, +) -> dict[str, NDArray[np.floating[Any]]]: + """Morris elementary-effects screening. + + Returns: + Mapping from observable name to mu_star array. + """ + from SALib.analyze import morris as morris_analyze # type: ignore[import-untyped] + from SALib.sample import morris as morris_sample # type: ignore[import-untyped] + param_values = morris_sample.sample(problem, n_trajectories, seed=seed) - # Evaluate model at each sample point results_by_obs: dict[str, list[float]] = {} for row in param_values: cfg = dict(zip(problem["names"], row, strict=True)) @@ -280,6 +298,41 @@ def screen( return importance +def _screen_sobol( + run_fn: Callable[[dict[str, Any]], dict[str, float]], + problem: dict[str, Any], + n_samples: int, + seed: int, +) -> dict[str, NDArray[np.floating[Any]]]: + """Sobol variance-based sensitivity analysis. + + Returns: + Mapping from observable name to S1 (first-order) index array. + """ + from SALib.analyze import sobol as sobol_analyze + from SALib.sample import sobol as sobol_sample + + param_values = sobol_sample.sample(problem, n_samples, seed=seed) + + results_by_obs: dict[str, list[float]] = {} + for row in param_values: + cfg = dict(zip(problem["names"], row, strict=True)) + scores = run_fn(cfg) + for obs_name, val in scores.items(): + results_by_obs.setdefault(obs_name, []).append(val) + + importance: dict[str, NDArray[np.floating[Any]]] = {} + for obs_name, vals in results_by_obs.items(): + si = sobol_analyze.analyze( + problem, + np.array(vals), + seed=seed, + ) + importance[obs_name] = np.asarray(si["S1"], dtype=np.float64) + + return importance + + def reduce_factors( factors: list[Factor], importance: dict[str, NDArray[np.floating[Any]]], diff --git a/tests/test_design.py b/tests/test_design.py index a0cd206..dcf2700 100644 --- a/tests/test_design.py +++ b/tests/test_design.py @@ -314,10 +314,10 @@ def multi_obs(cfg: dict[str, Any]) -> dict[str, float]: assert result["obs2"].shape == (2,) -def test_screen_rejects_non_morris() -> None: +def test_screen_rejects_unknown_method() -> None: factors = [Factor("x", FactorType.CONTINUOUS, bounds=(0.0, 1.0))] - with pytest.raises(NotImplementedError, match="not yet implemented"): - screen(lambda _c: {"y": 0.0}, factors, method="sobol") + with pytest.raises(ValueError, match="Unknown screening method"): + screen(lambda _c: {"y": 0.0}, factors, method="bogus") def test_screen_rejects_no_continuous() -> None: @@ -326,6 +326,70 @@ def test_screen_rejects_no_continuous() -> None: screen(lambda _c: {"y": 0.0}, factors) +# --------------------------------------------------------------------------- +# screen — Sobol (#76) +# --------------------------------------------------------------------------- + + +def test_screen_sobol_returns_dict(continuous_factors: list[Factor]) -> None: + result = screen( + _linear_model, + continuous_factors, + method="sobol", + n_trajectories=64, + seed=0, + ) + assert isinstance(result, dict) + assert "y" in result + + +def test_screen_sobol_importance_shape(continuous_factors: list[Factor]) -> None: + result = screen( + _linear_model, + continuous_factors, + method="sobol", + n_trajectories=64, + seed=0, + ) + assert result["y"].shape == (2,) + + +def test_screen_sobol_detects_influential_factor( + continuous_factors: list[Factor], +) -> None: + """Sobol S1 for alpha should dominate; beta should be near zero.""" + result = screen( + _linear_model, + continuous_factors, + method="sobol", + n_trajectories=256, + seed=0, + ) + assert result["y"][0] > result["y"][1] + assert result["y"][1] == pytest.approx(0.0, abs=0.1) + + +def test_screen_sobol_multiple_observables( + continuous_factors: list[Factor], +) -> None: + def multi_obs(cfg: dict[str, Any]) -> dict[str, float]: + return { + "obs1": cfg["alpha"], + "obs2": cfg["beta"], + } + + result = screen( + multi_obs, + continuous_factors, + method="sobol", + n_trajectories=64, + seed=0, + ) + assert set(result.keys()) == {"obs1", "obs2"} + assert result["obs1"].shape == (2,) + assert result["obs2"].shape == (2,) + + # --------------------------------------------------------------------------- # reduce_factors (#10) # ---------------------------------------------------------------------------