Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
XLA_FLAGS: "--xla_force_host_platform_device_count=4"
JAX_PLATFORM_NAME: "cpu"
run: |
pytest --cov --cov-branch --cov-report=xml
pytest --cov --cov-branch --cov-report=xml -m "not timing"

- name: Upload results to Codecov
uses: codecov/codecov-action@v5
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pip install --upgrade "jax[cpu]"
```
## 🥜 In a nutshell

All optimizers follow the same stateless pattern: `Optimizer.init` returns a `(state, optimizer)` pair, and `optimizer.optimize` runs the search loop. Your objective function must have the signature `fn(key, params) -> scalar`.
All optimizers follow the same stateless pattern: `Optimizer.init` returns a `(state, optimizer)` pair, and `optimizer.optimize` runs the search loop. Your objective function must have the signature `fn(key, params) -> scalar`. `params` can be any PyTree.

```python
import jax
Expand Down
9,393 changes: 9,393 additions & 0 deletions notebooks/bo_design_study.ipynb

Large diffs are not rendered by default.

495 changes: 495 additions & 0 deletions notebooks/high_dimensional.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ docs = [
"sphinx-autobuild"
]

[tool.pytest.ini_options]
markers = [
"timing: performance/throughput benchmarks (deselect with '-m not timing')",
]

[tool.ruff]
line-length = 88
target-version = "py310"
Expand Down
20 changes: 10 additions & 10 deletions src/hyperoptax/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ class BayesianSearchState(base.OptimizerState):
class BayesianSearch(base.Optimizer):
"""Bayesian optimisation with a Gaussian Process surrogate.

Uses a GP (Matérn 2.5 kernel by default) to model the objective and
Uses a GP (Matérn 0.5 kernel by default) to model the objective and
selects the next batch of candidates by maximising an acquisition function
(EI by default). ARD length scales are tuned with Adam each iteration.
(PI by default). ARD length scales are tuned with Adam each iteration.
Parallel batches are generated via the Kriging Believer hallucination
strategy.

Attributes:
jitter: Small diagonal added to the kernel matrix for numerical
stability (default ``1e-6``).
kernel: Kernel function (default :class:`~hyperoptax.kernels.Matern`
with ``nu=2.5``).
with ``nu=0.5``).
acquisition: Acquisition function (default
:class:`~hyperoptax.acquisition.EI` with ``xi=0.01``).
:class:`~hyperoptax.acquisition.PI` with ``xi=0.01``).
n_candidates: Number of random candidates sampled per iteration for
the discrete pre-selection step (default ``1000``).
n_restarts: Number of L-BFGS restarts seeded from the top candidates
Expand All @@ -64,27 +64,27 @@ class BayesianSearch(base.Optimizer):
n_warmup: Number of pure-random iterations before the GP is used
(default ``1``).
maximize: Set ``False`` to minimise the objective (default ``True``).
n_parallel: Number of parallel candidates per iteration (default ``1``).
n_parallel: Number of parallel candidates per iteration (default ``4``).
hallucination: Hallucination strategy for Kriging Believer parallel
selection (default :class:`~hyperoptax.acquisition.MeanHallucination`).
selection (default :class:`~hyperoptax.acquisition.SampleHallucination`).
"""

jitter: float = 1e-6
kernel: kernels.BaseKernel = dataclasses.field(
default_factory=lambda: kernels.Matern(length_scale=1.0, nu=2.5)
default_factory=lambda: kernels.Matern(length_scale=1.0, nu=0.5)
)
acquisition: acq.BaseAcquisition = dataclasses.field(
default_factory=lambda: acq.EI(xi=0.01)
default_factory=lambda: acq.PI(xi=0.01)
)
n_candidates: int = 1000 # random candidates sampled for continuous spaces
n_restarts: int = 2 # number of L-BFGS restarts (seeded from top candidates)
n_lbfgs_steps: int = 10 # gradient steps per restart
n_hparam_steps: int = 20 # Adam steps to tune log_length_scale each iteration
n_warmup: int = 1 # pure-random evaluations before GP kicks in
maximize: bool = True # set False to minimize the objective
n_parallel: int = 1
n_parallel: int = 4
hallucination: acq.BaseHallucination = dataclasses.field(
default_factory=acq.MeanHallucination
default_factory=acq.SampleHallucination
)

@classmethod
Expand Down
40 changes: 20 additions & 20 deletions tests/test_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def test_2d_space(self):
params = optimizer.get_next_params(state, self.key)
assert "x" in params
assert "y" in params
assert params["x"].shape == (1,)
assert params["y"].shape == (1,)
assert params["x"].shape == (4,)
assert params["y"].shape == (4,)

