Skip to content
3 changes: 3 additions & 0 deletions docs/examples/tutorials/tuning-a-process.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Running this code will execute an optimisation job and print out information on
!!! tip
Since [Optuna](https://optuna.org/) is used under the hood, you can configure the optional `algorithm` argument on the `Tuner` with additional configuration defined in [`OptunaSpec`][plugboard.schemas.OptunaSpec]. For example, the [`storage`](https://optuna.readthedocs.io/en/stable/reference/storages.html) argument allows you to save the optimisation results to a database or SQLite file. You can then use a tool like [Optuna Dashboard](https://optuna-dashboard.readthedocs.io/en/stable/getting-started.html) to study the optimisation output in more detail.

!!! tip
You can impose arbitary constraints on variables within a `Process`. In your `step` method you can raise a [`ConstraintError`][plugboard.exceptions.ConstraintError] to indicate to the `Tuner` that a constraint has been breached. This will cause the trial to be stopped, and the optimisation will continue trying to find parameters that don't cause the constraint violation.

## Using YAML config

Plugboard's YAML config supports an optional `tune` section, allowing you to define optimisation jobs alongside your model configuration:
Expand Down
6 changes: 6 additions & 0 deletions plugboard/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,9 @@ class ValidationError(Exception):
"""Raised when an invalid `Process` or `Component` is encountered."""

pass


class ConstraintError(Exception):
"""Raised when a constraint is violated."""

pass
43 changes: 39 additions & 4 deletions plugboard/tune/tune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Provides `Tuner` class for optimising Plugboard processes."""

from inspect import isfunction
import math
from pydoc import locate
import typing as _t

import ray.tune.search.optuna

from plugboard.component.component import Component, ComponentRegistry
from plugboard.exceptions import ConstraintError
from plugboard.process import Process, ProcessBuilder
from plugboard.schemas import (
Direction,
Expand Down Expand Up @@ -52,6 +54,8 @@ def __init__(
algorithm: Configuration for the underlying Optuna algorithm used for optimisation.
"""
self._logger = DI.logger.resolve_sync().bind(cls=self.__class__.__name__)
# Check that objective and mode are lists of the same length if multiple objectives are used
self._check_objective(objective, mode)
self._objective = objective if isinstance(objective, list) else [objective]
self._mode = [str(m) for m in mode] if isinstance(mode, list) else str(mode)
self._metric = (
Expand Down Expand Up @@ -79,6 +83,22 @@ def result_grid(self) -> ray.tune.ResultGrid:
raise ValueError("No result grid available. Run the optimisation job first.")
return self._result_grid

@classmethod
def _check_objective(
cls, objective: ObjectiveSpec | list[ObjectiveSpec], mode: Direction | list[Direction]
) -> None:
"""Check that the objective and mode are valid."""
if isinstance(objective, list):
if not isinstance(mode, list):
raise ValueError("If using multiple objectives, `mode` must also be a list.")
if len(objective) != len(mode):
raise ValueError(
"If using multiple objectives, `mode` and `objective` must be the same length."
)
else:
if isinstance(mode, list):
raise ValueError("If using a single objective, `mode` must not be a list.")

def _build_algorithm(
self, algorithm: _t.Optional[OptunaSpec] = None
) -> ray.tune.search.Searcher:
Expand Down Expand Up @@ -189,7 +209,7 @@ def run(self, spec: ProcessSpec) -> ray.tune.Result | list[ray.tune.Result]:
def _build_objective(
self, component_classes: dict[str, type[Component]], spec: ProcessSpec
) -> _t.Callable:
def fn(config: dict[str, _t.Any]) -> _t.Any: # pragma: no-cover
def fn(config: dict[str, _t.Any]) -> dict[str, _t.Any]: # pragma: no cover
# Recreate the ComponentRegistry in the Ray worker
for key, cls in component_classes.items():
ComponentRegistry.add(cls, key=key)
Expand All @@ -198,8 +218,23 @@ def fn(config: dict[str, _t.Any]) -> _t.Any: # pragma: no-cover
self._override_parameter(spec, self._parameters_dict[name], value)

process = ProcessBuilder.build(spec)
run_coro_sync(self._run_process(process))

return {obj.full_name: self._get_objective(process, obj) for obj in self._objective}
result = {}
try:
run_coro_sync(self._run_process(process))
result = {
obj.full_name: self._get_objective(process, obj) for obj in self._objective
}
except* ConstraintError as e:
modes = self._mode if isinstance(self._mode, list) else [self._mode]
Comment thread
toby-coleman marked this conversation as resolved.
self._logger.warning(
"Constraint violated during optimisation, stopping early",
constraint_error=str(e),
)
result = {
obj.full_name: math.inf if mode == "min" else -math.inf
for obj, mode in zip(self._objective, modes)
}
Comment thread
toby-coleman marked this conversation as resolved.

return result

return fn
61 changes: 59 additions & 2 deletions tests/integration/test_tuner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
"""Integration tests for the `Tuner` class."""

import math

import msgspec
import pytest

from plugboard.exceptions import ConstraintError
from plugboard.schemas import ConfigSpec, ConnectorBuilderSpec, ObjectiveSpec
from plugboard.schemas.tune import CategoricalParameterSpec, IntParameterSpec, OptunaSpec
from plugboard.tune import Tuner
from tests.integration.test_process_with_components_run import A, B, C # noqa: F401


class ConstrainedB(B):
"""Component with a constraint."""

async def step(self) -> None:
"""Override step to apply a constraint."""
if self.in_1 > 10:
raise ConstraintError("Input must not be greater than 10")
await super().step()


@pytest.fixture
def config() -> dict:
"""Loads the YAML config."""
Expand Down Expand Up @@ -59,10 +72,10 @@ async def test_tune(config: dict, mode: str, process_type: str, ray_ctx: None) -
assert not [t for t in result if t.error]
# Correct optimimum must be found (within tolerance)
if mode == "min":
assert best_result.config["a.iters"] <= tuner._parameters["a.iters"].lower + 1
assert best_result.config["a.iters"] <= tuner._parameters["a.iters"].lower + 2
assert best_result.metrics["c.in_1"] == best_result.config["a.iters"] - 1
else:
assert best_result.config["a.iters"] >= tuner._parameters["a.iters"].upper - 1
assert best_result.config["a.iters"] >= tuner._parameters["a.iters"].upper - 2
assert best_result.metrics["c.in_1"] == best_result.config["a.iters"] - 1


Expand Down Expand Up @@ -121,3 +134,47 @@ async def test_multi_objective_tune(config: dict, ray_ctx: None) -> None:
assert -1 in set(r.config["b.factor"] for r in best_result)
assert -1 in set(r.metrics["b.out_1"] for r in best_result)
assert 1 in set(r.metrics["c.in_1"] for r in best_result)


@pytest.mark.tuner
@pytest.mark.asyncio
async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None:
"""Tests running of optimisation jobs with a constraint."""
spec = ConfigSpec.model_validate(config)
process_spec = spec.plugboard.process
# Replace component B with a constrained version
process_spec.args.components[1].type = "tests.integration.test_tuner.ConstrainedB"
tuner = Tuner(
objective=ObjectiveSpec(
object_type="component",
object_name="c",
field_type="field",
field_name="in_1",
),
parameters=[
IntParameterSpec(
object_type="component",
object_name="a",
field_type="arg",
field_name="iters",
lower=5,
upper=15,
)
],
num_samples=12,
mode="max",
max_concurrent=2,
algorithm=OptunaSpec(),
)
best_result = tuner.run(
spec=process_spec,
)
result = tuner.result_grid
# There must be no failed trials
assert not [t for t in result if t.error]
# Constraint must be respected
assert all(t.metrics["c.in_1"] <= 10 for t in result)
# Optimum must be less than or equal to 10
assert best_result.metrics["c.in_1"] <= 10
# If a.iters is greater than 11, the constraint will be violated
assert all(t.metrics["c.in_1"] == -math.inf for t in result if t.config["a.iters"] > 11)
Loading