Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions docs/guide/bayesian.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,43 @@ separate evaluation from analysis:
--8<-- "examples/bayesian_study.py:persistence"
```

## Multi-fidelity workflow

Real-world studies often have an expensive simulator — full MCMC, a
fine-mesh CFD solver, or an agent-based epidemiological model.
Running every candidate design at full fidelity is wasteful. A
multi-fidelity strategy screens many designs cheaply, then validates
only the promising ones at high fidelity.

`Phase` supports this via optional `world` and `scorer` overrides.
When set, a phase uses its own simulator instead of the `Study`-level
default. Here the cheap surrogate draws only 50 posterior samples
(fast but noisy CRPS estimates), while the validation phase draws
2 000:

```python
--8<-- "examples/bayesian_study.py:multifidelity"
```

The `Study` orchestrates both phases: the first screens 60 designs
with the cheap surrogate and keeps the top 10 by Pareto rank, then
the second re-evaluates those 10 designs with the expensive model.

This pattern applies whenever fidelity is a computational strategy
rather than a design factor:

- **Epidemiology** — screen surveillance designs with a deterministic
ODE model, validate the best with a stochastic agent-based model.
- **Engineering** — coarse mesh for broad exploration, fine mesh for
the Pareto front.
- **Forecast model grading** — fast approximate inference for
screening, full HMC for final assessment.

## What to try next

- Swap `method="morris"` for `method="sobol"` in `screen()` for
variance-based sensitivity indices.
- Use `Constraint` + `feasibility_filter` to enforce
`coverage_95 >= 0.90` before stacking.
- Replace `run_grid` with a two-phase `Study` that screens first
and refines around the Pareto front.
- Try `stack_bayesian()` on models that expose log-likelihood
(requires `arviz`).
62 changes: 62 additions & 0 deletions examples/bayesian_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Factor,
FactorType,
Observable,
Phase,
Study,
build_grid,
coverage_curve,
ensemble_predict,
Expand All @@ -36,6 +38,7 @@
score,
screen,
stack_scores,
top_k_pareto_filter,
)

ASSET_DIR = "docs/assets"
Expand Down Expand Up @@ -130,8 +133,21 @@ class BayesianRegressionSimulator:
Each config specifies prior hyperparameters and sample size.
The "truth" is the held-out test set; "observations" are the
posterior predictive samples at those test points.

Args:
n_samples: Number of posterior predictive draws. Fewer draws
give faster but noisier score estimates — useful as a cheap
surrogate in a multi-fidelity workflow.
"""

def __init__(self, n_samples: int = 500) -> None:
"""Initialise the simulator.

Args:
n_samples: Number of posterior predictive draws.
"""
self.n_samples = n_samples

def generate(self, config: dict[str, Any]) -> tuple[Any, Any]:
"""Draw training data, fit posterior, and return test predictions.

Expand All @@ -155,6 +171,7 @@ def generate(self, config: dict[str, Any]) -> tuple[Any, Any]:
y_train,
prior_var=config["prior_var"],
noise_scale=config["noise_scale"],
n_samples=self.n_samples,
)

observations = {
Expand Down Expand Up @@ -371,6 +388,50 @@ def _save_plots(results: Any, directions: Any, nominal: Any, empirical: Any) ->
# --8<-- [end:plots]


def _run_multi_fidelity() -> None:
"""Demonstrate multi-fidelity via Phase-level world overrides."""
# --8<-- [start:multifidelity]
# Cheap surrogate: only 50 posterior draws (fast, noisier scores)
cheap_world = BayesianRegressionSimulator(n_samples=50)
# Expensive model: 2000 posterior draws (slow, precise scores)
expensive_world = BayesianRegressionSimulator(n_samples=2000)

grid = build_grid(factors, method="lhs", n_samples=60, seed=42)
for cfg in grid:
cfg["n_obs"] = round(cfg["n_obs"])

study = Study(
world=expensive_world,
scorer=BayesianRegressionScorer(),
observables=observables,
phases=[
# Phase 1: screen 60 designs with the cheap surrogate
Phase(
name="screen",
grid=grid,
world=cheap_world,
filter_fn=top_k_pareto_filter(k=10),
),
# Phase 2: validate top 10 with the expensive model
Phase(name="validate", grid="carry"),
],
annotations=[compute_cost],
)
study.run()

screen_r = study.results("screen")
validate_r = study.results("validate")
print("\nMulti-fidelity study:")
print(f" Screen phase: {screen_r.scores.shape[0]} designs (50 draws)")
print(f" Validate phase: {validate_r.scores.shape[0]} designs (2000 draws)")

directions = [o.direction for o in observables]
weights = [o.weight for o in observables]
front_idx = extract_front(validate_r.scores, directions, weights)
print(f" Final Pareto front: {len(front_idx)} designs")
# --8<-- [end:multifidelity]


def main() -> None:
"""Run the Bayesian model criticism study."""
world = BayesianRegressionSimulator()
Expand Down Expand Up @@ -403,6 +464,7 @@ def main() -> None:
nominal, empirical = _run_calibration(results, front_idx, world)
_run_persistence(results)
_save_plots(results, directions, nominal, empirical)
_run_multi_fidelity()


if __name__ == "__main__":
Expand Down
24 changes: 18 additions & 6 deletions src/trade_study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class Phase:
indices of configs to pass to the next phase. If None, phase
is terminal.
n_trials: For adaptive mode, number of optuna trials.
world: Optional phase-level simulator override. When set, this
phase uses *world* instead of the ``Study``-level simulator.
Useful for multi-fidelity workflows (cheap surrogate first,
expensive model later).
scorer: Optional phase-level scorer override. When set, this
phase uses *scorer* instead of the ``Study``-level scorer.
"""

name: str
Expand All @@ -57,6 +63,8 @@ class Phase:
None
)
n_trials: int = 100
world: Simulator | None = None
scorer: Scorer | None = None


def top_k_pareto_filter(
Expand Down Expand Up @@ -210,10 +218,14 @@ def run(
prev_result: ResultsTable | None = None

for phase in self.phases:
# Resolve phase-level overrides (multi-fidelity support)
world = phase.world if phase.world is not None else self.world
scorer = phase.scorer if phase.scorer is not None else self.scorer

if isinstance(phase.grid, str) and phase.grid == "adaptive":
result = run_adaptive(
self.world,
self.scorer,
world,
scorer,
self.factors,
self.observables,
n_trials=phase.n_trials,
Expand All @@ -226,8 +238,8 @@ def run(
raise ValueError(msg)
grid = phase.grid(prev_result, self.observables)
result = run_grid(
self.world,
self.scorer,
world,
scorer,
grid,
self.observables,
annotations=self.annotations or None,
Expand All @@ -239,8 +251,8 @@ def run(
phase.grid if isinstance(phase.grid, list) else (carry_grid or [])
)
result = run_grid(
self.world,
self.scorer,
world,
scorer,
grid,
self.observables,
annotations=self.annotations or None,
Expand Down
164 changes: 164 additions & 0 deletions tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,3 +991,167 @@ def test_feasibility_filter_in_study(
final = study.results("refine")
# alpha=0.0 (cost=0), alpha=0.25 (cost=2.5), alpha=0.5 (cost=5.0) satisfy cost <= 5
assert final.scores.shape[0] == 3


# ---------------------------------------------------------------------------
# Phase-level world / scorer override (multi-fidelity, #78)
# ---------------------------------------------------------------------------


class _CheapSimulator:
"""Cheap surrogate that adds a constant offset to alpha."""

def generate(
self,
config: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Return config as both truth and observations.

Returns:
Tuple of (config, config).
"""
return config, config


class _CheapScorer:
"""Scorer that returns a constant error and halved cost."""

def score(
self,
truth: Any,
observations: Any,
config: dict[str, Any],
) -> dict[str, float]:
"""Score: error = 0.1 (constant), cost = alpha * 5.

Returns:
Dict with ``error`` and ``cost`` scores.
"""
a = float(config.get("alpha", 0.5))
return {"error": 0.1, "cost": a * 5.0}


def test_phase_world_override(
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""Phase.world overrides Study.world for that phase."""
cheap = _CheapSimulator()
expensive = _ToySimulator()
grid = [{"alpha": 0.5}]
study = Study(
world=expensive,
scorer=scorer,
observables=observables,
phases=[
Phase(name="cheap_phase", grid=grid, world=cheap),
Phase(name="expensive_phase", grid=grid),
],
)
study.run()
# Both phases use ToyScorer (error=|0.5-0.5|=0, cost=5).
# The key is that they ran without error, proving the
# phase-level world was used for cheap_phase.
assert study.results("cheap_phase").scores.shape == (1, 2)
assert study.results("expensive_phase").scores.shape == (1, 2)


def test_phase_scorer_override(
world: _ToySimulator,
observables: list[Observable],
) -> None:
"""Phase.scorer overrides Study.scorer for that phase."""
cheap_scorer = _CheapScorer()
expensive_scorer = _ToyScorer()
grid = [{"alpha": 0.5}]
study = Study(
world=world,
scorer=expensive_scorer,
observables=observables,
phases=[
Phase(name="cheap_phase", grid=grid, scorer=cheap_scorer),
Phase(name="expensive_phase", grid=grid),
],
)
study.run()
cheap_r = study.results("cheap_phase")
expensive_r = study.results("expensive_phase")
# CheapScorer: error=0.1, cost=0.5*5=2.5
assert cheap_r.scores[0, 0] == pytest.approx(0.1)
assert cheap_r.scores[0, 1] == pytest.approx(2.5)
# ToyScorer: error=|0.5-0.5|=0, cost=0.5*10=5
assert expensive_r.scores[0, 0] == pytest.approx(0.0)
assert expensive_r.scores[0, 1] == pytest.approx(5.0)


def test_phase_both_overrides(
observables: list[Observable],
) -> None:
"""Phase can override both world and scorer simultaneously."""
study = Study(
world=_ToySimulator(),
scorer=_ToyScorer(),
observables=observables,
phases=[
Phase(
name="custom",
grid=[{"alpha": 0.5}],
world=_CheapSimulator(),
scorer=_CheapScorer(),
),
],
)
study.run()
r = study.results("custom")
assert r.scores[0, 0] == pytest.approx(0.1)
assert r.scores[0, 1] == pytest.approx(2.5)


def test_multi_fidelity_screen_then_validate(
observables: list[Observable],
) -> None:
"""Two-phase multi-fidelity: cheap screen, expensive validation."""
grid = [{"alpha": v} for v in [0.0, 0.25, 0.5, 0.75, 1.0]]
study = Study(
world=_ToySimulator(),
scorer=_ToyScorer(),
observables=observables,
phases=[
Phase(
name="screen",
grid=grid,
world=_CheapSimulator(),
scorer=_CheapScorer(),
filter_fn=top_k_pareto_filter(k=2),
),
Phase(name="validate", grid="carry"),
],
)
study.run()
screen_r = study.results("screen")
validate_r = study.results("validate")
# Screening used cheap scorer (all errors = 0.1)
assert np.all(screen_r.scores[:, 0] == pytest.approx(0.1))
# Validation used Study-level ToyScorer (varied errors)
assert validate_r.scores.shape[0] <= 2
# At least one validation error differs from 0.1
assert not np.all(validate_r.scores[:, 0] == pytest.approx(0.1))


def test_phase_world_override_none_uses_study_default(
world: _ToySimulator,
scorer: _ToyScorer,
observables: list[Observable],
) -> None:
"""Phase.world=None (default) uses Study.world."""
grid = [{"alpha": 0.5}]
study = Study(
world=world,
scorer=scorer,
observables=observables,
phases=[Phase(name="default", grid=grid)],
)
study.run()
r = study.results("default")
assert r.scores[0, 0] == pytest.approx(0.0)
assert r.scores[0, 1] == pytest.approx(5.0)
Loading