diff --git a/docs/examples/tutorials/tuning-a-process.md b/docs/examples/tutorials/tuning-a-process.md index 071f76d9..4c8d3651 100644 --- a/docs/examples/tutorials/tuning-a-process.md +++ b/docs/examples/tutorials/tuning-a-process.md @@ -58,6 +58,9 @@ Running this code will execute an optimisation job and print out information on !!! tip Since [Optuna](https://optuna.org/) is used under the hood, you can configure the optional `algorithm` argument on the `Tuner` with additional configuration defined in [`OptunaSpec`][plugboard.schemas.OptunaSpec]. For example, the [`storage`](https://optuna.readthedocs.io/en/stable/reference/storages.html) argument allows you to save the optimisation results to a database or SQLite file. You can then use a tool like [Optuna Dashboard](https://optuna-dashboard.readthedocs.io/en/stable/getting-started.html) to study the optimisation output in more detail. +!!! tip + You can impose arbitary constraints on variables within a `Process`. In your `step` method you can raise a [`ConstraintError`][plugboard.exceptions.ConstraintError] to indicate to the `Tuner` that a constraint has been breached. This will cause the trial to be stopped, and the optimisation will continue trying to find parameters that don't cause the constraint violation. + ## Using YAML config Plugboard's YAML config supports an optional `tune` section, allowing you to define optimisation jobs alongside your model configuration: diff --git a/plugboard/exceptions/__init__.py b/plugboard/exceptions/__init__.py index 75fccd74..5dd488da 100644 --- a/plugboard/exceptions/__init__.py +++ b/plugboard/exceptions/__init__.py @@ -89,3 +89,9 @@ class ValidationError(Exception): """Raised when an invalid `Process` or `Component` is encountered.""" pass + + +class ConstraintError(Exception): + """Raised when a constraint is violated.""" + + pass diff --git a/plugboard/tune/tune.py b/plugboard/tune/tune.py index 6e0fd090..aa678fc4 100644 --- a/plugboard/tune/tune.py +++ b/plugboard/tune/tune.py @@ -1,12 +1,14 @@ """Provides `Tuner` class for optimising Plugboard processes.""" from inspect import isfunction +import math from pydoc import locate import typing as _t import ray.tune.search.optuna from plugboard.component.component import Component, ComponentRegistry +from plugboard.exceptions import ConstraintError from plugboard.process import Process, ProcessBuilder from plugboard.schemas import ( Direction, @@ -52,6 +54,8 @@ def __init__( algorithm: Configuration for the underlying Optuna algorithm used for optimisation. """ self._logger = DI.logger.resolve_sync().bind(cls=self.__class__.__name__) + # Check that objective and mode are lists of the same length if multiple objectives are used + self._check_objective(objective, mode) self._objective = objective if isinstance(objective, list) else [objective] self._mode = [str(m) for m in mode] if isinstance(mode, list) else str(mode) self._metric = ( @@ -79,6 +83,22 @@ def result_grid(self) -> ray.tune.ResultGrid: raise ValueError("No result grid available. Run the optimisation job first.") return self._result_grid + @classmethod + def _check_objective( + cls, objective: ObjectiveSpec | list[ObjectiveSpec], mode: Direction | list[Direction] + ) -> None: + """Check that the objective and mode are valid.""" + if isinstance(objective, list): + if not isinstance(mode, list): + raise ValueError("If using multiple objectives, `mode` must also be a list.") + if len(objective) != len(mode): + raise ValueError( + "If using multiple objectives, `mode` and `objective` must be the same length." + ) + else: + if isinstance(mode, list): + raise ValueError("If using a single objective, `mode` must not be a list.") + def _build_algorithm( self, algorithm: _t.Optional[OptunaSpec] = None ) -> ray.tune.search.Searcher: @@ -189,7 +209,7 @@ def run(self, spec: ProcessSpec) -> ray.tune.Result | list[ray.tune.Result]: def _build_objective( self, component_classes: dict[str, type[Component]], spec: ProcessSpec ) -> _t.Callable: - def fn(config: dict[str, _t.Any]) -> _t.Any: # pragma: no-cover + def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover # Recreate the ComponentRegistry in the Ray worker for key, cls in component_classes.items(): ComponentRegistry.add(cls, key=key) @@ -198,8 +218,23 @@ def fn(config: dict[str, _t.Any]) -> _t.Any: # pragma: no-cover self._override_parameter(spec, self._parameters_dict[name], value) process = ProcessBuilder.build(spec) - run_coro_sync(self._run_process(process)) - - return {obj.full_name: self._get_objective(process, obj) for obj in self._objective} + result = {} + try: + run_coro_sync(self._run_process(process)) + result = { + obj.full_name: self._get_objective(process, obj) for obj in self._objective + } + except* ConstraintError as e: + modes = self._mode if isinstance(self._mode, list) else [self._mode] + self._logger.warning( + "Constraint violated during optimisation, stopping early", + constraint_error=str(e), + ) + result = { + obj.full_name: math.inf if mode == "min" else -math.inf + for obj, mode in zip(self._objective, modes) + } + + return result return fn diff --git a/tests/integration/test_tuner.py b/tests/integration/test_tuner.py index e3f997a4..4e35b645 100644 --- a/tests/integration/test_tuner.py +++ b/tests/integration/test_tuner.py @@ -1,14 +1,27 @@ """Integration tests for the `Tuner` class.""" +import math + import msgspec import pytest +from plugboard.exceptions import ConstraintError from plugboard.schemas import ConfigSpec, ConnectorBuilderSpec, ObjectiveSpec from plugboard.schemas.tune import CategoricalParameterSpec, IntParameterSpec, OptunaSpec from plugboard.tune import Tuner from tests.integration.test_process_with_components_run import A, B, C # noqa: F401 +class ConstrainedB(B): + """Component with a constraint.""" + + async def step(self) -> None: + """Override step to apply a constraint.""" + if self.in_1 > 10: + raise ConstraintError("Input must not be greater than 10") + await super().step() + + @pytest.fixture def config() -> dict: """Loads the YAML config.""" @@ -59,10 +72,10 @@ async def test_tune(config: dict, mode: str, process_type: str, ray_ctx: None) - assert not [t for t in result if t.error] # Correct optimimum must be found (within tolerance) if mode == "min": - assert best_result.config["a.iters"] <= tuner._parameters["a.iters"].lower + 1 + assert best_result.config["a.iters"] <= tuner._parameters["a.iters"].lower + 2 assert best_result.metrics["c.in_1"] == best_result.config["a.iters"] - 1 else: - assert best_result.config["a.iters"] >= tuner._parameters["a.iters"].upper - 1 + assert best_result.config["a.iters"] >= tuner._parameters["a.iters"].upper - 2 assert best_result.metrics["c.in_1"] == best_result.config["a.iters"] - 1 @@ -121,3 +134,47 @@ async def test_multi_objective_tune(config: dict, ray_ctx: None) -> None: assert -1 in set(r.config["b.factor"] for r in best_result) assert -1 in set(r.metrics["b.out_1"] for r in best_result) assert 1 in set(r.metrics["c.in_1"] for r in best_result) + + +@pytest.mark.tuner +@pytest.mark.asyncio +async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None: + """Tests running of optimisation jobs with a constraint.""" + spec = ConfigSpec.model_validate(config) + process_spec = spec.plugboard.process + # Replace component B with a constrained version + process_spec.args.components[1].type = "tests.integration.test_tuner.ConstrainedB" + tuner = Tuner( + objective=ObjectiveSpec( + object_type="component", + object_name="c", + field_type="field", + field_name="in_1", + ), + parameters=[ + IntParameterSpec( + object_type="component", + object_name="a", + field_type="arg", + field_name="iters", + lower=5, + upper=15, + ) + ], + num_samples=12, + mode="max", + max_concurrent=2, + algorithm=OptunaSpec(), + ) + best_result = tuner.run( + spec=process_spec, + ) + result = tuner.result_grid + # There must be no failed trials + assert not [t for t in result if t.error] + # Constraint must be respected + assert all(t.metrics["c.in_1"] <= 10 for t in result) + # Optimum must be less than or equal to 10 + assert best_result.metrics["c.in_1"] <= 10 + # If a.iters is greater than 11, the constraint will be violated + assert all(t.metrics["c.in_1"] == -math.inf for t in result if t.config["a.iters"] > 11)