def test_n_parallel_discrete(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
Expand All @@ -328,7 +328,7 @@ def test_n_parallel_discrete(self):
class TestBayesianSearchOptimize:
def test_optimize_returns_correct_shapes(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), func
Expand All @@ -338,14 +338,14 @@ def test_optimize_returns_correct_shapes(self):

def test_optimize_fills_state(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, _ = optimizer.optimize(state, jax.random.PRNGKey(0), func)
assert int(state.mask.sum()) == 5

def test_optimize_finds_optimum(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=20)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=20, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), func
Expand All @@ -355,7 +355,7 @@ def test_optimize_finds_optimum(self):

def test_optimize_with_array_result(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=3)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=3, n_parallel=1)
func = lambda key, config: jnp.array([-(config["x"] ** 2)])
state, (_, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), func
Expand All @@ -365,15 +365,15 @@ def test_optimize_with_array_result(self):
def test_optimize_converges_toward_optimum(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(
space, n_max=20, acquisition=acq.UCB()
space, n_max=20, n_parallel=1, acquisition=acq.UCB()
)
func = lambda key, config: -(config["x"] ** 2)
state, _ = optimizer.optimize(state, jax.random.PRNGKey(0), func)
assert float(jnp.min(state.X[:20, 0])) == pytest.approx(0.0)

def test_optimize_continuous_space(self):
space = {"x": sp.LinearSpace(0.0, 1.0)}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), func
Expand All @@ -384,7 +384,7 @@ def test_optimize_continuous_space(self):
def test_optimize_continuous_with_ei_uses_observed_y_max(self):
space = {"x": sp.LinearSpace(0.0, 1.0), "y": sp.LinearSpace(0.0, 1.0)}
state, optimizer = bayesian.BayesianSearch.init(
space, n_max=20, acquisition=acq.EI()
space, n_max=20, n_parallel=1, acquisition=acq.EI()
)
state = optimizer.update_state(
state, jax.random.PRNGKey(0), jnp.array([100.0]), jnp.array([[0.5, 0.5]])
Expand All @@ -394,10 +394,10 @@ def test_optimize_continuous_with_ei_uses_observed_y_max(self):

def test_optimize_minimize(self):
space = {"x": sp.DiscreteSpace([0.0, 0.25, 0.5, 0.75, 1.0])}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=10, maximize=False)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=20, n_parallel=1, maximize=False)
func = lambda key, config: config["x"] ** 2
state, _ = optimizer.optimize(state, jax.random.PRNGKey(0), func)
assert int(state.mask.sum()) == 10
assert int(state.mask.sum()) == 20
assert float(optimizer.best_result(state)) == pytest.approx(0.0)

def test_optimize_n_parallel_fills_buffer(self):
Expand Down Expand Up @@ -458,7 +458,7 @@ def test_best_params_minimize(self):
assert float(params["x"]) == pytest.approx(0.5)

def test_best_result_after_full_optimize(self):
state, optimizer = bayesian.BayesianSearch.init(self.space, n_max=10)
state, optimizer = bayesian.BayesianSearch.init(self.space, n_max=10, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, _ = optimizer.optimize(state, jax.random.PRNGKey(0), func)
assert float(optimizer.best_result(state)) == pytest.approx(
Expand Down Expand Up @@ -509,7 +509,7 @@ def test_log_length_scale_unchanged_with_single_observation(self):

def test_tuned_length_scale_used_in_gp(self):
state, optimizer = bayesian.BayesianSearch.init(
self.space, n_max=10, n_hparam_steps=20
self.space, n_max=10, n_parallel=1, n_hparam_steps=20
)
func = lambda key, config: -(config["x"] ** 2)
state, _ = optimizer.optimize(state, jax.random.PRNGKey(0), func)
Expand Down Expand Up @@ -551,7 +551,7 @@ def test_get_next_params_mixed(self):
"lr": sp.LogSpace(1e-4, 1e-1),
"layers": sp.DiscreteSpace([1, 2, 3, 4]),
}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=20)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=20, n_parallel=1)
params = optimizer.get_next_params(state, jax.random.PRNGKey(0))
assert "lr" in params and "layers" in params
assert params["lr"].shape == (1,)
Expand Down Expand Up @@ -593,7 +593,7 @@ def test_two_arg_passes(self):
class TestBayesianSearchOptimizeScan:
def test_optimize_scan_runs(self):
space = {"x": sp.LinearSpace(0.0, 1.0)}
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5)
state, optimizer = bayesian.BayesianSearch.init(space, n_max=5, n_parallel=1)
func = lambda key, config: -(config["x"] ** 2)
state, (params_hist, results_hist) = optimizer.optimize_scan(
state, jax.random.PRNGKey(0), func
Expand Down Expand Up @@ -719,9 +719,9 @@ def test_lbfgs_improves_over_seed(self):
state, key, jnp.array([r]), jnp.array([[x, y_]])
)
params = optimizer.get_next_params(state, key)
assert params["x"].shape == (1,)
assert 0.0 <= float(params["x"][0]) <= 1.0
assert 0.0 <= float(params["y"][0]) <= 1.0
assert params["x"].shape == (4,)
assert all(0.0 <= float(v) <= 1.0 for v in params["x"])
assert all(0.0 <= float(v) <= 1.0 for v in params["y"])


class TestKrigingBelieverHallucination:
Expand Down Expand Up @@ -749,10 +749,10 @@ def _make_state_with_obs(
state = optimizer.update_state(state, key, jnp.array([float(i)]), x)
return state, optimizer

def test_default_hallucination_is_mean(self):
def test_default_hallucination_is_sample(self):
space = {"x": sp.LinearSpace(0.0, 1.0)}
_, optimizer = bayesian.BayesianSearch.init(space)
assert isinstance(optimizer.hallucination, acq.MeanHallucination)
assert isinstance(optimizer.hallucination, acq.SampleHallucination)

def test_mean_hallucination_optimize_runs(self):
space = {"x": sp.LinearSpace(0.0, 1.0)}
Expand Down
Loading