From 1ae9bfd190edc0fe11fda8a367f38279eeebe1aa Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:34:56 +0100 Subject: [PATCH 01/23] Refactor shared validators and converters --- baybe/parameters/categorical.py | 12 ++++-------- baybe/parameters/numerical.py | 3 ++- baybe/parameters/validation.py | 17 ----------------- baybe/utils/conversion.py | 12 ++++++++++++ baybe/utils/validation.py | 34 ++++++++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/baybe/parameters/categorical.py b/baybe/parameters/categorical.py index bffecd5682..85058f6b29 100644 --- a/baybe/parameters/categorical.py +++ b/baybe/parameters/categorical.py @@ -13,13 +13,7 @@ from baybe.parameters.enum import CategoricalEncoding from baybe.parameters.validation import validate_unique_values from baybe.settings import active_settings -from baybe.utils.conversion import nonstring_to_tuple - - -def _convert_values(value, self, field) -> tuple[str, ...]: - """Sort and convert values for categorical parameters.""" - value = nonstring_to_tuple(value, self, field) - return tuple(sorted(value, key=lambda x: (str(type(x)), x))) +from baybe.utils.conversion import normalize_convertible2str_sequence def _validate_label_min_len(self, attr, value) -> None: @@ -38,7 +32,9 @@ class CategoricalParameter(_DiscreteLabelLikeParameter): # object variables _values: tuple[str | bool, ...] = field( alias="values", - converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore + converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter + normalize_convertible2str_sequence, takes_self=True, takes_field=True + ), validator=( validate_unique_values, deep_iterable( diff --git a/baybe/parameters/numerical.py b/baybe/parameters/numerical.py index ba210de244..e56ca2c3f9 100644 --- a/baybe/parameters/numerical.py +++ b/baybe/parameters/numerical.py @@ -13,9 +13,10 @@ from baybe.exceptions import NumericalUnderflowError from baybe.parameters.base import ContinuousParameter, DiscreteParameter -from baybe.parameters.validation import validate_is_finite, validate_unique_values +from baybe.parameters.validation import validate_unique_values from baybe.settings import active_settings from baybe.utils.interval import InfiniteIntervalError, Interval +from baybe.utils.validation import validate_is_finite @define(frozen=True, slots=False) diff --git a/baybe/parameters/validation.py b/baybe/parameters/validation.py index 5c70ed5ecf..0367b2b30f 100644 --- a/baybe/parameters/validation.py +++ b/baybe/parameters/validation.py @@ -1,9 +1,7 @@ """Validation functionality for parameters.""" -from collections.abc import Sequence from typing import Any -import numpy as np from attrs.validators import gt, instance_of, lt @@ -28,18 +26,3 @@ def validate_decorrelation(obj: Any, attribute: Any, value: float) -> None: if isinstance(value, float): gt(0.0)(obj, attribute, value) lt(1.0)(obj, attribute, value) - - -def validate_is_finite( # noqa: DOC101, DOC103 - obj: Any, _: Any, value: Sequence[float] -) -> None: - """Validate that ``value`` contains no infinity/nan. - - Raises: - ValueError: If ``value`` contains infinity/nan. - """ - if not all(np.isfinite(value)): - raise ValueError( - f"Cannot assign the following values containing infinity/nan to " - f"parameter {obj.name}: {value}." - ) diff --git a/baybe/utils/conversion.py b/baybe/utils/conversion.py index ed5f622532..846fc925ad 100644 --- a/baybe/utils/conversion.py +++ b/baybe/utils/conversion.py @@ -42,6 +42,18 @@ def nonstring_to_tuple(x: Sequence[_T], self: type, field: Attribute) -> tuple[_ return tuple(x) +def normalize_convertible2str_sequence( + value: Sequence[str | bool], self: type, field: Attribute +) -> tuple[str | bool, ...]: + """Sort and convert values for a sequence of string-convertible types. + + If the sequence is a string itself, this is blocked to avoid unintended iteration + over its characters. + """ + value = nonstring_to_tuple(value, self, field) + return tuple(sorted(value, key=lambda x: (str(type(x)), x))) + + def _indent(text: str, amount: int = 3, ch: str = " ") -> str: """Indent a given text by a certain amount.""" padding = amount * ch diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 93c87ab316..03a7462767 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -3,7 +3,8 @@ from __future__ import annotations import math -from collections.abc import Callable, Iterable +from collections import Counter +from collections.abc import Callable, Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any import numpy as np @@ -261,3 +262,34 @@ def preprocess_dataframe( else: targets = () return normalize_input_dtypes(df, [*searchspace.parameters, *targets]) + + +def validate_is_finite( # noqa: DOC101, DOC103 + _: Any, attribute: Attribute, value: float | Sequence[float] +) -> None: + """Validate that ``value`` contains no resp. is not infinity/nan. + + Raises: + ValueError: If ``value`` contains infinity/nan. + """ + if not np.isfinite(value).all(): + raise ValueError( + f"Cannot assign the following values containing infinity/nan to " + f"'{attribute.alias}': {value}." + ) + + +def validate_unique_values( # noqa: DOC101, DOC103 + _: Any, attribute: Attribute, value: Collection[str] +) -> None: + """Validate that there are no duplicates in ``value``. + + Raises: + ValueError: If there are duplicates in ``value``. + """ + duplicates = [item for item, count in Counter(value).items() if count > 1] + if duplicates: + raise ValueError( + f"Entries appearing multiple times: {duplicates}. " + f"All entries of '{attribute.alias}' must be unique." + ) From 764caca1d4034bd7b51000eb19ce6a8dcae4eef8 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:37:31 +0100 Subject: [PATCH 02/23] Update permutation augmentation utility interface --- baybe/utils/augmentation.py | 58 +++++++++++++++++--------------- tests/utils/test_augmentation.py | 27 +++++++-------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py index b9fc2d14aa..249b2a01e1 100644 --- a/baybe/utils/augmentation.py +++ b/baybe/utils/augmentation.py @@ -8,24 +8,29 @@ def df_apply_permutation_augmentation( df: pd.DataFrame, - column_groups: Sequence[Sequence[str]], + permutation_groups: Sequence[Sequence[str]], ) -> pd.DataFrame: """Augment a dataframe if permutation invariant columns are present. + Each group in ``permutation_groups`` contains the names of columns that are + permuted in lockstep. All groups must have the same length, and that length + must be at least 2 (otherwise there is nothing to permute). + Args: df: The dataframe that should be augmented. - column_groups: Sequences of permutation invariant columns. The n'th column in - each group will be permuted together with each n'th column in the other - groups. + permutation_groups: Groups of column names that are permuted in lockstep. + For example, ``[["A1", "A2"], ["B1", "B2"]]`` means that the columns + ``A1`` and ``A2`` are permuted, and ``B1`` and ``B2`` are permuted + in the same way. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Raises: - ValueError: If less than two column groups are given. - ValueError: If a column group is empty. - ValueError: If the column groups have differing amounts of entries. + ValueError: If no permutation groups are given. + ValueError: If any permutation group has fewer than two entries. + ValueError: If the permutation groups have differing amounts of entries. Examples: >>> df = pd.DataFrame({'A1':[1,2],'A2':[3,4], 'B1': [5, 6], 'B2': [7, 8]}) @@ -34,8 +39,8 @@ def df_apply_permutation_augmentation( 0 1 3 5 7 1 2 4 6 8 - >>> column_groups = [['A1'], ['A2']] - >>> dfa = df_apply_permutation_augmentation(df, column_groups) + >>> groups = [['A1', 'A2']] + >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 @@ -43,8 +48,8 @@ def df_apply_permutation_augmentation( 1 2 4 6 8 1 4 2 6 8 - >>> column_groups = [['A1', 'B1'], ['A2', 'B2']] - >>> dfa = df_apply_permutation_augmentation(df, column_groups) + >>> groups = [['A1', 'A2'], ['B1', 'B2']] + >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 @@ -53,36 +58,35 @@ def df_apply_permutation_augmentation( 1 4 2 8 6 """ # Validation - if len(column_groups) < 2: - raise ValueError( - "When augmenting permutation invariance, at least two column sequences " - "must be given." - ) + if len(permutation_groups) < 1: + raise ValueError("Permutation augmentation requires at least one group.") - if len({len(seq) for seq in column_groups}) != 1: + if len({len(seq) for seq in permutation_groups}) != 1: raise ValueError( - "Permutation augmentation can only work if the amount of columns in each " - "sequence is the same." + "Permutation augmentation can only work if all groups have the same " + "number of entries." ) - elif len(column_groups[0]) < 1: + + if len(permutation_groups[0]) < 2: raise ValueError( - "Permutation augmentation can only work if each column group has at " - "least one entry." + "Permutation augmentation can only work if each group has at " + "least two entries." ) # Augmentation Loop + n_positions = len(permutation_groups[0]) new_rows: list[pd.DataFrame] = [] - idx_permutation = list(permutations(range(len(column_groups)))) + idx_permutation = list(permutations(range(n_positions))) for _, row in df.iterrows(): # For each row in the original df, collect all its permutations to_add = [] for perm in idx_permutation: new_row = row.copy() - # Permute columns, this is done separately for each tuple of columns that - # belong together - for deps in map(list, zip(*column_groups)): - new_row[deps] = row[[deps[k] for k in perm]] + # Permute columns within each group according to the permutation + for group in permutation_groups: + cols = list(group) + new_row[cols] = row[[cols[k] for k in perm]] to_add.append(new_row) diff --git a/tests/utils/test_augmentation.py b/tests/utils/test_augmentation.py index 80ffff8a2a..4dd8a14d33 100644 --- a/tests/utils/test_augmentation.py +++ b/tests/utils/test_augmentation.py @@ -22,7 +22,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -41,7 +41,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -60,7 +60,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -80,7 +80,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -91,7 +91,7 @@ }, id="2inv+degen_target+degen", ), - param( # 3 invariant groups with 1 entry each + param( # 3 invariant cols in one group { "data": { "A": [1, 1], @@ -101,7 +101,7 @@ }, "index": [0, 1], }, - [["A"], ["B"], ["C"]], + [["A", "B", "C"]], { "data": { "A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5], @@ -113,7 +113,7 @@ }, id="3inv_1add", ), - param( # 2 groups with 2 entries each, 2 additional columns + param( # 2 lockstep groups with 2 entries each, 2 additional columns { "data": { "Slot1": ["s1", "s2"], @@ -125,7 +125,7 @@ }, "index": [0, 1], }, - [["Slot1", "Frac1"], ["Slot2", "Frac2"]], + [["Slot1", "Slot2"], ["Frac1", "Frac2"]], { "data": { "Slot1": ["s1", "s2", "s2", "s4"], @@ -139,7 +139,7 @@ }, id="2inv_2dependent_2add", ), - param( # 2 groups with 3 entries each, 1 additional column + param( # 3 lockstep groups with 2 entries each, 1 additional column { "data": { "Slot1": ["s1", "s2"], @@ -152,7 +152,7 @@ }, "index": [0, 1], }, - [["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]], + [["Slot1", "Slot2"], ["Frac1", "Frac2"], ["Temp1", "Temp2"]], { "data": { "Slot1": ["s1", "s2", "s2", "s4"], @@ -187,10 +187,9 @@ def test_df_permutation_aug(content, col_groups, content_expected): @pytest.mark.parametrize( ("col_groups", "msg"), [ - param([], "at least two column sequences", id="no_groups"), - param([["A"]], "at least two column sequences", id="just_one_group"), - param([["A"], ["B", "C"]], "the amount of columns in", id="different_lengths"), - param([[], []], "each column group has", id="empty_group"), + param([], "at least one group", id="no_groups"), + param([["A"]], "at least two entries", id="group_too_small"), + param([["A", "B"], ["C"]], "same number of entries", id="different_lengths"), ], ) def test_df_permutation_aug_invalid(col_groups, msg): From 99e1e60c85a3d44bdfd7a3292f80a40c4131cfa8 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:38:39 +0100 Subject: [PATCH 03/23] Add mirror augmentation utility --- baybe/utils/augmentation.py | 63 +++++++++++++++++++++++++++++++++++++ baybe/utils/dataframe.py | 4 +++ 2 files changed, 67 insertions(+) diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py index 249b2a01e1..6b1fd50d28 100644 --- a/baybe/utils/augmentation.py +++ b/baybe/utils/augmentation.py @@ -98,6 +98,69 @@ def df_apply_permutation_augmentation( return pd.concat(new_rows) +def df_apply_mirror_augmentation( + df: pd.DataFrame, + column: str, + *, + mirror_point: float = 0.0, +) -> pd.DataFrame: + """Augment a dataframe for a mirror invariant column. + + Args: + df: The dataframe that should be augmented. + column: The name of the affected column. + mirror_point: The point along which to mirror the values. Points that have + exactly this value will not be augmented. + + Returns: + The augmented dataframe containing the original one. Augmented row indices are + identical with the index of their original row. + + Examples: + >>> df = pd.DataFrame({'A':[1, 0, -2], 'B': [3, 4, 5]}) + >>> df + A B + 0 1 3 + 1 0 4 + 2 -2 5 + + >>> dfa = df_apply_mirror_augmentation(df, "A") + >>> dfa + A B + 0 1 3 + 0 -1 3 + 1 0 4 + 2 -2 5 + 2 2 5 + + >>> dfa = df_apply_mirror_augmentation(df, "A", mirror_point=1) + >>> dfa + A B + 0 1 3 + 1 0 4 + 1 2 4 + 2 -2 5 + 2 4 5 + """ + new_rows: list[pd.DataFrame] = [] + for _, row in df.iterrows(): + to_add = [row] # Always keep original row + + # Create the augmented row by mirroring the point at the mirror point. + # x_mirrored = mirror_point + (mirror_point - x) = 2*mirror_point - x + if row[column] != mirror_point: + row_new = row.copy() + row_new[column] = 2.0 * mirror_point - row[column] + to_add.append(row_new) + + # Store augmented rows, keeping the index of their original row + new_rows.append( + pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) + ) + + return pd.concat(new_rows) + + def df_apply_dependency_augmentation( df: pd.DataFrame, causing: tuple[str, Sequence], diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 84b831a406..595eba3fc3 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -192,6 +192,10 @@ def create_fake_input( Raises: ValueError: If less than one row was requested. + + Note: + This function does not consider constraints and might provide unexpected or + invalid data if certain constraints are present. """ # Assert at least one fake entry is being generated if n_rows < 1: From 221234be182a6f7e560cd2cc5f1dc902cf768538 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:40:03 +0100 Subject: [PATCH 04/23] Add Symmetry domain model --- baybe/symmetries/__init__.py | 13 +++ baybe/symmetries/base.py | 87 ++++++++++++++++ baybe/symmetries/dependency.py | 169 ++++++++++++++++++++++++++++++ baybe/symmetries/mirror.py | 86 +++++++++++++++ baybe/symmetries/permutation.py | 179 ++++++++++++++++++++++++++++++++ 5 files changed, 534 insertions(+) create mode 100644 baybe/symmetries/__init__.py create mode 100644 baybe/symmetries/base.py create mode 100644 baybe/symmetries/dependency.py create mode 100644 baybe/symmetries/mirror.py create mode 100644 baybe/symmetries/permutation.py diff --git a/baybe/symmetries/__init__.py b/baybe/symmetries/__init__.py new file mode 100644 index 0000000000..abbeedd974 --- /dev/null +++ b/baybe/symmetries/__init__.py @@ -0,0 +1,13 @@ +"""Symmetry classes for expressing invariances of the modeling process.""" + +from baybe.symmetries.base import Symmetry +from baybe.symmetries.dependency import DependencySymmetry +from baybe.symmetries.mirror import MirrorSymmetry +from baybe.symmetries.permutation import PermutationSymmetry + +__all__ = [ + "DependencySymmetry", + "MirrorSymmetry", + "PermutationSymmetry", + "Symmetry", +] diff --git a/baybe/symmetries/base.py b/baybe/symmetries/base.py new file mode 100644 index 0000000000..96d8d80324 --- /dev/null +++ b/baybe/symmetries/base.py @@ -0,0 +1,87 @@ +"""Base class for symmetries.""" + +from __future__ import annotations + +import gc +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import pandas as pd +from attrs import define, field +from attrs.validators import instance_of + +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.serialization import SerialMixin + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class Symmetry(SerialMixin, ABC): + """Abstract base class for symmetries. + + A ``Symmetry`` is a concept that can be used to configure the modeling process in + the presence of invariances. + """ + + use_data_augmentation: bool = field( + default=True, validator=instance_of(bool), kw_only=True + ) + """Flag indicating whether data augmentation is to be used.""" + + @property + @abstractmethod + def parameter_names(self) -> tuple[str, ...]: + """The names of the parameters affected by the symmetry.""" + + def summary(self) -> dict: + """Return a custom summarization of the symmetry.""" + symmetry_dict = dict( + Type=self.__class__.__name__, Affected_Parameters=self.parameter_names + ) + return symmetry_dict + + @abstractmethod + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + """Augment the given measurements according to the symmetry. + + Args: + measurements: The dataframe containing the measurements to be + augmented. + parameters: Corresponding parameter objects carrying additional + information (not needed by all augmentation types). + + Returns: + The augmented dataframe including the original measurements. + """ + + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """Validate that the symmetry is compatible with the given searchspace. + + Args: + searchspace: The searchspace to validate against. + + Raises: + IncompatibleSearchSpaceError: If the symmetry affects parameters not + present in the searchspace. + """ + parameters_missing = set(self.parameter_names).difference( + searchspace.parameter_names + ) + if parameters_missing: + raise IncompatibleSearchSpaceError( + f"The symmetry of type '{self.__class__.__name__}' was set up with the " + f"following parameters that are not present in the search space: " + f"{parameters_missing}." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/dependency.py b/baybe/symmetries/dependency.py new file mode 100644 index 0000000000..3fdadaeb94 --- /dev/null +++ b/baybe/symmetries/dependency.py @@ -0,0 +1,169 @@ +"""Dependency symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable +from typing import TYPE_CHECKING, cast + +import numpy as np +import pandas as pd +from attrs import Converter, define, field +from attrs.validators import deep_iterable, ge, instance_of, min_len +from typing_extensions import override + +from baybe.constraints.conditions import Condition +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_dependency_augmentation +from baybe.utils.conversion import normalize_convertible2str_sequence +from baybe.utils.validation import validate_unique_values + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class DependencySymmetry(Symmetry): + """Class for representing dependency symmetries. + + A dependency symmetry expresses that certain parameters are dependent on another + parameter having a specific value. For instance, the situation "The value of + parameter y only matters if parameter x has the value 'on'.". In this scenario x + is the causing parameter and y depends on x. + """ + + _parameter_name: str = field(validator=instance_of(str), alias="parameter_name") + """The names of the causing parameter others are depending on.""" + + # object variables + condition: Condition = field(validator=instance_of(Condition)) + """The condition specifying the active range of the causing parameter.""" + + affected_parameter_names: tuple[str, ...] = field( + converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter + normalize_convertible2str_sequence, takes_self=True, takes_field=True + ), + validator=( # type: ignore + validate_unique_values, + deep_iterable( + member_validator=instance_of(str), iterable_validator=min_len(1) + ), + ), + ) + """The parameters affected by the dependency.""" + + n_discretization_points: int = field( + default=3, validator=(instance_of(int), ge(2)), kw_only=True + ) + """Number of points used when subsampling continuous parameter ranges.""" + + @override + @property + def parameter_names(self) -> tuple[str, ...]: + return (self._parameter_name,) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + if not self.use_data_augmentation: + return measurements + + if parameters is None: + raise ValueError( + f"A '{self.__class__.__name__}' requires parameter objects " + f"for data augmentation." + ) + + from baybe.parameters.base import DiscreteParameter + + # The 'causing' entry describes the parameters and the value + # for which one or more affected parameters become degenerate. + # 'cond' specifies for which values the affected parameter + # values are active, i.e. not degenerate. Hence, here we get the + # values that are not active, as rows containing them should be + # augmented. + param = next( + cast(DiscreteParameter, p) + for p in parameters + if p.name == self._parameter_name + ) + + causing_values = [ + x + for x, flag in zip( + param.values, + ~self.condition.evaluate(pd.Series(param.values)), + strict=True, + ) + if flag + ] + causing = (param.name, causing_values) + + # The 'affected' entry describes the affected parameters and the + # values they are allowed to take, which are all degenerate if + # the corresponding condition for the causing parameter is met. + affected: list[tuple[str, tuple[float, ...]]] = [] + for pn in self.affected_parameter_names: + p = next(p for p in parameters if p.name == pn) + if p.is_discrete: + # Use all values for augmentation + vals = cast(DiscreteParameter, p).values + else: + # Use linear subsample of parameter bounds interval for augmentation. + # Note: The original value will not necessarily be part of this. + vals = tuple( + np.linspace( + p.bounds.lower, # type: ignore[attr-defined] + p.bounds.upper, # type: ignore[attr-defined] + self.n_discretization_points, + ) + ) + affected.append((p.name, vals)) + + measurements = df_apply_dependency_augmentation(measurements, causing, affected) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + IncompatibleSearchSpaceError: If any of the affected parameters is + not present in the searchspace. + TypeError: If the causing parameter is not discrete. + """ + super().validate_searchspace_context(searchspace) + + # Affected parameters must be in the searchspace + parameters_missing = set(self.affected_parameter_names).difference( + searchspace.parameter_names + ) + if parameters_missing: + raise IncompatibleSearchSpaceError( + f"The symmetry of type '{self.__class__.__name__}' was set up " + f"with at least one parameter which is not present in the " + f"search space: {parameters_missing}." + ) + + # Causing parameter must be discrete + param = searchspace.get_parameters_by_name(self._parameter_name)[0] + if not param.is_discrete: + raise TypeError( + f"In a '{self.__class__.__name__}', the causing parameter must " + f"be discrete. However, the parameter '{param.name}' is of " + f"type '{param.__class__.__name__}' and is not discrete." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/mirror.py b/baybe/symmetries/mirror.py new file mode 100644 index 0000000000..9a7ac7d06f --- /dev/null +++ b/baybe/symmetries/mirror.py @@ -0,0 +1,86 @@ +"""Mirror symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import pandas as pd +from attrs import define, field +from attrs.validators import instance_of +from typing_extensions import override + +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_mirror_augmentation +from baybe.utils.validation import validate_is_finite + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class MirrorSymmetry(Symmetry): + """Class for representing mirror symmetries. + + A mirror symmetry expresses that certain parameters can be inflected at a mirror + point without affecting the outcome of the model. For instance, when specified + for parameter ``x`` and mirror point ``c``, the symmetry expresses that + $f(..., c+x, ...) = f(..., c-x, ...)$. + """ + + _parameter_name: str = field(validator=instance_of(str), alias="parameter_name") + """The name of the single parameter affected by the symmetry.""" + + # object variables + mirror_point: float = field( + default=0.0, converter=float, validator=validate_is_finite, kw_only=True + ) + """The mirror point.""" + + @override + @property + def parameter_names(self) -> tuple[str]: + return (self._parameter_name,) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + + if not self.use_data_augmentation: + return measurements + + measurements = df_apply_mirror_augmentation( + measurements, self._parameter_name, mirror_point=self.mirror_point + ) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + TypeError: If the affected parameter is not numerical. + """ + super().validate_searchspace_context(searchspace) + + param = searchspace.get_parameters_by_name(self.parameter_names)[0] + if not param.is_numerical: + raise TypeError( + f"In a '{self.__class__.__name__}', the affected parameter must " + f"be numerical. However, the parameter '{param.name}' is of " + f"type '{param.__class__.__name__}' and is not numerical." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/permutation.py b/baybe/symmetries/permutation.py new file mode 100644 index 0000000000..03d1554874 --- /dev/null +++ b/baybe/symmetries/permutation.py @@ -0,0 +1,179 @@ +"""Permutation symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable, Sequence +from itertools import combinations +from typing import TYPE_CHECKING, Any, cast + +import pandas as pd +from attrs import define, field +from attrs.validators import deep_iterable, instance_of, min_len +from typing_extensions import override + +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_permutation_augmentation + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +def _convert_groups( + groups: Sequence[Sequence[str]], +) -> tuple[tuple[str, ...], ...]: + """Convert nested sequences to a tuple of tuples, blocking bare strings.""" + if isinstance(groups, str): + raise ValueError( + "The 'permutation_groups' argument must be a sequence of sequences, " + "not a string." + ) + converted = [] + for g in groups: + if isinstance(g, str): + raise ValueError( + "Each element in 'permutation_groups' must be a sequence of " + "parameter names, not a string." + ) + converted.append(tuple(g)) + return tuple(converted) + + +@define(frozen=True) +class PermutationSymmetry(Symmetry): + """Class for representing permutation symmetries. + + A permutation symmetry expresses that certain parameters can be permuted without + affecting the outcome of the model. For instance, this is the case if + $f(x,y) = f(y,x)$. + """ + + permutation_groups: tuple[tuple[str, ...], ...] = field( + converter=_convert_groups, + validator=( + min_len(1), + deep_iterable( + member_validator=deep_iterable( + member_validator=instance_of(str), + iterable_validator=min_len(2), + ), + ), + ), + ) + """The permutation groups. Each group contains parameter names that are permuted + in lockstep. All groups must have the same length.""" + + @permutation_groups.validator + def _validate_permutation_groups( # noqa: DOC101, DOC103 + self, _: Any, groups: tuple[tuple[str, ...], ...] + ) -> None: + """Validate the permutation groups. + + Raises: + ValueError: If the groups have different lengths. + ValueError: If any group contains duplicate parameters. + ValueError: If any parameter name appears in multiple groups. + """ + # Ensure all groups have the same length + if len({len(g) for g in groups}) != 1: + lengths = {k + 1: len(g) for k, g in enumerate(groups)} + raise ValueError( + f"In a '{self.__class__.__name__}', all permutation groups " + f"must have the same length. Got group lengths: {lengths}." + ) + + # Ensure parameter names in each group are unique + for group in groups: + if len(set(group)) != len(group): + raise ValueError( + f"In a '{self.__class__.__name__}', all parameters being " + f"permuted with each other must be unique. However, the " + f"following group contains duplicates: {group}." + ) + + # Ensure there is no overlap between any permutation group + for a, b in combinations(groups, 2): + if overlap := set(a) & set(b): + raise ValueError( + f"In a '{self.__class__.__name__}', parameter names cannot " + f"appear in multiple permutation groups. However, the " + f"following parameter names appear in several groups: " + f"{overlap}." + ) + + @override + @property + def parameter_names(self) -> tuple[str, ...]: + return tuple(name for group in self.permutation_groups for name in group) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + + if not self.use_data_augmentation: + return measurements + + measurements = df_apply_permutation_augmentation( + measurements, self.permutation_groups + ) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + TypeError: If parameters within a permutation group do not have the + same type. + ValueError: If parameters within a permutation group do not have a + compatible set of values. + """ + super().validate_searchspace_context(searchspace) + + # Ensure permuted parameters all have the same specification. + # Without this, it could be attempted to read in data that is not allowed + # for parameters that only allow a subset or different values compared to + # parameters they are being permuted with. + for group in self.permutation_groups: + params = searchspace.get_parameters_by_name(group) + + # All parameters in a group must be of the same type + if len(types := {type(p).__name__ for p in params}) != 1: + raise TypeError( + f"In a '{self.__class__.__name__}', all parameters being " + f"permuted with each other must have the same type. " + f"However, the following multiple types were found in the " + f"permutation group {group}: {types}." + ) + + # All parameters in a group must have the same values. Numerical + # parameters are not considered here since technically for them + # this restriction is not required as all numbers can be added if + # the tolerance is configured accordingly. + if all(p.is_discrete and not p.is_numerical for p in params): + from baybe.parameters.base import DiscreteParameter + + ref_vals = set(cast(DiscreteParameter, params[0]).values) + if any( + set(cast(DiscreteParameter, p).values) != ref_vals + for p in params[1:] + ): + raise ValueError( + f"The parameter group {group} contains " + f"parameters with different values. All " + f"parameters in a group must have the same " + f"specification." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() From e9dff9c85f0fbbf046b504589f451b9b1e38c061 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:41:00 +0100 Subject: [PATCH 05/23] Add Parameter.is_equivalent and apply in PermutationSymmetry --- baybe/parameters/base.py | 17 ++++++++++++++ baybe/symmetries/permutation.py | 41 ++++++++------------------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 2d4df2bc77..9ae13e9a98 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -7,6 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar +import attrs import pandas as pd from attrs import define, field from attrs.converters import optional as optional_c @@ -88,6 +89,22 @@ def to_searchspace(self) -> SearchSpace: return SearchSpace.from_parameter(self) + def is_equivalent(self, other: Parameter) -> bool: + """Check if this parameter is equivalent to another, ignoring the name. + + Two parameters are considered equivalent if they have the same type and + all attributes are equal except for the name. + + Args: + other: The parameter to compare against. + + Returns: + ``True`` if the parameters are equivalent, ``False`` otherwise. + """ + if type(self) is not type(other): + return False + return attrs.evolve(self, name=other.name) == other + @abstractmethod def summary(self) -> dict: """Return a custom summarization of the parameter.""" diff --git a/baybe/symmetries/permutation.py b/baybe/symmetries/permutation.py index 03d1554874..deab356f9e 100644 --- a/baybe/symmetries/permutation.py +++ b/baybe/symmetries/permutation.py @@ -5,7 +5,7 @@ import gc from collections.abc import Iterable, Sequence from itertools import combinations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import pandas as pd from attrs import define, field @@ -132,10 +132,8 @@ def validate_searchspace_context(self, searchspace: SearchSpace) -> None: searchspace: The searchspace to validate against. Raises: - TypeError: If parameters within a permutation group do not have the - same type. - ValueError: If parameters within a permutation group do not have a - compatible set of values. + ValueError: If parameters within a permutation group are not + equivalent (i.e., differ in type or specification). """ super().validate_searchspace_context(searchspace) @@ -145,33 +143,14 @@ def validate_searchspace_context(self, searchspace: SearchSpace) -> None: # parameters they are being permuted with. for group in self.permutation_groups: params = searchspace.get_parameters_by_name(group) - - # All parameters in a group must be of the same type - if len(types := {type(p).__name__ for p in params}) != 1: - raise TypeError( - f"In a '{self.__class__.__name__}', all parameters being " - f"permuted with each other must have the same type. " - f"However, the following multiple types were found in the " - f"permutation group {group}: {types}." - ) - - # All parameters in a group must have the same values. Numerical - # parameters are not considered here since technically for them - # this restriction is not required as all numbers can be added if - # the tolerance is configured accordingly. - if all(p.is_discrete and not p.is_numerical for p in params): - from baybe.parameters.base import DiscreteParameter - - ref_vals = set(cast(DiscreteParameter, params[0]).values) - if any( - set(cast(DiscreteParameter, p).values) != ref_vals - for p in params[1:] - ): + ref = params[0] + for p in params[1:]: + if not ref.is_equivalent(p): raise ValueError( - f"The parameter group {group} contains " - f"parameters with different values. All " - f"parameters in a group must have the same " - f"specification." + f"In a '{self.__class__.__name__}', all parameters " + f"within a permutation group must be equivalent. " + f"However, '{ref.name}' and '{p.name}' differ in " + f"their specification." ) From d1c88db96512929c340b47dbe88aae2179108500 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:43:14 +0100 Subject: [PATCH 06/23] Integrate symmetries into surrogates and recommenders --- baybe/recommenders/pure/bayesian/base.py | 10 +++++++ baybe/surrogates/base.py | 33 ++++++++++++++++++++++- baybe/surrogates/gaussian_process/core.py | 2 +- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 4ac5c1eed2..093fd89600 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -114,6 +114,10 @@ def _setup_botorch_acqf( f"{len(objective.targets)}-target multi-output context." ) + # Perform data augmentation if configured + if hasattr(s := self._surrogate_model, "augment_measurements"): + measurements = s.augment_measurements(measurements, searchspace.parameters) + surrogate = self.get_surrogate(searchspace, objective, measurements) self._botorch_acqf = acqf.to_botorch( surrogate, @@ -156,6 +160,12 @@ def recommend( validate_object_names(searchspace.parameters + objective.targets) + # Validate compatibility of surrogate symmetries with searchspace + if hasattr(self._surrogate_model, "symmetries"): + for s in self._surrogate_model.symmetries: + s.validate_searchspace_context(searchspace) + + # Experimental input validation if (measurements is None) or measurements.empty: raise NotImplementedError( f"Recommenders of type '{BayesianRecommender.__name__}' do not support " diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 205e32f703..5e93b48919 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -4,12 +4,13 @@ import gc from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from enum import Enum, auto from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias import pandas as pd from attrs import define, field +from attrs.validators import deep_iterable, instance_of from joblib.hashing import hash from typing_extensions import override @@ -18,6 +19,7 @@ from baybe.parameters.base import Parameter from baybe.searchspace import SearchSpace from baybe.serialization.mixin import SerialMixin +from baybe.symmetries import Symmetry from baybe.utils.basic import classproperty from baybe.utils.conversion import to_string from baybe.utils.dataframe import handle_missing_values, to_tensor @@ -90,6 +92,14 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): """Class variable encoding whether or not the surrogate is multi-output compatible.""" + symmetries: tuple[Symmetry, ...] = field( + factory=tuple, + converter=tuple, + validator=deep_iterable(member_validator=instance_of(Symmetry)), + kw_only=True, + ) + """Symmetries to be considered by the surrogate model.""" + _searchspace: SearchSpace | None = field(init=False, default=None, eq=False) """The search space on which the surrogate operates. Available after fitting.""" @@ -115,6 +125,27 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): Scales a tensor containing target measurements in computational representation to make them digestible for the model-specific, scale-agnostic posterior logic.""" + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + """Apply data augmentation to measurements. + + Args: + measurements: A dataframe with measurements. + parameters: Parameter objects carrying additional information (might + not be needed by all augmentation implementations). + + Returns: + A dataframe with the augmented measurements, including the original + ones. + """ + for s in self.symmetries: + measurements = s.augment_measurements(measurements, parameters) + + return measurements + @classproperty def is_available(cls) -> bool: """Indicates if the surrogate class is available in the Python environment. diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index c0148aca55..bfffaf5a48 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -142,7 +142,7 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: - import botorch + import botorch.models.transforms import gpytorch import torch from botorch.models.transforms import Normalize, Standardize From 68b46a992719168606fb3e81745a0c4299824817 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:44:21 +0100 Subject: [PATCH 07/23] Update constraints for symmetry support --- baybe/constraints/base.py | 4 ---- baybe/constraints/conditions.py | 2 +- baybe/constraints/discrete.py | 40 ++++++++++++++++++++++++++------- baybe/searchspace/core.py | 5 ----- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/baybe/constraints/base.py b/baybe/constraints/base.py index 5c1a6d33ed..b38548efcc 100644 --- a/baybe/constraints/base.py +++ b/baybe/constraints/base.py @@ -37,10 +37,6 @@ class Constraint(ABC, SerialMixin): eval_during_modeling: ClassVar[bool] """Class variable encoding whether the condition is evaluated during modeling.""" - eval_during_augmentation: ClassVar[bool] = False - """Class variable encoding whether the constraint could be considered during data - augmentation.""" - numerical_only: ClassVar[bool] = False """Class variable encoding whether the constraint is valid only for numerical parameters.""" diff --git a/baybe/constraints/conditions.py b/baybe/constraints/conditions.py index a136f6e170..3dd52b5d78 100644 --- a/baybe/constraints/conditions.py +++ b/baybe/constraints/conditions.py @@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin): """Abstract base class for all conditions. Conditions always evaluate an expression regarding a single parameter. - Conditions are part of constraints, a constraint can have multiple conditions. + Conditions are part of constraints and symmetries. """ @abstractmethod diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 740e603f89..91f5d8147d 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: import polars as pl + from baybe.symmetries import DependencySymmetry, PermutationSymmetry + @define class DiscreteExcludeConstraint(DiscreteConstraint): @@ -195,10 +197,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint): a single constraint. """ - # class variables - eval_during_augmentation: ClassVar[bool] = True - # See base class - # object variables conditions: list[Condition] = field() """The list of individual conditions.""" @@ -271,6 +269,24 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_bad + def to_symmetries( + self, use_data_augmentation=True + ) -> tuple[DependencySymmetry, ...]: + """Convert to a :class:`~baybe.symmetries.DependencySymmetry`.""" + from baybe.symmetries import DependencySymmetry + + return tuple( + DependencySymmetry( + parameter_name=p, + condition=c, + affected_parameter_names=aps, + use_data_augmentation=use_data_augmentation, + ) + for p, c, aps in zip( + self.parameters, self.conditions, self.affected_parameters, strict=True + ) + ) + @define class DiscretePermutationInvarianceConstraint(DiscreteConstraint): @@ -285,10 +301,6 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint): evaluated during modeling to make use of the invariance. """ - # class variables - eval_during_augmentation: ClassVar[bool] = True - # See base class - # object variables dependencies: DiscreteDependenciesConstraint | None = field(default=None) """Dependencies connected with the invariant parameters.""" @@ -337,6 +349,18 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_invalid + def to_symmetry(self, use_data_augmentation=True) -> PermutationSymmetry: + """Convert to a :class:`~baybe.symmetries.PermutationSymmetry`.""" + from baybe.symmetries import PermutationSymmetry + + groups = [self.parameters] + if self.dependencies: + groups.append(list(self.dependencies.parameters)) + return PermutationSymmetry( + permutation_groups=groups, + use_data_augmentation=use_data_augmentation, + ) + @define class DiscreteCustomConstraint(DiscreteConstraint): diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index 8b0da30c92..34e5f64047 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -381,11 +381,6 @@ def transform( return comp_rep - @property - def constraints_augmentable(self) -> tuple[Constraint, ...]: - """The searchspace constraints that can be considered during augmentation.""" - return tuple(c for c in self.constraints if c.eval_during_augmentation) - def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]: """Return parameters with the specified names. From 241a9d3259ab8b036fbe0ea3478bce2717cc2195 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:45:00 +0100 Subject: [PATCH 08/23] Add hypothesis strategies for symmetries and conditions --- tests/hypothesis_strategies/conditions.py | 24 ++++ tests/hypothesis_strategies/constraints.py | 23 +--- tests/hypothesis_strategies/symmetries.py | 121 +++++++++++++++++++++ 3 files changed, 149 insertions(+), 19 deletions(-) create mode 100644 tests/hypothesis_strategies/conditions.py create mode 100644 tests/hypothesis_strategies/symmetries.py diff --git a/tests/hypothesis_strategies/conditions.py b/tests/hypothesis_strategies/conditions.py new file mode 100644 index 0000000000..fe90efdaa8 --- /dev/null +++ b/tests/hypothesis_strategies/conditions.py @@ -0,0 +1,24 @@ +"""Hypothesis strategies for conditions.""" + +from typing import Any + +import hypothesis.strategies as st + +from baybe.constraints import SubSelectionCondition, ThresholdCondition +from tests.hypothesis_strategies.basic import finite_floats + + +def sub_selection_conditions(superset: list[Any] | None = None): + """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" + if superset is None: + element_strategy = st.text() + else: + element_strategy = st.sampled_from(superset) + return st.builds( + SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1) + ) + + +def threshold_conditions(): + """Generate :class:`baybe.constraints.conditions.ThresholdCondition`.""" + return st.builds(ThresholdCondition, threshold=finite_floats()) diff --git a/tests/hypothesis_strategies/constraints.py b/tests/hypothesis_strategies/constraints.py index e1f1014833..dfc065d817 100644 --- a/tests/hypothesis_strategies/constraints.py +++ b/tests/hypothesis_strategies/constraints.py @@ -1,14 +1,11 @@ """Hypothesis strategies for constraints.""" from functools import partial -from typing import Any import hypothesis.strategies as st from hypothesis import assume from baybe.constraints.conditions import ( - SubSelectionCondition, - ThresholdCondition, _valid_logic_combiners, ) from baybe.constraints.continuous import ( @@ -26,22 +23,10 @@ from baybe.parameters.base import DiscreteParameter from baybe.parameters.numerical import NumericalDiscreteParameter from tests.hypothesis_strategies.basic import finite_floats - - -def sub_selection_conditions(superset: list[Any] | None = None): - """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" - if superset is None: - element_strategy = st.text() - else: - element_strategy = st.sampled_from(superset) - return st.builds( - SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1) - ) - - -def threshold_conditions(): - """Generate :class:`baybe.constraints.conditions.ThresholdCondition`.""" - return st.builds(ThresholdCondition, threshold=finite_floats()) +from tests.hypothesis_strategies.conditions import ( + sub_selection_conditions, + threshold_conditions, +) @st.composite diff --git a/tests/hypothesis_strategies/symmetries.py b/tests/hypothesis_strategies/symmetries.py new file mode 100644 index 0000000000..31e45aaa40 --- /dev/null +++ b/tests/hypothesis_strategies/symmetries.py @@ -0,0 +1,121 @@ +"""Hypothesis strategies for symmetries.""" + +import hypothesis.strategies as st +from hypothesis import assume + +from baybe.parameters.base import Parameter +from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry +from tests.hypothesis_strategies.basic import finite_floats +from tests.hypothesis_strategies.conditions import ( + sub_selection_conditions, + threshold_conditions, +) + + +@st.composite +def mirror_symmetries(draw: st.DrawFn, parameter_pool: list[Parameter] | None = None): + """Generate :class:`baybe.symmetries.MirrorSymmetry`.""" + if parameter_pool is None: + parameter_name = draw(st.text(min_size=1)) + else: + parameter = draw(st.sampled_from(parameter_pool)) + assume(parameter.is_numerical) + parameter_name = parameter.name + + return MirrorSymmetry( + parameter_name=parameter_name, + use_data_augmentation=draw(st.booleans()), + mirror_point=draw(finite_floats()), + ) + + +@st.composite +def permutation_symmetries( + draw: st.DrawFn, parameter_pool: list[Parameter] | None = None +): + """Generate :class:`baybe.symmetries.PermutationSymmetry`.""" + if parameter_pool is None: + parameter_names_pool = draw( + st.lists(st.text(min_size=1), unique=True, min_size=2) + ) + else: + parameter_names_pool = [p.name for p in parameter_pool] + + # Draw the first (required) group with at least 2 elements + first_group = draw( + st.lists(st.sampled_from(parameter_names_pool), min_size=2, unique=True).map( + tuple + ) + ) + group_size = len(first_group) + + # Determine how many additional groups we can have + remaining_names = [n for n in parameter_names_pool if n not in first_group] + max_additional = len(remaining_names) // group_size if group_size > 0 else 0 + n_additional = draw(st.integers(min_value=0, max_value=max_additional)) + + # Draw additional groups of the same size from remaining names + all_groups = [first_group] + for _ in range(n_additional): + group = draw( + st.lists( + st.sampled_from(remaining_names), + unique=True, + min_size=group_size, + max_size=group_size, + ).map(tuple) + ) + # Ensure no overlap with existing groups + for existing in all_groups: + assume(not set(group) & set(existing)) + remaining_names = [n for n in remaining_names if n not in group] + all_groups.append(group) + + return PermutationSymmetry( + permutation_groups=all_groups, + use_data_augmentation=draw(st.booleans()), + ) + + +@st.composite +def dependency_symmetries( + draw: st.DrawFn, parameter_pool: list[Parameter] | None = None +): + """Generate :class:`baybe.symmetries.DependencySymmetry`.""" + if parameter_pool is None: + parameter_name = draw(st.text(min_size=1)) + affected_strat = st.lists( + st.text(min_size=1).filter(lambda x: x != parameter_name), + min_size=1, + unique=True, + ).map(tuple) + else: + parameter = draw(st.sampled_from(parameter_pool)) + assume(parameter.is_discrete) + parameter_name = parameter.name + affected_strat = st.lists( + st.sampled_from(parameter_pool) + .filter(lambda x: x.name != parameter_name) + .map(lambda x: x.name), + unique=True, + min_size=1, + max_size=len(parameter_pool) - 1, + ).map(tuple) + + return DependencySymmetry( + parameter_name=parameter_name, + condition=draw(st.one_of(threshold_conditions(), sub_selection_conditions())), + affected_parameter_names=draw(affected_strat), + n_discretization_points=draw(st.integers(min_value=2)), + use_data_augmentation=draw(st.booleans()), + ) + + +symmetries = st.one_of( + [ + mirror_symmetries(), + permutation_symmetries(), + dependency_symmetries(), + ] +) +"""A strategy that generates symmetries.""" From d82f8de310e3acf51b62849b0ce7ec93f167988d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:45:37 +0100 Subject: [PATCH 09/23] Add symmetry tests --- tests/conftest.py | 2 +- .../test_symmetry_serialization.py | 28 +++ tests/test_measurement_augmentation.py | 143 +++++++++++ tests/validation/test_symmetry_validation.py | 227 ++++++++++++++++++ 4 files changed, 399 insertions(+), 1 deletion(-) create mode 100644 tests/serialization/test_symmetry_serialization.py create mode 100644 tests/test_measurement_augmentation.py create mode 100644 tests/validation/test_symmetry_validation.py diff --git a/tests/conftest.py b/tests/conftest.py index 882204c156..5cd9537b67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -242,7 +242,7 @@ def fixture_parameters( NumericalDiscreteParameter( name="Fraction_1", values=tuple(np.linspace(0, 100, n_grid_points)), - tolerance=0.2, + tolerance=0.5, ), NumericalDiscreteParameter( name="Fraction_2", diff --git a/tests/serialization/test_symmetry_serialization.py b/tests/serialization/test_symmetry_serialization.py new file mode 100644 index 0000000000..e1f6ae9373 --- /dev/null +++ b/tests/serialization/test_symmetry_serialization.py @@ -0,0 +1,28 @@ +"""Symmetry serialization tests.""" + +import hypothesis.strategies as st +import pytest +from hypothesis import given +from pytest import param + +from tests.hypothesis_strategies.symmetries import ( + dependency_symmetries, + mirror_symmetries, + permutation_symmetries, +) +from tests.serialization.utils import assert_roundtrip_consistency + + +@pytest.mark.parametrize( + "strategy", + [ + param(mirror_symmetries(), id="MirrorSymmetry"), + param(permutation_symmetries(), id="PermutationSymmetry"), + param(dependency_symmetries(), id="DependencySymmetry"), + ], +) +@given(data=st.data()) +def test_roundtrip(strategy: st.SearchStrategy, data: st.DataObject): + """A serialization roundtrip yields an equivalent object.""" + symmetry = data.draw(strategy) + assert_roundtrip_consistency(symmetry) diff --git a/tests/test_measurement_augmentation.py b/tests/test_measurement_augmentation.py new file mode 100644 index 0000000000..fa7e76c289 --- /dev/null +++ b/tests/test_measurement_augmentation.py @@ -0,0 +1,143 @@ +"""Tests for augmentation of measurements.""" + +import math +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +from attrs import evolve +from pandas.testing import assert_frame_equal + +from baybe.acquisition import qLogEI +from baybe.constraints import ( + DiscretePermutationInvarianceConstraint, + ThresholdCondition, +) +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.recommenders import BotorchRecommender +from baybe.searchspace import SearchSpace +from baybe.symmetries import DependencySymmetry, MirrorSymmetry +from baybe.utils.dataframe import create_fake_input + + +@pytest.mark.parametrize("mirror_aug", [True, False], ids=["mirror", "nomirror"]) +@pytest.mark.parametrize("perm_aug", [True, False], ids=["perm", "noperm"]) +@pytest.mark.parametrize("dep_aug", [True, False], ids=["dep", "nodep"]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_11", "Constraint_7"]], ids=["c"] +) +@pytest.mark.parametrize( + "parameter_names", + [ + [ + "Solvent_1", + "Solvent_2", + "Solvent_3", + "Fraction_1", + "Fraction_2", + "Fraction_3", + "Num_disc_2", + ] + ], + ids=["p"], +) +def test_measurement_augmentation( + parameters, + surrogate_model, + objective, + constraints, + dep_aug, + perm_aug, + mirror_aug, +): + """Measurement augmentation is performed if configured.""" + original_to_botorch = qLogEI.to_botorch + called_args_list = [] + + def spy(self, *args, **kwargs): + called_args_list.append((args, kwargs)) + return original_to_botorch(self, *args, **kwargs) + + with patch.object(qLogEI, "to_botorch", side_effect=spy, autospec=True): + # Basic setup + c_perm = next( + c + for c in constraints + if isinstance(c, DiscretePermutationInvarianceConstraint) + ) + c_dep = c_perm.dependencies + s_perm = c_perm.to_symmetry(perm_aug) + s_deps = c_dep.to_symmetries(dep_aug) # this is a tuple of multiple + s_mirror = MirrorSymmetry("Num_disc_2", use_data_augmentation=mirror_aug) + searchspace = SearchSpace.from_product(parameters, constraints) + surrogate = evolve(surrogate_model, symmetries=[*s_deps, s_perm, s_mirror]) + recommender = BotorchRecommender( + surrogate_model=surrogate, acquisition_function=qLogEI() + ) + + # Perform call and watch measurements + measurements = create_fake_input(parameters, objective.targets, 5) + recommender.recommend(1, searchspace, objective, measurements) + measurements_passed = called_args_list[0][0][3] # take 4th arg from first call + + # Create expectation + # We calculate how many degenerate points the augmentation should create: + # - n_dep: Product of the number of active values for all affected parameters + # - n_perm: Number of permutations possible + # - n_mirror: 2 if the row is not on the mirror point, else 1 + # - If augmentation is turned off, the corresponding factor becomes 1 + # We expect a given row to produce n_perm * (n_dep^k) * n_mirror points, where + # k is the number of "Fraction_*" parameters having the "causing" value 0.0. The + # total number of expected points after augmentation is the sum over the + # expectations for all rows. + dep_affected = [p for p in parameters if p.name in c_dep.affected_parameters[0]] + n_dep = math.prod(len(p.active_values) for p in dep_affected) if dep_aug else 1 + n_perm = ( # number of permutations + math.prod(range(1, len(c_perm.parameters) + 1)) if perm_aug else 1 + ) + n_expected = 0 + for _, row in measurements.iterrows(): + n_mirror = ( + 2 if (mirror_aug and row["Num_disc_2"] != s_mirror.mirror_point) else 1 + ) + k = row[c_dep.parameters].eq(0).sum() + n_expected += n_perm * np.pow(n_dep, k) * n_mirror + + # Check expectation + if any([dep_aug, perm_aug, mirror_aug]): + assert len(measurements_passed) == n_expected + else: + assert_frame_equal(measurements, measurements_passed) + + +@pytest.mark.parametrize("n_points", [2, 5]) +@pytest.mark.parametrize("mixed", [True, False]) +def test_continuous_dependency_augmentation(n_points, mixed): + """Dependency augmentation with continuous affected parameters works correctly.""" + df = pd.DataFrame({"n1": [1, 0], "cat1": ["a", "b"], "c1": [1, 2]}) + ps = [ + NumericalDiscreteParameter("n1", (0, 0.5, 1.0)), + CategoricalParameter("cat1", ("a", "b", "c")), + NumericalContinuousParameter("c1", (0, 10)), + ] + + # Model "Affected parameters only matter if n1 is > 0" + s = DependencySymmetry( + "n1", + condition=ThresholdCondition(0.0, ">"), + affected_parameter_names=("cat1", "c1") if mixed else ("c1",), + n_discretization_points=n_points, + ) + dfa = s.augment_measurements(df, ps) + + # Calculate expectation. The first row is not affected and contributes one point. + # The second row is degenerate and contributes `n_points` values for c1 and 3 + # values for cat1 (if part of the symmetry). + n_expected = 1 + n_points * (3 if mixed else 1) + + assert len(dfa) == n_expected diff --git a/tests/validation/test_symmetry_validation.py b/tests/validation/test_symmetry_validation.py new file mode 100644 index 0000000000..a6c3749027 --- /dev/null +++ b/tests/validation/test_symmetry_validation.py @@ -0,0 +1,227 @@ +"""Validation tests for symmetry.""" + +import numpy as np +import pytest +from pytest import param + +from baybe.constraints import ThresholdCondition +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.recommenders import BotorchRecommender +from baybe.searchspace import SearchSpace +from baybe.surrogates import GaussianProcessSurrogate +from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry +from baybe.targets import NumericalTarget +from baybe.utils.dataframe import create_fake_input + +valid_config_mirror = {"parameter_name": "n1"} +valid_config_perm = { + "permutation_groups": [["cat1", "cat2"], ["n1", "n2"]], +} +valid_config_dep = { + "parameter_name": "n1", + "condition": ThresholdCondition(0.0, ">="), + "affected_parameter_names": ["n2", "cat1"], +} + + +@pytest.mark.parametrize( + "cls, config, error, msg", + [ + param( + MirrorSymmetry, + valid_config_mirror | {"mirror_point": np.inf}, + ValueError, + "values containing infinity/nan to attribute 'mirror_point': inf", + id="mirror_nonfinite", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1", "cat1"]]}, + ValueError, + r"the following group contains duplicates", + id="perm_not_unique", + ), + param( + PermutationSymmetry, + { + "permutation_groups": [ + ["cat1", "cat2", "cat3"], + ["n1", "n2", "n3", "n4"], + ] + }, + ValueError, + "must have the same length", + id="perm_different_lengths", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1", "cat2"], ["cat1", "n2"]]}, + ValueError, + r"following parameter names appear in several groups", + id="perm_overlap", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1"]]}, + ValueError, + "must be >= 2", + id="perm_group_too_small", + ), + param( + PermutationSymmetry, + {"permutation_groups": []}, + ValueError, + "must be >= 1", + id="perm_no_groups", + ), + param( + PermutationSymmetry, + {"permutation_groups": [[1, 2]]}, + TypeError, + "must be ", + id="perm_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"parameter_name": 1}, + TypeError, + "must be ", + id="dep_param_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"condition": 1}, + TypeError, + "must be ", + id="dep_wrong_cond_type", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": []}, + ValueError, + "Length of 'affected_parameter_names' must be >= 1", + id="dep_affected_empty", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": [1]}, + TypeError, + "must be ", + id="dep_affected_wrong_type", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": ["a1", "a1"]}, + ValueError, + r"Entries appearing multiple times: \['a1'\].", + id="dep_affected_not_unique", + ), + ], +) +def test_configuration(cls, config, error, msg): + """Invalid configurations raise an expected error.""" + with pytest.raises(error, match=msg): + cls(**config) + + +_parameters = [ + NumericalDiscreteParameter("n1", (-1, 0, 1)), + NumericalDiscreteParameter("n2", (-1, 0, 1)), + NumericalContinuousParameter("n1_not_discrete", (0.0, 10.0)), + NumericalContinuousParameter("n2_not_discrete", (0.0, 10.0)), + NumericalContinuousParameter("c1", (0.0, 10.0)), + NumericalContinuousParameter("c2", (0.0, 10.0)), + CategoricalParameter("cat1", ("a", "b", "c")), + CategoricalParameter("cat1_altered", ("a", "b")), + CategoricalParameter("cat2", ("a", "b", "c")), +] + + +@pytest.fixture +def searchspace(parameter_names): + ps = tuple(p for p in _parameters if p.name in parameter_names) + return SearchSpace.from_product(ps) + + +@pytest.mark.parametrize( + "parameter_names, symmetry, error, msg", + [ + param( + ["cat1"], + MirrorSymmetry(parameter_name="cat1"), + TypeError, + "'cat1' is of type 'CategoricalParameter' and is not numerical", + id="mirror_not_numerical", + ), + param( + ["n1"], + MirrorSymmetry(parameter_name="n2"), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="mirror_param_missing", + ), + param( + ["n2", "cat1"], + DependencySymmetry(**valid_config_dep), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="dep_causing_missing", + ), + param( + ["n1", "cat1"], + DependencySymmetry(**valid_config_dep), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="dep_affected_missing", + ), + param( + ["n1_not_discrete", "n2", "cat1"], + DependencySymmetry( + **valid_config_dep | {"parameter_name": "n1_not_discrete"} + ), + TypeError, + "must be discrete. However, the parameter 'n1_not_discrete'", + id="dep_causing_not_discrete", + ), + param( + ["cat1", "n1", "n2"], + PermutationSymmetry(**valid_config_perm), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="perm_not_present", + ), + param( + ["cat1", "cat2", "n1", "n2"], + PermutationSymmetry(permutation_groups=[("cat1", "n1"), ("cat2", "n2")]), + ValueError, + r"differ in their specification", + id="perm_inconsistent_types", + ), + param( + ["cat1_altered", "cat2", "n1", "n2"], + PermutationSymmetry( + permutation_groups=[["cat1_altered", "cat2"], ["n1", "n2"]] + ), + ValueError, + r"differ in their specification", + id="perm_inconsistent_values", + ), + ], +) +def test_searchspace_context(searchspace, symmetry, error, msg): + """Configurations not compatible with the searchspace raise an expected error.""" + recommender = BotorchRecommender( + surrogate_model=GaussianProcessSurrogate(symmetries=(symmetry,)) + ) + t = NumericalTarget("t") + measurements = create_fake_input(searchspace.parameters, [t]) + + with pytest.raises(error, match=msg): + recommender.recommend( + 1, searchspace, t.to_objective(), measurements=measurements + ) From 3d08a209fbc1820755e15c103d915c2a8fb2fbf9 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:46:56 +0100 Subject: [PATCH 10/23] Add symmetry documentation --- docs/_static/symmetries/augmentation.svg | 162 +++++++++++++++++++++++ docs/scripts/build_examples.py | 10 +- docs/userguide/constraints.md | 2 + docs/userguide/surrogates.md | 18 ++- docs/userguide/symmetries.md | 51 +++++++ docs/userguide/userguide.md | 1 + 6 files changed, 240 insertions(+), 4 deletions(-) create mode 100644 docs/_static/symmetries/augmentation.svg create mode 100644 docs/userguide/symmetries.md diff --git a/docs/_static/symmetries/augmentation.svg b/docs/_static/symmetries/augmentation.svg new file mode 100644 index 0000000000..ac1ffefeeb --- /dev/null +++ b/docs/_static/symmetries/augmentation.svg @@ -0,0 +1,162 @@ + + + + + + + + + + + + + + + + + + + original points + + augmented points + guide line + + + + + + + + + + + x + y + + + + PermutationSymmetry + + + + + mirror_point + + + x + y + + + + + MirrorSymmetry + + + + + + + + + + + + DependencySymmetry with c(x) = x < p + discrete y + + + + p + + + + + + x + y + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y.values + + continuous y with n_discretization_points=4 + + p + + + + + + + x + y + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y.bounds.upper + y.bounds.lower + + + diff --git a/docs/scripts/build_examples.py b/docs/scripts/build_examples.py index bfd62db2d5..ddf9be318c 100644 --- a/docs/scripts/build_examples.py +++ b/docs/scripts/build_examples.py @@ -31,8 +31,12 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool): else: raise OSError("Destination directory exists but should not be removed.") - # Copy the examples folder in the destination directory - shutil.copytree("examples", destination_directory) + # Copy the examples folder in the destination directory. "__pycache__" might be + # present in the examples folder and needs to be ignored + def ignore_pycache(_, contents: list[str]): + return [item for item in contents if item == "__pycache__"] + + shutil.copytree("examples", destination_directory, ignore=ignore_pycache) # For the toctree of the top level example folder, we need to keep track of all # folders. We thus write the header here and populate it during the execution of the @@ -44,7 +48,7 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool): # This list contains the order of the examples as we want to have them in the end. # The examples that should be the first ones are already included here and skipped - # later on. ALl other are just included. + # later on. All others are just included. ex_order = [ "Basics\n", "Searchspaces\n", diff --git a/docs/userguide/constraints.md b/docs/userguide/constraints.md index e66be3051c..535320d399 100644 --- a/docs/userguide/constraints.md +++ b/docs/userguide/constraints.md @@ -399,6 +399,7 @@ DiscreteDependenciesConstraint( ``` An end to end example can be found [here](../../examples/Constraints_Discrete/dependency_constraints). +For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation). ### DiscretePermutationInvarianceConstraint Permutation invariance, enabled by the @@ -480,6 +481,7 @@ DiscretePermutationInvarianceConstraint( The usage of `DiscretePermutationInvarianceConstraint` is also part of the [example on slot-based mixtures](../../examples/Mixtures/slot_based). +For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation). ### DiscreteCardinalityConstraint Like its [continuous cousin](#ContinuousCardinalityConstraint), the diff --git a/docs/userguide/surrogates.md b/docs/userguide/surrogates.md index 781051573f..7464cb2866 100644 --- a/docs/userguide/surrogates.md +++ b/docs/userguide/surrogates.md @@ -6,6 +6,7 @@ the utilization of custom models. All surrogate models are based upon the genera [`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer learning, as indicated by the `supports_transfer_learning` attribute. + ## Available Models BayBE provides a comprehensive selection of surrogate models, empowering you to choose @@ -93,7 +94,6 @@ A noticeable difference to the replication approach is that manual assembly requ the exact set of target variables to be known at the time the object is created. - ## Extracting the Model for Advanced Study In principle, the surrogate model does not need to be a persistent object during @@ -117,6 +117,22 @@ shap_values = explainer(data) shap.plots.bar(shap_values) ~~~ + +(surrogate_data_augmentation)= +## Data Augmentation +In certain situations like [mixture modeling](/examples/Mixtures/slot_based), +symmetries are present. Data augmentation is a model-agnostic way of enabling the +surrogate model to learn such symmetries effectively, which might result in a better +performance, similar as e.g. for image classification models. BayBE +`Surrogate`[baybe.surrogates.base.Surrogate] models automatically perform data +augmentation if +{attr}`~baybe.surrogates.base.Surrogate.symmetries` with +`use_data_augmentation=True` are present. This means you can add a data point in +any acceptable representation and BayBE will train the model on this point plus +augmented points that can be generated from it. To see the effect in practice, refer to +[this example](/examples/Symmetries/permutation). + + ## Using Custom Models BayBE goes one step further by allowing you to incorporate custom models based on the diff --git a/docs/userguide/symmetries.md b/docs/userguide/symmetries.md new file mode 100644 index 0000000000..aa8c085202 --- /dev/null +++ b/docs/userguide/symmetries.md @@ -0,0 +1,51 @@ +# Symmetry +{class}`~baybe.symmetries.Symmetry` is a concept tied to the structure of the searchspace. +It is thus closely related to a {class}`~baybe.constraints.base.Constraint`, but has a +different purpose in BayBE. If the searchspace is symmetric in any sense, you can +exclude the degenerate parts via a constraint. But this would not change the modeling +process. The role of a {class}`~baybe.symmetries.Symmetry` is exactly this: Influence how +the surrogate model is constructed to include the knowledge about the symmetry. This +can be applied independently of constraints. For an example of the influence of +symmetries and constraints on the optimization of a permutation invariant function, +[see here](/examples/Symmetries/permutation). + +## Definitions +The following table summarizes available symmetries in BayBE: + +| Symmetry | Functional Definition | Corresponding Constraint | +|:-----------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------| +| {class}`~baybe.symmetries.PermutationSymmetry` | $f(x,y) = f(y,x)$ | {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` | +| {class}`~baybe.symmetries.DependencySymmetry` | $f(x,y) = \begin{cases}f(x,y) & \text{if }c(x) \\f(x) & \text{otherwise}\end{cases}$
where $c(x)$ is a condition that is either true or false | {class}`~baybe.constraints.discrete.DiscreteDependenciesConstraint` | +| {class}`~baybe.symmetries.MirrorSymmetry` | $f(x,y) = f(-x,y)$ | No constraint is available. Instead, the number range for that parameter can simply be restricted. | + +## Data Augmentation +This can be a powerful tool to improve the modeling process. Data augmentation +essentially changes the data that the model is fitted on by adding more points. The +augmented points are constructed such that they represent a symmetric point compared +with their original, which always corresponds to a different transformation depending +on which symmetry is responsible. + +If the surrogate model receives such augmented points, it can learn the symmetry. This +has the advantage that it can improve predictions for unseen points and is fully +model-agnostic. Downsides are increased training time and potential computational +challenges arising from a fit on substantially more points. It is thus possible to +control the data augmentation behavior of any {class}`~baybe.symmetries.Symmetry` by +setting its {attr}`~baybe.symmetries.Symmetry.use_data_augmentation` attribute +(`True` by default). + +Below we illustrate the effect of data augmentation for the different symmetries +supported by BayBE: + +![Symmetry and Data Augmentation](../_static/symmetries/augmentation.svg) + +## Invariant Kernels +Some machine learning models can be constructed with architectures that automatically +respect a symmetry, i.e. applying the model to an augmented point always produces the +same output as the original point by construction. + +For Gaussian processes, this can be achieved by applying special kernels. +```{admonition} Not Implemented Yet +:class: warning +Ideally, invariant kernels will be applied automatically when a corresponding symmetry has been +configured for the surrogate model GP. This feature is not implemented yet. +``` \ No newline at end of file diff --git a/docs/userguide/userguide.md b/docs/userguide/userguide.md index 059f6ca581..cefb6516f3 100644 --- a/docs/userguide/userguide.md +++ b/docs/userguide/userguide.md @@ -44,6 +44,7 @@ Serialization Settings Simulation Surrogates +Symmetries Targets Transformations Transfer Learning From 49a046e6a93feb45958c26cec640ca204037a8da Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:47:06 +0100 Subject: [PATCH 11/23] Add symmetry example --- examples/Symmetries/Symmetries_Header.md | 3 + examples/Symmetries/permutation.py | 215 +++++ examples/Symmetries/permutation.svg | 957 +++++++++++++++++++++++ 3 files changed, 1175 insertions(+) create mode 100644 examples/Symmetries/Symmetries_Header.md create mode 100644 examples/Symmetries/permutation.py create mode 100644 examples/Symmetries/permutation.svg diff --git a/examples/Symmetries/Symmetries_Header.md b/examples/Symmetries/Symmetries_Header.md new file mode 100644 index 0000000000..361d774e3d --- /dev/null +++ b/examples/Symmetries/Symmetries_Header.md @@ -0,0 +1,3 @@ +# Symmetries + +These examples demonstrate the impact of Symmetries. \ No newline at end of file diff --git a/examples/Symmetries/permutation.py b/examples/Symmetries/permutation.py new file mode 100644 index 0000000000..b8e86801e6 --- /dev/null +++ b/examples/Symmetries/permutation.py @@ -0,0 +1,215 @@ +# # Optimizing a Permutation-Invariant Function + +# In this example, we explore BayBE's capabilities for handling optimization problems +# with symmetry via automatic data augmentation and / or constraint. + +# ## Imports + +import os + +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.ticker import MaxNLocator + +from baybe import Campaign +from baybe.constraints import DiscretePermutationInvarianceConstraint +from baybe.parameters import NumericalDiscreteParameter +from baybe.recommenders import ( + BotorchRecommender, + TwoPhaseMetaRecommender, +) +from baybe.searchspace import SearchSpace +from baybe.simulation import simulate_scenarios +from baybe.surrogates import NGBoostSurrogate +from baybe.targets import NumericalTarget +from baybe.utils.random import set_random_seed + +# ## Settings + +set_random_seed(1337) +SMOKE_TEST = "SMOKE_TEST" in os.environ +N_MC_ITERATIONS = 2 if SMOKE_TEST else 100 +N_DOE_ITERATIONS = 2 if SMOKE_TEST else 50 + +# ## The Scenario + +# We will explore a 2-dimensional permutation-invariant function, i.e., it holds that +# $f(x,y) = f(y,x)$. The function was crafted to exhibit no additional mirror symmetry +# (a common way of also resulting in permutation invariance) and have multiple minima. +# In practice, permutation invariance can arise e.g. for +# [mixtures when modeled with a slot-based approach](/examples/Mixtures/slot_based). + +# There are several ways to handle such symmetries. The simplest one is +# to augment your data. In the case of permutation invariance, augmentation means for +# each measurement $(x,y)$ you also add a measurement with switched values, i.e., $(y,x)§. +# This has the advantage that it is fully model-agnostic, but might +# come at the expense of increased training time and efficiency due to the larger amount +# of effective training points. + + +LBOUND = -2.0 +UBOUND = 2.0 + + +def lookup(df: pd.DataFrame) -> pd.DataFrame: + """A lookup modeling a permutation-invariant 2D function with multiple minima.""" + x = df["x"].values + y = df["y"].values + result = ( + (x - y) ** 2 + + (x**3 + y**3) + + ((x**2 - 1) ** 2 + (y**2 - 1) ** 2) + + np.sin(3 * (x + y)) ** 2 + + np.sin(3 * np.abs(x - y)) ** 2 + ) + + df_z = pd.DataFrame({"f": result}, index=df.index) + return df_z + + +# Grid and dataframe for plotting +x = np.linspace(LBOUND, UBOUND, 25) +y = np.linspace(LBOUND, UBOUND, 25) +xx, yy = np.meshgrid(x, y) +df_plot = lookup(pd.DataFrame({"x": xx.ravel(), "y": yy.ravel()})) +zz = df_plot["f"].values.reshape(xx.shape) +line_vals = np.linspace(LBOUND, UBOUND, 2) + +# Plot the contour and diagonal +# fmt: off +fig, axs = plt.subplots(1, 2, figsize=(15, 6)) +contour = axs[0].contourf(xx, yy, zz, levels=50, cmap="viridis") +fig.colorbar(contour, ax=axs[0]) +axs[0].plot(line_vals, line_vals, "r--", alpha=0.5, linewidth=2) +axs[0].set_title("Ground Truth: $f(x, y)$ = $f(y, x)$ (Permutation Invariant)") +axs[0].set_xlabel("x") +axs[0].set_ylabel("y"); +# fmt: on + +# The plots can be found at the bottom of this file. +# The first subplot shows the function we want to minimize. The dashed red line +# illustrates the permutation invariance, which is similar to a mirror-symmetry, just +# not along any of the parameter axis but along the diagonal. We can also see several +# local minima. + +# Such a situation can be challenging for optimization algorithms if no information +# about the invariance is considered. For instance, if no +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` was used +# at all, BayBE would search for the optima across the entire 2D space. But it is clear +# that the search can be restricted to the lower (or equivalently the upper) triangle +# of the searchspace. This is exactly what +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` does: +# Remove entries that are "duplicated" in the sense of already being represented by +# another invariant point. + +# If the surrogate is additionally configured with `symmetries` that use +# `use_data_augmentation=True`, the model will be fit with an extended set of points, +# including augmented ones. So as a user, you don't have to generate permutations and +# add them manually. Depending on the surrogate model, this might have different +# impacts. For example, we can expect a strong effect for tree-based models because their splits are +# always parallel to the parameter axes. Thus, without augmented measurements, it is +# easy to fall into suboptimal splits and overfit. We illustrate this by using the +# {class}`~baybe.surrogates.ngboost.NGBoostSurrogate`. + +# ## The Optimization Problem + +p1 = NumericalDiscreteParameter("x", np.linspace(LBOUND, UBOUND, 51)) +p2 = NumericalDiscreteParameter("y", np.linspace(LBOUND, UBOUND, 51)) +objective = NumericalTarget("f", minimize=True).to_objective() + +# We set up a constrained and an unconstrained (plain) searchspace to demonstrate the +# impact of the constraint on optimization performance. + +constraint = DiscretePermutationInvarianceConstraint(["x", "y"]) +searchspace_plain = SearchSpace.from_product([p1, p2]) +searchspace_constrained = SearchSpace.from_product([p1, p2], [constraint]) + +print("Number of Points in the Searchspace") +print(f"{'Without Constraint:':<35} {len(searchspace_plain.discrete.exp_rep)}") +print(f"{'With Constraint:':<35} {len(searchspace_constrained.discrete.exp_rep)}") + +# We can see that the unconstrained searchspace has roughly twice as many points +# compared to the constrained one. This is expected, as the +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` +# effectively models only one half of the parameter triangle. Note that the factor is +# not exactly 2 due to the (still included) points on the diagonal. + +# BayBE can automatically perform this augmentation if configured to do so. +# Specifically, surrogate models have the +# {attr}`~baybe.surrogates.base.Surrogate.symmetries` attribute. If any of +# these symmetries has `use_data_augmentation=True` (enabled by default), +# BayBE will automatically augment measurements internally before performing the model +# fit. To construct symmetries quickly, we use the `to_symmetry` method of the +# constraint. + +symmetry = constraint.to_symmetry(use_data_augmentation=True) +recommender_plain = TwoPhaseMetaRecommender( + recommender=BotorchRecommender(surrogate_model=NGBoostSurrogate()) +) +recommender_symmetric = TwoPhaseMetaRecommender( + recommender=BotorchRecommender( + surrogate_model=NGBoostSurrogate(symmetries=[symmetry]) + ) +) + +# The combination of constraint and augmentation settings results in four different +# campaigns: + +campaign_plain = Campaign(searchspace_plain, objective, recommender_plain) +campaign_c = Campaign(searchspace_constrained, objective, recommender_plain) +campaign_s = Campaign(searchspace_plain, objective, recommender_symmetric) +campaign_cs = Campaign(searchspace_constrained, objective, recommender_symmetric) + +# ## Simulating the Optimization Loop + + +scenarios = { + "Unconstrained, Unsymmetric": campaign_plain, + "Constrained, Unsymmetric": campaign_c, + "Unconstrained, Symmetric": campaign_s, + "Constrained, Symmetric": campaign_cs, +} + +results = simulate_scenarios( + scenarios, + lookup, + n_doe_iterations=N_DOE_ITERATIONS, + n_mc_iterations=N_MC_ITERATIONS, +).rename( + columns={ + "f_CumBest": "$f(x,y)$ (cumulative best)", + "Num_Experiments": "# Experiments", + } +) + +# ## Results + +# Let us visualize the optimization process in the second subplot: + +sns.lineplot( + data=results, + x="# Experiments", + y="$f(x,y)$ (cumulative best)", + hue="Scenario", + marker="o", + ax=axs[1], +) +axs[1].xaxis.set_major_locator(MaxNLocator(integer=True)) +axs[1].set_ylim(axs[1].get_ylim()[0], 3) +axs[1].set_title("Minimization Performance") +plt.tight_layout() +plt.show() + +# We find that the campaigns utilizing the permutation invariance constraint +# perform better than the ones without. This can be attributed simply to the reduced +# number of searchspace points they operate on. However, this effect is rather minor +# compared to the effect of symmetry. + +# Furthermore, there is a strong impact on whether data augmentation is used or not, +# the effect we expected for a tree-based surrogate model. Indeed, the campaign with +# constraint but without augmentation is barely better than the campaign not utilizing +# the constraint at all. Conversely, the data-augmented campaign has a clearly superior +# performance. The best result is achieved by using both constraints and data +# augmentation. diff --git a/examples/Symmetries/permutation.svg b/examples/Symmetries/permutation.svg new file mode 100644 index 0000000000..c3c84a909c --- /dev/null +++ b/examples/Symmetries/permutation.svg @@ -0,0 +1,957 @@ + + + + + + + 2025-10-31T12:45:43.974408 + image/svg+xml + + + Matplotlib v3.10.7, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From bfb9fe9a5e05ac33da4da197b27f778b939d18c3 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:53:37 +0100 Subject: [PATCH 12/23] Handle CompositeSurrogate in symmetry integration The _autoreplicate converter on main wraps surrogates in a CompositeSurrogate. Access the inner template for symmetry validation and augmentation. --- baybe/recommenders/pure/bayesian/base.py | 28 +++++++++++++++++--- tests/validation/test_symmetry_validation.py | 2 +- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 093fd89600..2d06db2edc 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction return qLogNEHVI() if objective.is_multi_output else qLogEI() return self.acquisition_function + def _get_surrogate_for_augmentation(self) -> Surrogate | None: + """Get the Surrogate instance for augmentation/validation, if available.""" + from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping + + model = self._surrogate_model + if isinstance(model, Surrogate): + return model + if isinstance(model, CompositeSurrogate): + # All inner surrogates are copies of the same template + surrogates = model.surrogates + if isinstance(surrogates, _ReplicationMapping): + template = surrogates.template + if isinstance(template, Surrogate): + return template + return None + def get_surrogate( self, searchspace: SearchSpace, @@ -115,8 +131,11 @@ def _setup_botorch_acqf( ) # Perform data augmentation if configured - if hasattr(s := self._surrogate_model, "augment_measurements"): - measurements = s.augment_measurements(measurements, searchspace.parameters) + surrogate_for_augmentation = self._get_surrogate_for_augmentation() + if surrogate_for_augmentation is not None: + measurements = surrogate_for_augmentation.augment_measurements( + measurements, searchspace.parameters + ) surrogate = self.get_surrogate(searchspace, objective, measurements) self._botorch_acqf = acqf.to_botorch( @@ -161,8 +180,9 @@ def recommend( validate_object_names(searchspace.parameters + objective.targets) # Validate compatibility of surrogate symmetries with searchspace - if hasattr(self._surrogate_model, "symmetries"): - for s in self._surrogate_model.symmetries: + surrogate_for_validation = self._get_surrogate_for_augmentation() + if surrogate_for_validation is not None: + for s in surrogate_for_validation.symmetries: s.validate_searchspace_context(searchspace) # Experimental input validation diff --git a/tests/validation/test_symmetry_validation.py b/tests/validation/test_symmetry_validation.py index a6c3749027..52311dd11d 100644 --- a/tests/validation/test_symmetry_validation.py +++ b/tests/validation/test_symmetry_validation.py @@ -36,7 +36,7 @@ MirrorSymmetry, valid_config_mirror | {"mirror_point": np.inf}, ValueError, - "values containing infinity/nan to attribute 'mirror_point': inf", + "values containing infinity/nan to 'mirror_point': inf", id="mirror_nonfinite", ), param( From 7b9e1348fd8afed80b0f4ff4a0c6311fc9d1614f Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 17:55:14 +0100 Subject: [PATCH 13/23] Fix mypy errors in categorical validator and dependency type ignore --- baybe/parameters/categorical.py | 2 +- baybe/symmetries/dependency.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/baybe/parameters/categorical.py b/baybe/parameters/categorical.py index 85058f6b29..7ebe673f76 100644 --- a/baybe/parameters/categorical.py +++ b/baybe/parameters/categorical.py @@ -35,7 +35,7 @@ class CategoricalParameter(_DiscreteLabelLikeParameter): converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter normalize_convertible2str_sequence, takes_self=True, takes_field=True ), - validator=( + validator=( # type: ignore[arg-type] # mypy: validator tuple validate_unique_values, deep_iterable( member_validator=(instance_of((str, bool)), _validate_label_min_len), diff --git a/baybe/symmetries/dependency.py b/baybe/symmetries/dependency.py index 3fdadaeb94..9bd72ab8bf 100644 --- a/baybe/symmetries/dependency.py +++ b/baybe/symmetries/dependency.py @@ -45,7 +45,7 @@ class DependencySymmetry(Symmetry): converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter normalize_convertible2str_sequence, takes_self=True, takes_field=True ), - validator=( # type: ignore + validator=( validate_unique_values, deep_iterable( member_validator=instance_of(str), iterable_validator=min_len(1) From fb21269514bd0cc9b040eeeba07d70187940b3b2 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 18:28:12 +0100 Subject: [PATCH 14/23] Add symmetry validation tests --- tests/validation/test_symmetry_validation.py | 64 ++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/validation/test_symmetry_validation.py b/tests/validation/test_symmetry_validation.py index 52311dd11d..b4093eb77f 100644 --- a/tests/validation/test_symmetry_validation.py +++ b/tests/validation/test_symmetry_validation.py @@ -1,6 +1,7 @@ """Validation tests for symmetry.""" import numpy as np +import pandas as pd import pytest from pytest import param @@ -121,6 +122,55 @@ r"Entries appearing multiple times: \['a1'\].", id="dep_affected_not_unique", ), + param( + PermutationSymmetry, + {"permutation_groups": "abc"}, + ValueError, + "must be a sequence of sequences, not a string", + id="perm_groups_bare_string", + ), + param( + PermutationSymmetry, + {"permutation_groups": ["abc", "def"]}, + ValueError, + "must be a sequence of parameter names, not a string", + id="perm_groups_inner_bare_string", + ), + param( + MirrorSymmetry, + {"parameter_name": 123}, + TypeError, + "must be ", + id="mirror_param_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"n_discretization_points": 3.5}, + TypeError, + "must be ", + id="dep_n_discretization_not_int", + ), + param( + DependencySymmetry, + valid_config_dep | {"n_discretization_points": 1}, + ValueError, + "must be >= 2", + id="dep_n_discretization_too_small", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": "abc"}, + ValueError, + "must be a sequence but cannot be a string", + id="dep_affected_bare_string", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["a", "b"]], "use_data_augmentation": 1}, + TypeError, + "must be ", + id="use_aug_not_bool", + ), ], ) def test_configuration(cls, config, error, msg): @@ -225,3 +275,17 @@ def test_searchspace_context(searchspace, symmetry, error, msg): recommender.recommend( 1, searchspace, t.to_objective(), measurements=measurements ) + + +def test_dependency_augmentation_requires_parameters(): + """DependencySymmetry.augment_measurements raises when parameters is None.""" + s = DependencySymmetry(**valid_config_dep) + df = pd.DataFrame({"n1": [0], "n2": [1], "cat1": ["a"]}) + with pytest.raises(ValueError, match="requires parameter objects"): + s.augment_measurements(df) + + +def test_surrogate_rejects_non_symmetry(): + """Surrogate.symmetries rejects non-Symmetry members.""" + with pytest.raises(TypeError, match="must be Date: Fri, 20 Mar 2026 18:28:37 +0100 Subject: [PATCH 15/23] Update CHANGELOG --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ab1fa1c52..fd632487a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `identify_non_dominated_configurations` method to `Campaign` and `Objective` for determining the Pareto front - Interpoint constraints for continuous search spaces +- Symmetry classes (`PermutationSymmetry`, `MirrorSymmetry`, `DependencySymmetry`) + for expressing invariances and configuring surrogate data augmentation +- `Parameter.is_equivalent` method for structural parameter comparison ### Breaking Changes - `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples instead of a single tuple (needed for interpoint constraints) +- `df_apply_permutation_augmentation` has a different interface and now expects + permutation groups instead of column groups ### Fixed - `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to From 57d03e954857aaa73615b0ff389bf1653f047eff Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 19:07:04 +0100 Subject: [PATCH 16/23] Replace deprecated set_random_seed with Settings in example --- examples/Symmetries/permutation.py | 5 +- examples/Symmetries/permutation.svg | 8442 ++++++++++++++++++++++++--- 2 files changed, 7507 insertions(+), 940 deletions(-) diff --git a/examples/Symmetries/permutation.py b/examples/Symmetries/permutation.py index b8e86801e6..aa338dd1ef 100644 --- a/examples/Symmetries/permutation.py +++ b/examples/Symmetries/permutation.py @@ -13,7 +13,7 @@ from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator -from baybe import Campaign +from baybe import Campaign, Settings from baybe.constraints import DiscretePermutationInvarianceConstraint from baybe.parameters import NumericalDiscreteParameter from baybe.recommenders import ( @@ -24,11 +24,10 @@ from baybe.simulation import simulate_scenarios from baybe.surrogates import NGBoostSurrogate from baybe.targets import NumericalTarget -from baybe.utils.random import set_random_seed # ## Settings -set_random_seed(1337) +Settings(random_seed=1337).activate() SMOKE_TEST = "SMOKE_TEST" in os.environ N_MC_ITERATIONS = 2 if SMOKE_TEST else 100 N_DOE_ITERATIONS = 2 if SMOKE_TEST else 50 diff --git a/examples/Symmetries/permutation.svg b/examples/Symmetries/permutation.svg index c3c84a909c..ab17b0ac9a 100644 --- a/examples/Symmetries/permutation.svg +++ b/examples/Symmetries/permutation.svg @@ -1,14 +1,16 @@ - - + + + - + - 2025-10-31T12:45:43.974408 + 2026-03-20T19:16:57.550802 image/svg+xml - Matplotlib v3.10.7, https://matplotlib.org/ + Matplotlib v3.10.8, https://matplotlib.org/ @@ -17,941 +19,7507 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + - - + + From e8a97e86b58d2f013141aa535229f82ecdafc78c Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 19:41:39 +0100 Subject: [PATCH 17/23] Fix Sphinx cross-references for symmetry classes Use full module paths (e.g., baybe.symmetries.base.Symmetry) instead of short paths via __init__.py re-exports, which Sphinx cannot resolve. --- baybe/constraints/discrete.py | 11 ++++++----- docs/userguide/symmetries.md | 14 +++++++------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 91f5d8147d..9832a4ce53 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -29,7 +29,8 @@ if TYPE_CHECKING: import polars as pl - from baybe.symmetries import DependencySymmetry, PermutationSymmetry + from baybe.symmetries.dependency import DependencySymmetry + from baybe.symmetries.permutation import PermutationSymmetry @define @@ -272,8 +273,8 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: def to_symmetries( self, use_data_augmentation=True ) -> tuple[DependencySymmetry, ...]: - """Convert to a :class:`~baybe.symmetries.DependencySymmetry`.""" - from baybe.symmetries import DependencySymmetry + """Convert to a :class:`~baybe.symmetries.dependency.DependencySymmetry`.""" + from baybe.symmetries.dependency import DependencySymmetry return tuple( DependencySymmetry( @@ -350,8 +351,8 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_invalid def to_symmetry(self, use_data_augmentation=True) -> PermutationSymmetry: - """Convert to a :class:`~baybe.symmetries.PermutationSymmetry`.""" - from baybe.symmetries import PermutationSymmetry + """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`.""" + from baybe.symmetries.permutation import PermutationSymmetry groups = [self.parameters] if self.dependencies: diff --git a/docs/userguide/symmetries.md b/docs/userguide/symmetries.md index aa8c085202..d31c43d64d 100644 --- a/docs/userguide/symmetries.md +++ b/docs/userguide/symmetries.md @@ -1,9 +1,9 @@ # Symmetry -{class}`~baybe.symmetries.Symmetry` is a concept tied to the structure of the searchspace. +{class}`~baybe.symmetries.base.Symmetry` is a concept tied to the structure of the searchspace. It is thus closely related to a {class}`~baybe.constraints.base.Constraint`, but has a different purpose in BayBE. If the searchspace is symmetric in any sense, you can exclude the degenerate parts via a constraint. But this would not change the modeling -process. The role of a {class}`~baybe.symmetries.Symmetry` is exactly this: Influence how +process. The role of a {class}`~baybe.symmetries.base.Symmetry` is exactly this: Influence how the surrogate model is constructed to include the knowledge about the symmetry. This can be applied independently of constraints. For an example of the influence of symmetries and constraints on the optimization of a permutation invariant function, @@ -14,9 +14,9 @@ The following table summarizes available symmetries in BayBE: | Symmetry | Functional Definition | Corresponding Constraint | |:-----------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------| -| {class}`~baybe.symmetries.PermutationSymmetry` | $f(x,y) = f(y,x)$ | {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` | -| {class}`~baybe.symmetries.DependencySymmetry` | $f(x,y) = \begin{cases}f(x,y) & \text{if }c(x) \\f(x) & \text{otherwise}\end{cases}$
where $c(x)$ is a condition that is either true or false | {class}`~baybe.constraints.discrete.DiscreteDependenciesConstraint` | -| {class}`~baybe.symmetries.MirrorSymmetry` | $f(x,y) = f(-x,y)$ | No constraint is available. Instead, the number range for that parameter can simply be restricted. | +| {class}`~baybe.symmetries.permutation.PermutationSymmetry` | $f(x,y) = f(y,x)$ | {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` | +| {class}`~baybe.symmetries.dependency.DependencySymmetry` | $f(x,y) = \begin{cases}f(x,y) & \text{if }c(x) \\f(x) & \text{otherwise}\end{cases}$
where $c(x)$ is a condition that is either true or false | {class}`~baybe.constraints.discrete.DiscreteDependenciesConstraint` | +| {class}`~baybe.symmetries.mirror.MirrorSymmetry` | $f(x,y) = f(-x,y)$ | No constraint is available. Instead, the number range for that parameter can simply be restricted. | ## Data Augmentation This can be a powerful tool to improve the modeling process. Data augmentation @@ -29,8 +29,8 @@ If the surrogate model receives such augmented points, it can learn the symmetry has the advantage that it can improve predictions for unseen points and is fully model-agnostic. Downsides are increased training time and potential computational challenges arising from a fit on substantially more points. It is thus possible to -control the data augmentation behavior of any {class}`~baybe.symmetries.Symmetry` by -setting its {attr}`~baybe.symmetries.Symmetry.use_data_augmentation` attribute +control the data augmentation behavior of any {class}`~baybe.symmetries.base.Symmetry` by +setting its {attr}`~baybe.symmetries.base.Symmetry.use_data_augmentation` attribute (`True` by default). Below we illustrate the effect of data augmentation for the different symmetries From aa783fe35389163904aef3d126bd508b0f57a66e Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 20 Mar 2026 20:59:21 +0100 Subject: [PATCH 18/23] Fix bug in permutation constraint `DiscretePermutationInvarianceConstraint` was always internally applying a DiscreteNoLabelDuplicates constraint to remove the diagonal elements, which is not correct and can always be achieved separately by explicitly using `DiscreteNoLabelDuplicates` --- CHANGELOG.md | 2 ++ baybe/constraints/discrete.py | 25 ++++--------------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd632487a8..ca286f78fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to scalar conversion +- `DiscretePermutationInvarianceConstraint` no longer erroneously removes diagonal + points (e.g., where all permuted parameters have the same value) ### Changed - The `Campaign.allow_*` flag mechanism is now based on `AutoBool` logic, providing diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 9832a4ce53..a5df0eae13 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -294,9 +294,7 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint): """Constraint class for declaring that a set of parameters is permutation invariant. More precisely, this means that, ``(val_from_param1, val_from_param2)`` is - equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense - to have this constraint with duplicated labels, this implementation also internally - applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`. + equivalent to ``(val_from_param2, val_from_param1)``. *Note:* This constraint is evaluated during creation. In the future it might also be evaluated during modeling to make use of the invariance. @@ -308,15 +306,6 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint): @override def get_invalid(self, data: pd.DataFrame) -> pd.Index: - # Get indices of entries with duplicate label entries. These will also be - # dropped by this constraint. - mask_duplicate_labels = pd.Series(False, index=data.index) - mask_duplicate_labels[ - DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid( - data - ) - ] = True - # Merge a permutation invariant representation of all affected parameters with # the other parameters and indicate duplicates. This ensures that variation in # other parameters is also accounted for. @@ -327,20 +316,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: data[self.parameters].apply(cast(Callable, frozenset), axis=1), ], axis=1, - ).loc[ - ~mask_duplicate_labels # only consider label-duplicate-free part - ] + ) mask_duplicate_permutations = df_eval.duplicated(keep="first") - # Indices of entries with label-duplicates - inds_duplicate_labels = data.index[mask_duplicate_labels] - - # Indices of duplicate permutations in the (already label-duplicate-free) data - inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations] + # Indices of duplicate permutations + inds_invalid = data.index[mask_duplicate_permutations] # If there are dependencies connected to the invariant parameters evaluate them # here and remove resulting duplicates with a DependenciesConstraint - inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations) if self.dependencies: self.dependencies.permutation_invariant = True inds_duplicate_independency_adjusted = self.dependencies.get_invalid( From 36c47791ab30cf6aa65073a81cebb471875dfd60 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Thu, 2 Apr 2026 02:10:24 +0200 Subject: [PATCH 19/23] Improve docstring Co-authored-by: Alexander V. Hopp --- baybe/surrogates/base.py | 4 ++-- baybe/symmetries/base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 5e93b48919..2b85b6d8e1 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -134,8 +134,8 @@ def augment_measurements( Args: measurements: A dataframe with measurements. - parameters: Parameter objects carrying additional information (might - not be needed by all augmentation implementations). + parameters: Optional parameter objects carrying additional information. + Only required by specific augmentation implementations. Returns: A dataframe with the augmented measurements, including the original diff --git a/baybe/symmetries/base.py b/baybe/symmetries/base.py index 96d8d80324..c41e4e5e3b 100644 --- a/baybe/symmetries/base.py +++ b/baybe/symmetries/base.py @@ -55,8 +55,8 @@ def augment_measurements( Args: measurements: The dataframe containing the measurements to be augmented. - parameters: Corresponding parameter objects carrying additional - information (not needed by all augmentation types). + parameters: Optional parameter objects carrying additional information. + Only required by specific augmentation implementations. Returns: The augmented dataframe including the original measurements. From c853cac4cd2f994a292f236691489145df626029 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 10 Apr 2026 20:51:09 +0200 Subject: [PATCH 20/23] Add docstrings to to_symmetries and to_symmetry methods --- baybe/constraints/discrete.py | 36 +++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index a5df0eae13..d291127e5e 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -271,9 +271,23 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_bad def to_symmetries( - self, use_data_augmentation=True + self, use_data_augmentation: bool = True ) -> tuple[DependencySymmetry, ...]: - """Convert to a :class:`~baybe.symmetries.dependency.DependencySymmetry`.""" + """Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects. + + Create one symmetry object per dependency relationship, i.e., per + (parameter, condition, affected_parameters) triple. + + Args: + use_data_augmentation: Flag indicating whether the resulting symmetry + objects should apply data augmentation. ``True`` means that + measurement augmentation will be performed by replacing inactive + affected parameter values with all possible values. + + Returns: + A tuple of dependency symmetries, one for each dependency in the + constraint. + """ from baybe.symmetries.dependency import DependencySymmetry return tuple( @@ -333,8 +347,22 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_invalid - def to_symmetry(self, use_data_augmentation=True) -> PermutationSymmetry: - """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`.""" + def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry: + """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`. + + The constraint's parameters form the primary permutation group. If + dependencies are attached, their parameters are added as an additional + group that is permuted in lockstep. + + Args: + use_data_augmentation: Flag indicating whether the resulting symmetry + object should apply data augmentation. ``True`` means that + measurement augmentation will be performed by generating all + permutations of parameter values within each group. + + Returns: + The corresponding permutation symmetry. + """ from baybe.symmetries.permutation import PermutationSymmetry groups = [self.parameters] From 9f076b6762ea7026b17a4aa6d303c5295ff5d09d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 10 Apr 2026 20:58:08 +0200 Subject: [PATCH 21/23] Add use_data_augmentation to symmetry summary --- baybe/symmetries/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/baybe/symmetries/base.py b/baybe/symmetries/base.py index c41e4e5e3b..780eae358d 100644 --- a/baybe/symmetries/base.py +++ b/baybe/symmetries/base.py @@ -40,7 +40,9 @@ def parameter_names(self) -> tuple[str, ...]: def summary(self) -> dict: """Return a custom summarization of the symmetry.""" symmetry_dict = dict( - Type=self.__class__.__name__, Affected_Parameters=self.parameter_names + Type=self.__class__.__name__, + Affected_Parameters=self.parameter_names, + Data_Augmentation=self.use_data_augmentation, ) return symmetry_dict From be1c293b552b02b0385f0d7cac6598787d0ec329 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 10 Apr 2026 21:12:14 +0200 Subject: [PATCH 22/23] Rework imports --- baybe/surrogates/gaussian_process/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index bfffaf5a48..a2fddc7a33 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -142,9 +142,10 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: - import botorch.models.transforms import gpytorch import torch + from botorch.fit import fit_gpytorch_mll + from botorch.models import SingleTaskGP from botorch.models.transforms import Normalize, Standardize # FIXME[typing]: It seems there is currently no better way to inform the type @@ -197,7 +198,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: likelihood.noise = torch.tensor([noise_prior[1]]) # construct and fit the Gaussian process - self._model = botorch.models.SingleTaskGP( + self._model = SingleTaskGP( train_x, train_y, input_transform=input_transform, @@ -218,7 +219,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: self._model.likelihood, self._model ) - botorch.fit.fit_gpytorch_mll(mll) + fit_gpytorch_mll(mll) @override def __str__(self) -> str: From f47c573bac2c6d9d3d912adc71ca9ae6bbf566fd Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Fri, 24 Apr 2026 10:21:40 +0200 Subject: [PATCH 23/23] Improve example --- examples/Symmetries/permutation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/Symmetries/permutation.py b/examples/Symmetries/permutation.py index aa338dd1ef..a909aba0ad 100644 --- a/examples/Symmetries/permutation.py +++ b/examples/Symmetries/permutation.py @@ -1,7 +1,7 @@ # # Optimizing a Permutation-Invariant Function # In this example, we explore BayBE's capabilities for handling optimization problems -# with symmetry via automatic data augmentation and / or constraint. +# with symmetry via automatic data augmentation and/or constraint. # ## Imports @@ -68,7 +68,6 @@ def lookup(df: pd.DataFrame) -> pd.DataFrame: return df_z -# Grid and dataframe for plotting x = np.linspace(LBOUND, UBOUND, 25) y = np.linspace(LBOUND, UBOUND, 25) xx, yy = np.meshgrid(x, y) @@ -76,7 +75,6 @@ def lookup(df: pd.DataFrame) -> pd.DataFrame: zz = df_plot["f"].values.reshape(xx.shape) line_vals = np.linspace(LBOUND, UBOUND, 2) -# Plot the contour and diagonal # fmt: off fig, axs = plt.subplots(1, 2, figsize=(15, 6)) contour = axs[0].contourf(xx, yy, zz, levels=50, cmap="viridis")