diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3ab1fa1c52..ca286f78fc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,14 +13,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `identify_non_dominated_configurations` method to `Campaign` and `Objective`
for determining the Pareto front
- Interpoint constraints for continuous search spaces
+- Symmetry classes (`PermutationSymmetry`, `MirrorSymmetry`, `DependencySymmetry`)
+ for expressing invariances and configuring surrogate data augmentation
+- `Parameter.is_equivalent` method for structural parameter comparison
### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
+- `df_apply_permutation_augmentation` has a different interface and now expects
+ permutation groups instead of column groups
### Fixed
- `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to
scalar conversion
+- `DiscretePermutationInvarianceConstraint` no longer erroneously removes diagonal
+ points (e.g., where all permuted parameters have the same value)
### Changed
- The `Campaign.allow_*` flag mechanism is now based on `AutoBool` logic, providing
diff --git a/baybe/constraints/base.py b/baybe/constraints/base.py
index 5c1a6d33ed..b38548efcc 100644
--- a/baybe/constraints/base.py
+++ b/baybe/constraints/base.py
@@ -37,10 +37,6 @@ class Constraint(ABC, SerialMixin):
eval_during_modeling: ClassVar[bool]
"""Class variable encoding whether the condition is evaluated during modeling."""
- eval_during_augmentation: ClassVar[bool] = False
- """Class variable encoding whether the constraint could be considered during data
- augmentation."""
-
numerical_only: ClassVar[bool] = False
"""Class variable encoding whether the constraint is valid only for numerical
parameters."""
diff --git a/baybe/constraints/conditions.py b/baybe/constraints/conditions.py
index a136f6e170..3dd52b5d78 100644
--- a/baybe/constraints/conditions.py
+++ b/baybe/constraints/conditions.py
@@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin):
"""Abstract base class for all conditions.
Conditions always evaluate an expression regarding a single parameter.
- Conditions are part of constraints, a constraint can have multiple conditions.
+ Conditions are part of constraints and symmetries.
"""
@abstractmethod
diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py
index 740e603f89..d291127e5e 100644
--- a/baybe/constraints/discrete.py
+++ b/baybe/constraints/discrete.py
@@ -29,6 +29,9 @@
if TYPE_CHECKING:
import polars as pl
+ from baybe.symmetries.dependency import DependencySymmetry
+ from baybe.symmetries.permutation import PermutationSymmetry
+
@define
class DiscreteExcludeConstraint(DiscreteConstraint):
@@ -195,10 +198,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint):
a single constraint.
"""
- # class variables
- eval_during_augmentation: ClassVar[bool] = True
- # See base class
-
# object variables
conditions: list[Condition] = field()
"""The list of individual conditions."""
@@ -271,39 +270,56 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:
return inds_bad
+ def to_symmetries(
+ self, use_data_augmentation: bool = True
+ ) -> tuple[DependencySymmetry, ...]:
+ """Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects.
+
+ Create one symmetry object per dependency relationship, i.e., per
+ (parameter, condition, affected_parameters) triple.
+
+ Args:
+ use_data_augmentation: Flag indicating whether the resulting symmetry
+ objects should apply data augmentation. ``True`` means that
+ measurement augmentation will be performed by replacing inactive
+ affected parameter values with all possible values.
+
+ Returns:
+ A tuple of dependency symmetries, one for each dependency in the
+ constraint.
+ """
+ from baybe.symmetries.dependency import DependencySymmetry
+
+ return tuple(
+ DependencySymmetry(
+ parameter_name=p,
+ condition=c,
+ affected_parameter_names=aps,
+ use_data_augmentation=use_data_augmentation,
+ )
+ for p, c, aps in zip(
+ self.parameters, self.conditions, self.affected_parameters, strict=True
+ )
+ )
+
@define
class DiscretePermutationInvarianceConstraint(DiscreteConstraint):
"""Constraint class for declaring that a set of parameters is permutation invariant.
More precisely, this means that, ``(val_from_param1, val_from_param2)`` is
- equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense
- to have this constraint with duplicated labels, this implementation also internally
- applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`.
+ equivalent to ``(val_from_param2, val_from_param1)``.
*Note:* This constraint is evaluated during creation. In the future it might also be
evaluated during modeling to make use of the invariance.
"""
- # class variables
- eval_during_augmentation: ClassVar[bool] = True
- # See base class
-
# object variables
dependencies: DiscreteDependenciesConstraint | None = field(default=None)
"""Dependencies connected with the invariant parameters."""
@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
- # Get indices of entries with duplicate label entries. These will also be
- # dropped by this constraint.
- mask_duplicate_labels = pd.Series(False, index=data.index)
- mask_duplicate_labels[
- DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid(
- data
- )
- ] = True
-
# Merge a permutation invariant representation of all affected parameters with
# the other parameters and indicate duplicates. This ensures that variation in
# other parameters is also accounted for.
@@ -314,20 +330,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:
data[self.parameters].apply(cast(Callable, frozenset), axis=1),
],
axis=1,
- ).loc[
- ~mask_duplicate_labels # only consider label-duplicate-free part
- ]
+ )
mask_duplicate_permutations = df_eval.duplicated(keep="first")
- # Indices of entries with label-duplicates
- inds_duplicate_labels = data.index[mask_duplicate_labels]
-
- # Indices of duplicate permutations in the (already label-duplicate-free) data
- inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations]
+ # Indices of duplicate permutations
+ inds_invalid = data.index[mask_duplicate_permutations]
# If there are dependencies connected to the invariant parameters evaluate them
# here and remove resulting duplicates with a DependenciesConstraint
- inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations)
if self.dependencies:
self.dependencies.permutation_invariant = True
inds_duplicate_independency_adjusted = self.dependencies.get_invalid(
@@ -337,6 +347,32 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:
return inds_invalid
+ def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry:
+ """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`.
+
+ The constraint's parameters form the primary permutation group. If
+ dependencies are attached, their parameters are added as an additional
+ group that is permuted in lockstep.
+
+ Args:
+ use_data_augmentation: Flag indicating whether the resulting symmetry
+ object should apply data augmentation. ``True`` means that
+ measurement augmentation will be performed by generating all
+ permutations of parameter values within each group.
+
+ Returns:
+ The corresponding permutation symmetry.
+ """
+ from baybe.symmetries.permutation import PermutationSymmetry
+
+ groups = [self.parameters]
+ if self.dependencies:
+ groups.append(list(self.dependencies.parameters))
+ return PermutationSymmetry(
+ permutation_groups=groups,
+ use_data_augmentation=use_data_augmentation,
+ )
+
@define
class DiscreteCustomConstraint(DiscreteConstraint):
diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py
index 2d4df2bc77..9ae13e9a98 100644
--- a/baybe/parameters/base.py
+++ b/baybe/parameters/base.py
@@ -7,6 +7,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar
+import attrs
import pandas as pd
from attrs import define, field
from attrs.converters import optional as optional_c
@@ -88,6 +89,22 @@ def to_searchspace(self) -> SearchSpace:
return SearchSpace.from_parameter(self)
+ def is_equivalent(self, other: Parameter) -> bool:
+ """Check if this parameter is equivalent to another, ignoring the name.
+
+ Two parameters are considered equivalent if they have the same type and
+ all attributes are equal except for the name.
+
+ Args:
+ other: The parameter to compare against.
+
+ Returns:
+ ``True`` if the parameters are equivalent, ``False`` otherwise.
+ """
+ if type(self) is not type(other):
+ return False
+ return attrs.evolve(self, name=other.name) == other
+
@abstractmethod
def summary(self) -> dict:
"""Return a custom summarization of the parameter."""
diff --git a/baybe/parameters/categorical.py b/baybe/parameters/categorical.py
index bffecd5682..7ebe673f76 100644
--- a/baybe/parameters/categorical.py
+++ b/baybe/parameters/categorical.py
@@ -13,13 +13,7 @@
from baybe.parameters.enum import CategoricalEncoding
from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
-from baybe.utils.conversion import nonstring_to_tuple
-
-
-def _convert_values(value, self, field) -> tuple[str, ...]:
- """Sort and convert values for categorical parameters."""
- value = nonstring_to_tuple(value, self, field)
- return tuple(sorted(value, key=lambda x: (str(type(x)), x)))
+from baybe.utils.conversion import normalize_convertible2str_sequence
def _validate_label_min_len(self, attr, value) -> None:
@@ -38,8 +32,10 @@ class CategoricalParameter(_DiscreteLabelLikeParameter):
# object variables
_values: tuple[str | bool, ...] = field(
alias="values",
- converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore
- validator=(
+ converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter
+ normalize_convertible2str_sequence, takes_self=True, takes_field=True
+ ),
+ validator=( # type: ignore[arg-type] # mypy: validator tuple
validate_unique_values,
deep_iterable(
member_validator=(instance_of((str, bool)), _validate_label_min_len),
diff --git a/baybe/parameters/numerical.py b/baybe/parameters/numerical.py
index ba210de244..e56ca2c3f9 100644
--- a/baybe/parameters/numerical.py
+++ b/baybe/parameters/numerical.py
@@ -13,9 +13,10 @@
from baybe.exceptions import NumericalUnderflowError
from baybe.parameters.base import ContinuousParameter, DiscreteParameter
-from baybe.parameters.validation import validate_is_finite, validate_unique_values
+from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
from baybe.utils.interval import InfiniteIntervalError, Interval
+from baybe.utils.validation import validate_is_finite
@define(frozen=True, slots=False)
diff --git a/baybe/parameters/validation.py b/baybe/parameters/validation.py
index 5c70ed5ecf..0367b2b30f 100644
--- a/baybe/parameters/validation.py
+++ b/baybe/parameters/validation.py
@@ -1,9 +1,7 @@
"""Validation functionality for parameters."""
-from collections.abc import Sequence
from typing import Any
-import numpy as np
from attrs.validators import gt, instance_of, lt
@@ -28,18 +26,3 @@ def validate_decorrelation(obj: Any, attribute: Any, value: float) -> None:
if isinstance(value, float):
gt(0.0)(obj, attribute, value)
lt(1.0)(obj, attribute, value)
-
-
-def validate_is_finite( # noqa: DOC101, DOC103
- obj: Any, _: Any, value: Sequence[float]
-) -> None:
- """Validate that ``value`` contains no infinity/nan.
-
- Raises:
- ValueError: If ``value`` contains infinity/nan.
- """
- if not all(np.isfinite(value)):
- raise ValueError(
- f"Cannot assign the following values containing infinity/nan to "
- f"parameter {obj.name}: {value}."
- )
diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py
index 4ac5c1eed2..2d06db2edc 100644
--- a/baybe/recommenders/pure/bayesian/base.py
+++ b/baybe/recommenders/pure/bayesian/base.py
@@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction
return qLogNEHVI() if objective.is_multi_output else qLogEI()
return self.acquisition_function
+ def _get_surrogate_for_augmentation(self) -> Surrogate | None:
+ """Get the Surrogate instance for augmentation/validation, if available."""
+ from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping
+
+ model = self._surrogate_model
+ if isinstance(model, Surrogate):
+ return model
+ if isinstance(model, CompositeSurrogate):
+ # All inner surrogates are copies of the same template
+ surrogates = model.surrogates
+ if isinstance(surrogates, _ReplicationMapping):
+ template = surrogates.template
+ if isinstance(template, Surrogate):
+ return template
+ return None
+
def get_surrogate(
self,
searchspace: SearchSpace,
@@ -114,6 +130,13 @@ def _setup_botorch_acqf(
f"{len(objective.targets)}-target multi-output context."
)
+ # Perform data augmentation if configured
+ surrogate_for_augmentation = self._get_surrogate_for_augmentation()
+ if surrogate_for_augmentation is not None:
+ measurements = surrogate_for_augmentation.augment_measurements(
+ measurements, searchspace.parameters
+ )
+
surrogate = self.get_surrogate(searchspace, objective, measurements)
self._botorch_acqf = acqf.to_botorch(
surrogate,
@@ -156,6 +179,13 @@ def recommend(
validate_object_names(searchspace.parameters + objective.targets)
+ # Validate compatibility of surrogate symmetries with searchspace
+ surrogate_for_validation = self._get_surrogate_for_augmentation()
+ if surrogate_for_validation is not None:
+ for s in surrogate_for_validation.symmetries:
+ s.validate_searchspace_context(searchspace)
+
+ # Experimental input validation
if (measurements is None) or measurements.empty:
raise NotImplementedError(
f"Recommenders of type '{BayesianRecommender.__name__}' do not support "
diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py
index 8b0da30c92..34e5f64047 100644
--- a/baybe/searchspace/core.py
+++ b/baybe/searchspace/core.py
@@ -381,11 +381,6 @@ def transform(
return comp_rep
- @property
- def constraints_augmentable(self) -> tuple[Constraint, ...]:
- """The searchspace constraints that can be considered during augmentation."""
- return tuple(c for c in self.constraints if c.eval_during_augmentation)
-
def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]:
"""Return parameters with the specified names.
diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py
index 205e32f703..2b85b6d8e1 100644
--- a/baybe/surrogates/base.py
+++ b/baybe/surrogates/base.py
@@ -4,12 +4,13 @@
import gc
from abc import ABC, abstractmethod
-from collections.abc import Sequence
+from collections.abc import Iterable, Sequence
from enum import Enum, auto
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias
import pandas as pd
from attrs import define, field
+from attrs.validators import deep_iterable, instance_of
from joblib.hashing import hash
from typing_extensions import override
@@ -18,6 +19,7 @@
from baybe.parameters.base import Parameter
from baybe.searchspace import SearchSpace
from baybe.serialization.mixin import SerialMixin
+from baybe.symmetries import Symmetry
from baybe.utils.basic import classproperty
from baybe.utils.conversion import to_string
from baybe.utils.dataframe import handle_missing_values, to_tensor
@@ -90,6 +92,14 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin):
"""Class variable encoding whether or not the surrogate is multi-output
compatible."""
+ symmetries: tuple[Symmetry, ...] = field(
+ factory=tuple,
+ converter=tuple,
+ validator=deep_iterable(member_validator=instance_of(Symmetry)),
+ kw_only=True,
+ )
+ """Symmetries to be considered by the surrogate model."""
+
_searchspace: SearchSpace | None = field(init=False, default=None, eq=False)
"""The search space on which the surrogate operates. Available after fitting."""
@@ -115,6 +125,27 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin):
Scales a tensor containing target measurements in computational representation
to make them digestible for the model-specific, scale-agnostic posterior logic."""
+ def augment_measurements(
+ self,
+ measurements: pd.DataFrame,
+ parameters: Iterable[Parameter] | None = None,
+ ) -> pd.DataFrame:
+ """Apply data augmentation to measurements.
+
+ Args:
+ measurements: A dataframe with measurements.
+ parameters: Optional parameter objects carrying additional information.
+ Only required by specific augmentation implementations.
+
+ Returns:
+ A dataframe with the augmented measurements, including the original
+ ones.
+ """
+ for s in self.symmetries:
+ measurements = s.augment_measurements(measurements, parameters)
+
+ return measurements
+
@classproperty
def is_available(cls) -> bool:
"""Indicates if the surrogate class is available in the Python environment.
diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py
index c0148aca55..a2fddc7a33 100644
--- a/baybe/surrogates/gaussian_process/core.py
+++ b/baybe/surrogates/gaussian_process/core.py
@@ -142,9 +142,10 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
- import botorch
import gpytorch
import torch
+ from botorch.fit import fit_gpytorch_mll
+ from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
# FIXME[typing]: It seems there is currently no better way to inform the type
@@ -197,7 +198,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
likelihood.noise = torch.tensor([noise_prior[1]])
# construct and fit the Gaussian process
- self._model = botorch.models.SingleTaskGP(
+ self._model = SingleTaskGP(
train_x,
train_y,
input_transform=input_transform,
@@ -218,7 +219,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
self._model.likelihood, self._model
)
- botorch.fit.fit_gpytorch_mll(mll)
+ fit_gpytorch_mll(mll)
@override
def __str__(self) -> str:
diff --git a/baybe/symmetries/__init__.py b/baybe/symmetries/__init__.py
new file mode 100644
index 0000000000..abbeedd974
--- /dev/null
+++ b/baybe/symmetries/__init__.py
@@ -0,0 +1,13 @@
+"""Symmetry classes for expressing invariances of the modeling process."""
+
+from baybe.symmetries.base import Symmetry
+from baybe.symmetries.dependency import DependencySymmetry
+from baybe.symmetries.mirror import MirrorSymmetry
+from baybe.symmetries.permutation import PermutationSymmetry
+
+__all__ = [
+ "DependencySymmetry",
+ "MirrorSymmetry",
+ "PermutationSymmetry",
+ "Symmetry",
+]
diff --git a/baybe/symmetries/base.py b/baybe/symmetries/base.py
new file mode 100644
index 0000000000..780eae358d
--- /dev/null
+++ b/baybe/symmetries/base.py
@@ -0,0 +1,89 @@
+"""Base class for symmetries."""
+
+from __future__ import annotations
+
+import gc
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from typing import TYPE_CHECKING
+
+import pandas as pd
+from attrs import define, field
+from attrs.validators import instance_of
+
+from baybe.exceptions import IncompatibleSearchSpaceError
+from baybe.serialization import SerialMixin
+
+if TYPE_CHECKING:
+ from baybe.parameters.base import Parameter
+ from baybe.searchspace import SearchSpace
+
+
+@define(frozen=True)
+class Symmetry(SerialMixin, ABC):
+ """Abstract base class for symmetries.
+
+ A ``Symmetry`` is a concept that can be used to configure the modeling process in
+ the presence of invariances.
+ """
+
+ use_data_augmentation: bool = field(
+ default=True, validator=instance_of(bool), kw_only=True
+ )
+ """Flag indicating whether data augmentation is to be used."""
+
+ @property
+ @abstractmethod
+ def parameter_names(self) -> tuple[str, ...]:
+ """The names of the parameters affected by the symmetry."""
+
+ def summary(self) -> dict:
+ """Return a custom summarization of the symmetry."""
+ symmetry_dict = dict(
+ Type=self.__class__.__name__,
+ Affected_Parameters=self.parameter_names,
+ Data_Augmentation=self.use_data_augmentation,
+ )
+ return symmetry_dict
+
+ @abstractmethod
+ def augment_measurements(
+ self,
+ measurements: pd.DataFrame,
+ parameters: Iterable[Parameter] | None = None,
+ ) -> pd.DataFrame:
+ """Augment the given measurements according to the symmetry.
+
+ Args:
+ measurements: The dataframe containing the measurements to be
+ augmented.
+ parameters: Optional parameter objects carrying additional information.
+ Only required by specific augmentation implementations.
+
+ Returns:
+ The augmented dataframe including the original measurements.
+ """
+
+ def validate_searchspace_context(self, searchspace: SearchSpace) -> None:
+ """Validate that the symmetry is compatible with the given searchspace.
+
+ Args:
+ searchspace: The searchspace to validate against.
+
+ Raises:
+ IncompatibleSearchSpaceError: If the symmetry affects parameters not
+ present in the searchspace.
+ """
+ parameters_missing = set(self.parameter_names).difference(
+ searchspace.parameter_names
+ )
+ if parameters_missing:
+ raise IncompatibleSearchSpaceError(
+ f"The symmetry of type '{self.__class__.__name__}' was set up with the "
+ f"following parameters that are not present in the search space: "
+ f"{parameters_missing}."
+ )
+
+
+# Collect leftover original slotted classes processed by `attrs.define`
+gc.collect()
diff --git a/baybe/symmetries/dependency.py b/baybe/symmetries/dependency.py
new file mode 100644
index 0000000000..9bd72ab8bf
--- /dev/null
+++ b/baybe/symmetries/dependency.py
@@ -0,0 +1,169 @@
+"""Dependency symmetry."""
+
+from __future__ import annotations
+
+import gc
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, cast
+
+import numpy as np
+import pandas as pd
+from attrs import Converter, define, field
+from attrs.validators import deep_iterable, ge, instance_of, min_len
+from typing_extensions import override
+
+from baybe.constraints.conditions import Condition
+from baybe.exceptions import IncompatibleSearchSpaceError
+from baybe.symmetries.base import Symmetry
+from baybe.utils.augmentation import df_apply_dependency_augmentation
+from baybe.utils.conversion import normalize_convertible2str_sequence
+from baybe.utils.validation import validate_unique_values
+
+if TYPE_CHECKING:
+ from baybe.parameters.base import Parameter
+ from baybe.searchspace import SearchSpace
+
+
+@define(frozen=True)
+class DependencySymmetry(Symmetry):
+ """Class for representing dependency symmetries.
+
+ A dependency symmetry expresses that certain parameters are dependent on another
+ parameter having a specific value. For instance, the situation "The value of
+ parameter y only matters if parameter x has the value 'on'.". In this scenario x
+ is the causing parameter and y depends on x.
+ """
+
+ _parameter_name: str = field(validator=instance_of(str), alias="parameter_name")
+ """The names of the causing parameter others are depending on."""
+
+ # object variables
+ condition: Condition = field(validator=instance_of(Condition))
+ """The condition specifying the active range of the causing parameter."""
+
+ affected_parameter_names: tuple[str, ...] = field(
+ converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter
+ normalize_convertible2str_sequence, takes_self=True, takes_field=True
+ ),
+ validator=(
+ validate_unique_values,
+ deep_iterable(
+ member_validator=instance_of(str), iterable_validator=min_len(1)
+ ),
+ ),
+ )
+ """The parameters affected by the dependency."""
+
+ n_discretization_points: int = field(
+ default=3, validator=(instance_of(int), ge(2)), kw_only=True
+ )
+ """Number of points used when subsampling continuous parameter ranges."""
+
+ @override
+ @property
+ def parameter_names(self) -> tuple[str, ...]:
+ return (self._parameter_name,)
+
+ @override
+ def augment_measurements(
+ self,
+ measurements: pd.DataFrame,
+ parameters: Iterable[Parameter] | None = None,
+ ) -> pd.DataFrame:
+ # See base class.
+ if not self.use_data_augmentation:
+ return measurements
+
+ if parameters is None:
+ raise ValueError(
+ f"A '{self.__class__.__name__}' requires parameter objects "
+ f"for data augmentation."
+ )
+
+ from baybe.parameters.base import DiscreteParameter
+
+ # The 'causing' entry describes the parameters and the value
+ # for which one or more affected parameters become degenerate.
+ # 'cond' specifies for which values the affected parameter
+ # values are active, i.e. not degenerate. Hence, here we get the
+ # values that are not active, as rows containing them should be
+ # augmented.
+ param = next(
+ cast(DiscreteParameter, p)
+ for p in parameters
+ if p.name == self._parameter_name
+ )
+
+ causing_values = [
+ x
+ for x, flag in zip(
+ param.values,
+ ~self.condition.evaluate(pd.Series(param.values)),
+ strict=True,
+ )
+ if flag
+ ]
+ causing = (param.name, causing_values)
+
+ # The 'affected' entry describes the affected parameters and the
+ # values they are allowed to take, which are all degenerate if
+ # the corresponding condition for the causing parameter is met.
+ affected: list[tuple[str, tuple[float, ...]]] = []
+ for pn in self.affected_parameter_names:
+ p = next(p for p in parameters if p.name == pn)
+ if p.is_discrete:
+ # Use all values for augmentation
+ vals = cast(DiscreteParameter, p).values
+ else:
+ # Use linear subsample of parameter bounds interval for augmentation.
+ # Note: The original value will not necessarily be part of this.
+ vals = tuple(
+ np.linspace(
+ p.bounds.lower, # type: ignore[attr-defined]
+ p.bounds.upper, # type: ignore[attr-defined]
+ self.n_discretization_points,
+ )
+ )
+ affected.append((p.name, vals))
+
+ measurements = df_apply_dependency_augmentation(measurements, causing, affected)
+
+ return measurements
+
+ @override
+ def validate_searchspace_context(self, searchspace: SearchSpace) -> None:
+ """See base class.
+
+ Args:
+ searchspace: The searchspace to validate against.
+
+ Raises:
+ IncompatibleSearchSpaceError: If any of the affected parameters is
+ not present in the searchspace.
+ TypeError: If the causing parameter is not discrete.
+ """
+ super().validate_searchspace_context(searchspace)
+
+ # Affected parameters must be in the searchspace
+ parameters_missing = set(self.affected_parameter_names).difference(
+ searchspace.parameter_names
+ )
+ if parameters_missing:
+ raise IncompatibleSearchSpaceError(
+ f"The symmetry of type '{self.__class__.__name__}' was set up "
+ f"with at least one parameter which is not present in the "
+ f"search space: {parameters_missing}."
+ )
+
+ # Causing parameter must be discrete
+ param = searchspace.get_parameters_by_name(self._parameter_name)[0]
+ if not param.is_discrete:
+ raise TypeError(
+ f"In a '{self.__class__.__name__}', the causing parameter must "
+ f"be discrete. However, the parameter '{param.name}' is of "
+ f"type '{param.__class__.__name__}' and is not discrete."
+ )
+
+
+# Collect leftover original slotted classes processed by `attrs.define`
+gc.collect()
diff --git a/baybe/symmetries/mirror.py b/baybe/symmetries/mirror.py
new file mode 100644
index 0000000000..9a7ac7d06f
--- /dev/null
+++ b/baybe/symmetries/mirror.py
@@ -0,0 +1,86 @@
+"""Mirror symmetry."""
+
+from __future__ import annotations
+
+import gc
+from collections.abc import Iterable
+from typing import TYPE_CHECKING
+
+import pandas as pd
+from attrs import define, field
+from attrs.validators import instance_of
+from typing_extensions import override
+
+from baybe.symmetries.base import Symmetry
+from baybe.utils.augmentation import df_apply_mirror_augmentation
+from baybe.utils.validation import validate_is_finite
+
+if TYPE_CHECKING:
+ from baybe.parameters.base import Parameter
+ from baybe.searchspace import SearchSpace
+
+
+@define(frozen=True)
+class MirrorSymmetry(Symmetry):
+ """Class for representing mirror symmetries.
+
+ A mirror symmetry expresses that certain parameters can be inflected at a mirror
+ point without affecting the outcome of the model. For instance, when specified
+ for parameter ``x`` and mirror point ``c``, the symmetry expresses that
+ $f(..., c+x, ...) = f(..., c-x, ...)$.
+ """
+
+ _parameter_name: str = field(validator=instance_of(str), alias="parameter_name")
+ """The name of the single parameter affected by the symmetry."""
+
+ # object variables
+ mirror_point: float = field(
+ default=0.0, converter=float, validator=validate_is_finite, kw_only=True
+ )
+ """The mirror point."""
+
+ @override
+ @property
+ def parameter_names(self) -> tuple[str]:
+ return (self._parameter_name,)
+
+ @override
+ def augment_measurements(
+ self,
+ measurements: pd.DataFrame,
+ parameters: Iterable[Parameter] | None = None,
+ ) -> pd.DataFrame:
+ # See base class.
+
+ if not self.use_data_augmentation:
+ return measurements
+
+ measurements = df_apply_mirror_augmentation(
+ measurements, self._parameter_name, mirror_point=self.mirror_point
+ )
+
+ return measurements
+
+ @override
+ def validate_searchspace_context(self, searchspace: SearchSpace) -> None:
+ """See base class.
+
+ Args:
+ searchspace: The searchspace to validate against.
+
+ Raises:
+ TypeError: If the affected parameter is not numerical.
+ """
+ super().validate_searchspace_context(searchspace)
+
+ param = searchspace.get_parameters_by_name(self.parameter_names)[0]
+ if not param.is_numerical:
+ raise TypeError(
+ f"In a '{self.__class__.__name__}', the affected parameter must "
+ f"be numerical. However, the parameter '{param.name}' is of "
+ f"type '{param.__class__.__name__}' and is not numerical."
+ )
+
+
+# Collect leftover original slotted classes processed by `attrs.define`
+gc.collect()
diff --git a/baybe/symmetries/permutation.py b/baybe/symmetries/permutation.py
new file mode 100644
index 0000000000..deab356f9e
--- /dev/null
+++ b/baybe/symmetries/permutation.py
@@ -0,0 +1,158 @@
+"""Permutation symmetry."""
+
+from __future__ import annotations
+
+import gc
+from collections.abc import Iterable, Sequence
+from itertools import combinations
+from typing import TYPE_CHECKING, Any
+
+import pandas as pd
+from attrs import define, field
+from attrs.validators import deep_iterable, instance_of, min_len
+from typing_extensions import override
+
+from baybe.symmetries.base import Symmetry
+from baybe.utils.augmentation import df_apply_permutation_augmentation
+
+if TYPE_CHECKING:
+ from baybe.parameters.base import Parameter
+ from baybe.searchspace import SearchSpace
+
+
+def _convert_groups(
+ groups: Sequence[Sequence[str]],
+) -> tuple[tuple[str, ...], ...]:
+ """Convert nested sequences to a tuple of tuples, blocking bare strings."""
+ if isinstance(groups, str):
+ raise ValueError(
+ "The 'permutation_groups' argument must be a sequence of sequences, "
+ "not a string."
+ )
+ converted = []
+ for g in groups:
+ if isinstance(g, str):
+ raise ValueError(
+ "Each element in 'permutation_groups' must be a sequence of "
+ "parameter names, not a string."
+ )
+ converted.append(tuple(g))
+ return tuple(converted)
+
+
+@define(frozen=True)
+class PermutationSymmetry(Symmetry):
+ """Class for representing permutation symmetries.
+
+ A permutation symmetry expresses that certain parameters can be permuted without
+ affecting the outcome of the model. For instance, this is the case if
+ $f(x,y) = f(y,x)$.
+ """
+
+ permutation_groups: tuple[tuple[str, ...], ...] = field(
+ converter=_convert_groups,
+ validator=(
+ min_len(1),
+ deep_iterable(
+ member_validator=deep_iterable(
+ member_validator=instance_of(str),
+ iterable_validator=min_len(2),
+ ),
+ ),
+ ),
+ )
+ """The permutation groups. Each group contains parameter names that are permuted
+ in lockstep. All groups must have the same length."""
+
+ @permutation_groups.validator
+ def _validate_permutation_groups( # noqa: DOC101, DOC103
+ self, _: Any, groups: tuple[tuple[str, ...], ...]
+ ) -> None:
+ """Validate the permutation groups.
+
+ Raises:
+ ValueError: If the groups have different lengths.
+ ValueError: If any group contains duplicate parameters.
+ ValueError: If any parameter name appears in multiple groups.
+ """
+ # Ensure all groups have the same length
+ if len({len(g) for g in groups}) != 1:
+ lengths = {k + 1: len(g) for k, g in enumerate(groups)}
+ raise ValueError(
+ f"In a '{self.__class__.__name__}', all permutation groups "
+ f"must have the same length. Got group lengths: {lengths}."
+ )
+
+ # Ensure parameter names in each group are unique
+ for group in groups:
+ if len(set(group)) != len(group):
+ raise ValueError(
+ f"In a '{self.__class__.__name__}', all parameters being "
+ f"permuted with each other must be unique. However, the "
+ f"following group contains duplicates: {group}."
+ )
+
+ # Ensure there is no overlap between any permutation group
+ for a, b in combinations(groups, 2):
+ if overlap := set(a) & set(b):
+ raise ValueError(
+ f"In a '{self.__class__.__name__}', parameter names cannot "
+ f"appear in multiple permutation groups. However, the "
+ f"following parameter names appear in several groups: "
+ f"{overlap}."
+ )
+
+ @override
+ @property
+ def parameter_names(self) -> tuple[str, ...]:
+ return tuple(name for group in self.permutation_groups for name in group)
+
+ @override
+ def augment_measurements(
+ self,
+ measurements: pd.DataFrame,
+ parameters: Iterable[Parameter] | None = None,
+ ) -> pd.DataFrame:
+ # See base class.
+
+ if not self.use_data_augmentation:
+ return measurements
+
+ measurements = df_apply_permutation_augmentation(
+ measurements, self.permutation_groups
+ )
+
+ return measurements
+
+ @override
+ def validate_searchspace_context(self, searchspace: SearchSpace) -> None:
+ """See base class.
+
+ Args:
+ searchspace: The searchspace to validate against.
+
+ Raises:
+ ValueError: If parameters within a permutation group are not
+ equivalent (i.e., differ in type or specification).
+ """
+ super().validate_searchspace_context(searchspace)
+
+ # Ensure permuted parameters all have the same specification.
+ # Without this, it could be attempted to read in data that is not allowed
+ # for parameters that only allow a subset or different values compared to
+ # parameters they are being permuted with.
+ for group in self.permutation_groups:
+ params = searchspace.get_parameters_by_name(group)
+ ref = params[0]
+ for p in params[1:]:
+ if not ref.is_equivalent(p):
+ raise ValueError(
+ f"In a '{self.__class__.__name__}', all parameters "
+ f"within a permutation group must be equivalent. "
+ f"However, '{ref.name}' and '{p.name}' differ in "
+ f"their specification."
+ )
+
+
+# Collect leftover original slotted classes processed by `attrs.define`
+gc.collect()
diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py
index b9fc2d14aa..6b1fd50d28 100644
--- a/baybe/utils/augmentation.py
+++ b/baybe/utils/augmentation.py
@@ -8,24 +8,29 @@
def df_apply_permutation_augmentation(
df: pd.DataFrame,
- column_groups: Sequence[Sequence[str]],
+ permutation_groups: Sequence[Sequence[str]],
) -> pd.DataFrame:
"""Augment a dataframe if permutation invariant columns are present.
+ Each group in ``permutation_groups`` contains the names of columns that are
+ permuted in lockstep. All groups must have the same length, and that length
+ must be at least 2 (otherwise there is nothing to permute).
+
Args:
df: The dataframe that should be augmented.
- column_groups: Sequences of permutation invariant columns. The n'th column in
- each group will be permuted together with each n'th column in the other
- groups.
+ permutation_groups: Groups of column names that are permuted in lockstep.
+ For example, ``[["A1", "A2"], ["B1", "B2"]]`` means that the columns
+ ``A1`` and ``A2`` are permuted, and ``B1`` and ``B2`` are permuted
+ in the same way.
Returns:
The augmented dataframe containing the original one. Augmented row indices are
identical with the index of their original row.
Raises:
- ValueError: If less than two column groups are given.
- ValueError: If a column group is empty.
- ValueError: If the column groups have differing amounts of entries.
+ ValueError: If no permutation groups are given.
+ ValueError: If any permutation group has fewer than two entries.
+ ValueError: If the permutation groups have differing amounts of entries.
Examples:
>>> df = pd.DataFrame({'A1':[1,2],'A2':[3,4], 'B1': [5, 6], 'B2': [7, 8]})
@@ -34,8 +39,8 @@ def df_apply_permutation_augmentation(
0 1 3 5 7
1 2 4 6 8
- >>> column_groups = [['A1'], ['A2']]
- >>> dfa = df_apply_permutation_augmentation(df, column_groups)
+ >>> groups = [['A1', 'A2']]
+ >>> dfa = df_apply_permutation_augmentation(df, groups)
>>> dfa
A1 A2 B1 B2
0 1 3 5 7
@@ -43,8 +48,8 @@ def df_apply_permutation_augmentation(
1 2 4 6 8
1 4 2 6 8
- >>> column_groups = [['A1', 'B1'], ['A2', 'B2']]
- >>> dfa = df_apply_permutation_augmentation(df, column_groups)
+ >>> groups = [['A1', 'A2'], ['B1', 'B2']]
+ >>> dfa = df_apply_permutation_augmentation(df, groups)
>>> dfa
A1 A2 B1 B2
0 1 3 5 7
@@ -53,36 +58,35 @@ def df_apply_permutation_augmentation(
1 4 2 8 6
"""
# Validation
- if len(column_groups) < 2:
- raise ValueError(
- "When augmenting permutation invariance, at least two column sequences "
- "must be given."
- )
+ if len(permutation_groups) < 1:
+ raise ValueError("Permutation augmentation requires at least one group.")
- if len({len(seq) for seq in column_groups}) != 1:
+ if len({len(seq) for seq in permutation_groups}) != 1:
raise ValueError(
- "Permutation augmentation can only work if the amount of columns in each "
- "sequence is the same."
+ "Permutation augmentation can only work if all groups have the same "
+ "number of entries."
)
- elif len(column_groups[0]) < 1:
+
+ if len(permutation_groups[0]) < 2:
raise ValueError(
- "Permutation augmentation can only work if each column group has at "
- "least one entry."
+ "Permutation augmentation can only work if each group has at "
+ "least two entries."
)
# Augmentation Loop
+ n_positions = len(permutation_groups[0])
new_rows: list[pd.DataFrame] = []
- idx_permutation = list(permutations(range(len(column_groups))))
+ idx_permutation = list(permutations(range(n_positions)))
for _, row in df.iterrows():
# For each row in the original df, collect all its permutations
to_add = []
for perm in idx_permutation:
new_row = row.copy()
- # Permute columns, this is done separately for each tuple of columns that
- # belong together
- for deps in map(list, zip(*column_groups)):
- new_row[deps] = row[[deps[k] for k in perm]]
+ # Permute columns within each group according to the permutation
+ for group in permutation_groups:
+ cols = list(group)
+ new_row[cols] = row[[cols[k] for k in perm]]
to_add.append(new_row)
@@ -94,6 +98,69 @@ def df_apply_permutation_augmentation(
return pd.concat(new_rows)
+def df_apply_mirror_augmentation(
+ df: pd.DataFrame,
+ column: str,
+ *,
+ mirror_point: float = 0.0,
+) -> pd.DataFrame:
+ """Augment a dataframe for a mirror invariant column.
+
+ Args:
+ df: The dataframe that should be augmented.
+ column: The name of the affected column.
+ mirror_point: The point along which to mirror the values. Points that have
+ exactly this value will not be augmented.
+
+ Returns:
+ The augmented dataframe containing the original one. Augmented row indices are
+ identical with the index of their original row.
+
+ Examples:
+ >>> df = pd.DataFrame({'A':[1, 0, -2], 'B': [3, 4, 5]})
+ >>> df
+ A B
+ 0 1 3
+ 1 0 4
+ 2 -2 5
+
+ >>> dfa = df_apply_mirror_augmentation(df, "A")
+ >>> dfa
+ A B
+ 0 1 3
+ 0 -1 3
+ 1 0 4
+ 2 -2 5
+ 2 2 5
+
+ >>> dfa = df_apply_mirror_augmentation(df, "A", mirror_point=1)
+ >>> dfa
+ A B
+ 0 1 3
+ 1 0 4
+ 1 2 4
+ 2 -2 5
+ 2 4 5
+ """
+ new_rows: list[pd.DataFrame] = []
+ for _, row in df.iterrows():
+ to_add = [row] # Always keep original row
+
+ # Create the augmented row by mirroring the point at the mirror point.
+ # x_mirrored = mirror_point + (mirror_point - x) = 2*mirror_point - x
+ if row[column] != mirror_point:
+ row_new = row.copy()
+ row_new[column] = 2.0 * mirror_point - row[column]
+ to_add.append(row_new)
+
+ # Store augmented rows, keeping the index of their original row
+ new_rows.append(
+ pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add))
+ )
+
+ return pd.concat(new_rows)
+
+
def df_apply_dependency_augmentation(
df: pd.DataFrame,
causing: tuple[str, Sequence],
diff --git a/baybe/utils/conversion.py b/baybe/utils/conversion.py
index ed5f622532..846fc925ad 100644
--- a/baybe/utils/conversion.py
+++ b/baybe/utils/conversion.py
@@ -42,6 +42,18 @@ def nonstring_to_tuple(x: Sequence[_T], self: type, field: Attribute) -> tuple[_
return tuple(x)
+def normalize_convertible2str_sequence(
+ value: Sequence[str | bool], self: type, field: Attribute
+) -> tuple[str | bool, ...]:
+ """Sort and convert values for a sequence of string-convertible types.
+
+ If the sequence is a string itself, this is blocked to avoid unintended iteration
+ over its characters.
+ """
+ value = nonstring_to_tuple(value, self, field)
+ return tuple(sorted(value, key=lambda x: (str(type(x)), x)))
+
+
def _indent(text: str, amount: int = 3, ch: str = " ") -> str:
"""Indent a given text by a certain amount."""
padding = amount * ch
diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py
index 84b831a406..595eba3fc3 100644
--- a/baybe/utils/dataframe.py
+++ b/baybe/utils/dataframe.py
@@ -192,6 +192,10 @@ def create_fake_input(
Raises:
ValueError: If less than one row was requested.
+
+ Note:
+ This function does not consider constraints and might provide unexpected or
+ invalid data if certain constraints are present.
"""
# Assert at least one fake entry is being generated
if n_rows < 1:
diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py
index 93c87ab316..03a7462767 100644
--- a/baybe/utils/validation.py
+++ b/baybe/utils/validation.py
@@ -3,7 +3,8 @@
from __future__ import annotations
import math
-from collections.abc import Callable, Iterable
+from collections import Counter
+from collections.abc import Callable, Collection, Iterable, Sequence
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -261,3 +262,34 @@ def preprocess_dataframe(
else:
targets = ()
return normalize_input_dtypes(df, [*searchspace.parameters, *targets])
+
+
+def validate_is_finite( # noqa: DOC101, DOC103
+ _: Any, attribute: Attribute, value: float | Sequence[float]
+) -> None:
+ """Validate that ``value`` contains no resp. is not infinity/nan.
+
+ Raises:
+ ValueError: If ``value`` contains infinity/nan.
+ """
+ if not np.isfinite(value).all():
+ raise ValueError(
+ f"Cannot assign the following values containing infinity/nan to "
+ f"'{attribute.alias}': {value}."
+ )
+
+
+def validate_unique_values( # noqa: DOC101, DOC103
+ _: Any, attribute: Attribute, value: Collection[str]
+) -> None:
+ """Validate that there are no duplicates in ``value``.
+
+ Raises:
+ ValueError: If there are duplicates in ``value``.
+ """
+ duplicates = [item for item, count in Counter(value).items() if count > 1]
+ if duplicates:
+ raise ValueError(
+ f"Entries appearing multiple times: {duplicates}. "
+ f"All entries of '{attribute.alias}' must be unique."
+ )
diff --git a/docs/_static/symmetries/augmentation.svg b/docs/_static/symmetries/augmentation.svg
new file mode 100644
index 0000000000..ac1ffefeeb
--- /dev/null
+++ b/docs/_static/symmetries/augmentation.svg
@@ -0,0 +1,162 @@
+
+
+
diff --git a/docs/scripts/build_examples.py b/docs/scripts/build_examples.py
index bfd62db2d5..ddf9be318c 100644
--- a/docs/scripts/build_examples.py
+++ b/docs/scripts/build_examples.py
@@ -31,8 +31,12 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool):
else:
raise OSError("Destination directory exists but should not be removed.")
- # Copy the examples folder in the destination directory
- shutil.copytree("examples", destination_directory)
+ # Copy the examples folder in the destination directory. "__pycache__" might be
+ # present in the examples folder and needs to be ignored
+ def ignore_pycache(_, contents: list[str]):
+ return [item for item in contents if item == "__pycache__"]
+
+ shutil.copytree("examples", destination_directory, ignore=ignore_pycache)
# For the toctree of the top level example folder, we need to keep track of all
# folders. We thus write the header here and populate it during the execution of the
@@ -44,7 +48,7 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool):
# This list contains the order of the examples as we want to have them in the end.
# The examples that should be the first ones are already included here and skipped
- # later on. ALl other are just included.
+ # later on. All others are just included.
ex_order = [
"Basics\n",
"Searchspaces\n",
diff --git a/docs/userguide/constraints.md b/docs/userguide/constraints.md
index e66be3051c..535320d399 100644
--- a/docs/userguide/constraints.md
+++ b/docs/userguide/constraints.md
@@ -399,6 +399,7 @@ DiscreteDependenciesConstraint(
```
An end to end example can be found [here](../../examples/Constraints_Discrete/dependency_constraints).
+For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation).
### DiscretePermutationInvarianceConstraint
Permutation invariance, enabled by the
@@ -480,6 +481,7 @@ DiscretePermutationInvarianceConstraint(
The usage of `DiscretePermutationInvarianceConstraint` is also part of the
[example on slot-based mixtures](../../examples/Mixtures/slot_based).
+For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation).
### DiscreteCardinalityConstraint
Like its [continuous cousin](#ContinuousCardinalityConstraint), the
diff --git a/docs/userguide/surrogates.md b/docs/userguide/surrogates.md
index 781051573f..7464cb2866 100644
--- a/docs/userguide/surrogates.md
+++ b/docs/userguide/surrogates.md
@@ -6,6 +6,7 @@ the utilization of custom models. All surrogate models are based upon the genera
[`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer
learning, as indicated by the `supports_transfer_learning` attribute.
+
## Available Models
BayBE provides a comprehensive selection of surrogate models, empowering you to choose
@@ -93,7 +94,6 @@ A noticeable difference to the replication approach is that manual assembly requ
the exact set of target variables to be known at the time the object is created.
-
## Extracting the Model for Advanced Study
In principle, the surrogate model does not need to be a persistent object during
@@ -117,6 +117,22 @@ shap_values = explainer(data)
shap.plots.bar(shap_values)
~~~
+
+(surrogate_data_augmentation)=
+## Data Augmentation
+In certain situations like [mixture modeling](/examples/Mixtures/slot_based),
+symmetries are present. Data augmentation is a model-agnostic way of enabling the
+surrogate model to learn such symmetries effectively, which might result in a better
+performance, similar as e.g. for image classification models. BayBE
+`Surrogate`[baybe.surrogates.base.Surrogate] models automatically perform data
+augmentation if
+{attr}`~baybe.surrogates.base.Surrogate.symmetries` with
+`use_data_augmentation=True` are present. This means you can add a data point in
+any acceptable representation and BayBE will train the model on this point plus
+augmented points that can be generated from it. To see the effect in practice, refer to
+[this example](/examples/Symmetries/permutation).
+
+
## Using Custom Models
BayBE goes one step further by allowing you to incorporate custom models based on the
diff --git a/docs/userguide/symmetries.md b/docs/userguide/symmetries.md
new file mode 100644
index 0000000000..d31c43d64d
--- /dev/null
+++ b/docs/userguide/symmetries.md
@@ -0,0 +1,51 @@
+# Symmetry
+{class}`~baybe.symmetries.base.Symmetry` is a concept tied to the structure of the searchspace.
+It is thus closely related to a {class}`~baybe.constraints.base.Constraint`, but has a
+different purpose in BayBE. If the searchspace is symmetric in any sense, you can
+exclude the degenerate parts via a constraint. But this would not change the modeling
+process. The role of a {class}`~baybe.symmetries.base.Symmetry` is exactly this: Influence how
+the surrogate model is constructed to include the knowledge about the symmetry. This
+can be applied independently of constraints. For an example of the influence of
+symmetries and constraints on the optimization of a permutation invariant function,
+[see here](/examples/Symmetries/permutation).
+
+## Definitions
+The following table summarizes available symmetries in BayBE:
+
+| Symmetry | Functional Definition | Corresponding Constraint |
+|:-----------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------|
+| {class}`~baybe.symmetries.permutation.PermutationSymmetry` | $f(x,y) = f(y,x)$ | {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` |
+| {class}`~baybe.symmetries.dependency.DependencySymmetry` | $f(x,y) = \begin{cases}f(x,y) & \text{if }c(x) \\f(x) & \text{otherwise}\end{cases}$ where $c(x)$ is a condition that is either true or false | {class}`~baybe.constraints.discrete.DiscreteDependenciesConstraint` |
+| {class}`~baybe.symmetries.mirror.MirrorSymmetry` | $f(x,y) = f(-x,y)$ | No constraint is available. Instead, the number range for that parameter can simply be restricted. |
+
+## Data Augmentation
+This can be a powerful tool to improve the modeling process. Data augmentation
+essentially changes the data that the model is fitted on by adding more points. The
+augmented points are constructed such that they represent a symmetric point compared
+with their original, which always corresponds to a different transformation depending
+on which symmetry is responsible.
+
+If the surrogate model receives such augmented points, it can learn the symmetry. This
+has the advantage that it can improve predictions for unseen points and is fully
+model-agnostic. Downsides are increased training time and potential computational
+challenges arising from a fit on substantially more points. It is thus possible to
+control the data augmentation behavior of any {class}`~baybe.symmetries.base.Symmetry` by
+setting its {attr}`~baybe.symmetries.base.Symmetry.use_data_augmentation` attribute
+(`True` by default).
+
+Below we illustrate the effect of data augmentation for the different symmetries
+supported by BayBE:
+
+
+
+## Invariant Kernels
+Some machine learning models can be constructed with architectures that automatically
+respect a symmetry, i.e. applying the model to an augmented point always produces the
+same output as the original point by construction.
+
+For Gaussian processes, this can be achieved by applying special kernels.
+```{admonition} Not Implemented Yet
+:class: warning
+Ideally, invariant kernels will be applied automatically when a corresponding symmetry has been
+configured for the surrogate model GP. This feature is not implemented yet.
+```
\ No newline at end of file
diff --git a/docs/userguide/userguide.md b/docs/userguide/userguide.md
index 059f6ca581..cefb6516f3 100644
--- a/docs/userguide/userguide.md
+++ b/docs/userguide/userguide.md
@@ -44,6 +44,7 @@ Serialization
Settings
Simulation
Surrogates
+Symmetries
Targets
Transformations
Transfer Learning
diff --git a/examples/Symmetries/Symmetries_Header.md b/examples/Symmetries/Symmetries_Header.md
new file mode 100644
index 0000000000..361d774e3d
--- /dev/null
+++ b/examples/Symmetries/Symmetries_Header.md
@@ -0,0 +1,3 @@
+# Symmetries
+
+These examples demonstrate the impact of Symmetries.
\ No newline at end of file
diff --git a/examples/Symmetries/permutation.py b/examples/Symmetries/permutation.py
new file mode 100644
index 0000000000..a909aba0ad
--- /dev/null
+++ b/examples/Symmetries/permutation.py
@@ -0,0 +1,212 @@
+# # Optimizing a Permutation-Invariant Function
+
+# In this example, we explore BayBE's capabilities for handling optimization problems
+# with symmetry via automatic data augmentation and/or constraint.
+
+# ## Imports
+
+import os
+
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from matplotlib import pyplot as plt
+from matplotlib.ticker import MaxNLocator
+
+from baybe import Campaign, Settings
+from baybe.constraints import DiscretePermutationInvarianceConstraint
+from baybe.parameters import NumericalDiscreteParameter
+from baybe.recommenders import (
+ BotorchRecommender,
+ TwoPhaseMetaRecommender,
+)
+from baybe.searchspace import SearchSpace
+from baybe.simulation import simulate_scenarios
+from baybe.surrogates import NGBoostSurrogate
+from baybe.targets import NumericalTarget
+
+# ## Settings
+
+Settings(random_seed=1337).activate()
+SMOKE_TEST = "SMOKE_TEST" in os.environ
+N_MC_ITERATIONS = 2 if SMOKE_TEST else 100
+N_DOE_ITERATIONS = 2 if SMOKE_TEST else 50
+
+# ## The Scenario
+
+# We will explore a 2-dimensional permutation-invariant function, i.e., it holds that
+# $f(x,y) = f(y,x)$. The function was crafted to exhibit no additional mirror symmetry
+# (a common way of also resulting in permutation invariance) and have multiple minima.
+# In practice, permutation invariance can arise e.g. for
+# [mixtures when modeled with a slot-based approach](/examples/Mixtures/slot_based).
+
+# There are several ways to handle such symmetries. The simplest one is
+# to augment your data. In the case of permutation invariance, augmentation means for
+# each measurement $(x,y)$ you also add a measurement with switched values, i.e., $(y,x)§.
+# This has the advantage that it is fully model-agnostic, but might
+# come at the expense of increased training time and efficiency due to the larger amount
+# of effective training points.
+
+
+LBOUND = -2.0
+UBOUND = 2.0
+
+
+def lookup(df: pd.DataFrame) -> pd.DataFrame:
+ """A lookup modeling a permutation-invariant 2D function with multiple minima."""
+ x = df["x"].values
+ y = df["y"].values
+ result = (
+ (x - y) ** 2
+ + (x**3 + y**3)
+ + ((x**2 - 1) ** 2 + (y**2 - 1) ** 2)
+ + np.sin(3 * (x + y)) ** 2
+ + np.sin(3 * np.abs(x - y)) ** 2
+ )
+
+ df_z = pd.DataFrame({"f": result}, index=df.index)
+ return df_z
+
+
+x = np.linspace(LBOUND, UBOUND, 25)
+y = np.linspace(LBOUND, UBOUND, 25)
+xx, yy = np.meshgrid(x, y)
+df_plot = lookup(pd.DataFrame({"x": xx.ravel(), "y": yy.ravel()}))
+zz = df_plot["f"].values.reshape(xx.shape)
+line_vals = np.linspace(LBOUND, UBOUND, 2)
+
+# fmt: off
+fig, axs = plt.subplots(1, 2, figsize=(15, 6))
+contour = axs[0].contourf(xx, yy, zz, levels=50, cmap="viridis")
+fig.colorbar(contour, ax=axs[0])
+axs[0].plot(line_vals, line_vals, "r--", alpha=0.5, linewidth=2)
+axs[0].set_title("Ground Truth: $f(x, y)$ = $f(y, x)$ (Permutation Invariant)")
+axs[0].set_xlabel("x")
+axs[0].set_ylabel("y");
+# fmt: on
+
+# The plots can be found at the bottom of this file.
+# The first subplot shows the function we want to minimize. The dashed red line
+# illustrates the permutation invariance, which is similar to a mirror-symmetry, just
+# not along any of the parameter axis but along the diagonal. We can also see several
+# local minima.
+
+# Such a situation can be challenging for optimization algorithms if no information
+# about the invariance is considered. For instance, if no
+# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` was used
+# at all, BayBE would search for the optima across the entire 2D space. But it is clear
+# that the search can be restricted to the lower (or equivalently the upper) triangle
+# of the searchspace. This is exactly what
+# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` does:
+# Remove entries that are "duplicated" in the sense of already being represented by
+# another invariant point.
+
+# If the surrogate is additionally configured with `symmetries` that use
+# `use_data_augmentation=True`, the model will be fit with an extended set of points,
+# including augmented ones. So as a user, you don't have to generate permutations and
+# add them manually. Depending on the surrogate model, this might have different
+# impacts. For example, we can expect a strong effect for tree-based models because their splits are
+# always parallel to the parameter axes. Thus, without augmented measurements, it is
+# easy to fall into suboptimal splits and overfit. We illustrate this by using the
+# {class}`~baybe.surrogates.ngboost.NGBoostSurrogate`.
+
+# ## The Optimization Problem
+
+p1 = NumericalDiscreteParameter("x", np.linspace(LBOUND, UBOUND, 51))
+p2 = NumericalDiscreteParameter("y", np.linspace(LBOUND, UBOUND, 51))
+objective = NumericalTarget("f", minimize=True).to_objective()
+
+# We set up a constrained and an unconstrained (plain) searchspace to demonstrate the
+# impact of the constraint on optimization performance.
+
+constraint = DiscretePermutationInvarianceConstraint(["x", "y"])
+searchspace_plain = SearchSpace.from_product([p1, p2])
+searchspace_constrained = SearchSpace.from_product([p1, p2], [constraint])
+
+print("Number of Points in the Searchspace")
+print(f"{'Without Constraint:':<35} {len(searchspace_plain.discrete.exp_rep)}")
+print(f"{'With Constraint:':<35} {len(searchspace_constrained.discrete.exp_rep)}")
+
+# We can see that the unconstrained searchspace has roughly twice as many points
+# compared to the constrained one. This is expected, as the
+# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint`
+# effectively models only one half of the parameter triangle. Note that the factor is
+# not exactly 2 due to the (still included) points on the diagonal.
+
+# BayBE can automatically perform this augmentation if configured to do so.
+# Specifically, surrogate models have the
+# {attr}`~baybe.surrogates.base.Surrogate.symmetries` attribute. If any of
+# these symmetries has `use_data_augmentation=True` (enabled by default),
+# BayBE will automatically augment measurements internally before performing the model
+# fit. To construct symmetries quickly, we use the `to_symmetry` method of the
+# constraint.
+
+symmetry = constraint.to_symmetry(use_data_augmentation=True)
+recommender_plain = TwoPhaseMetaRecommender(
+ recommender=BotorchRecommender(surrogate_model=NGBoostSurrogate())
+)
+recommender_symmetric = TwoPhaseMetaRecommender(
+ recommender=BotorchRecommender(
+ surrogate_model=NGBoostSurrogate(symmetries=[symmetry])
+ )
+)
+
+# The combination of constraint and augmentation settings results in four different
+# campaigns:
+
+campaign_plain = Campaign(searchspace_plain, objective, recommender_plain)
+campaign_c = Campaign(searchspace_constrained, objective, recommender_plain)
+campaign_s = Campaign(searchspace_plain, objective, recommender_symmetric)
+campaign_cs = Campaign(searchspace_constrained, objective, recommender_symmetric)
+
+# ## Simulating the Optimization Loop
+
+
+scenarios = {
+ "Unconstrained, Unsymmetric": campaign_plain,
+ "Constrained, Unsymmetric": campaign_c,
+ "Unconstrained, Symmetric": campaign_s,
+ "Constrained, Symmetric": campaign_cs,
+}
+
+results = simulate_scenarios(
+ scenarios,
+ lookup,
+ n_doe_iterations=N_DOE_ITERATIONS,
+ n_mc_iterations=N_MC_ITERATIONS,
+).rename(
+ columns={
+ "f_CumBest": "$f(x,y)$ (cumulative best)",
+ "Num_Experiments": "# Experiments",
+ }
+)
+
+# ## Results
+
+# Let us visualize the optimization process in the second subplot:
+
+sns.lineplot(
+ data=results,
+ x="# Experiments",
+ y="$f(x,y)$ (cumulative best)",
+ hue="Scenario",
+ marker="o",
+ ax=axs[1],
+)
+axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
+axs[1].set_ylim(axs[1].get_ylim()[0], 3)
+axs[1].set_title("Minimization Performance")
+plt.tight_layout()
+plt.show()
+
+# We find that the campaigns utilizing the permutation invariance constraint
+# perform better than the ones without. This can be attributed simply to the reduced
+# number of searchspace points they operate on. However, this effect is rather minor
+# compared to the effect of symmetry.
+
+# Furthermore, there is a strong impact on whether data augmentation is used or not,
+# the effect we expected for a tree-based surrogate model. Indeed, the campaign with
+# constraint but without augmentation is barely better than the campaign not utilizing
+# the constraint at all. Conversely, the data-augmented campaign has a clearly superior
+# performance. The best result is achieved by using both constraints and data
+# augmentation.
diff --git a/examples/Symmetries/permutation.svg b/examples/Symmetries/permutation.svg
new file mode 100644
index 0000000000..ab17b0ac9a
--- /dev/null
+++ b/examples/Symmetries/permutation.svg
@@ -0,0 +1,7525 @@
+
+
+
diff --git a/tests/conftest.py b/tests/conftest.py
index 882204c156..5cd9537b67 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -242,7 +242,7 @@ def fixture_parameters(
NumericalDiscreteParameter(
name="Fraction_1",
values=tuple(np.linspace(0, 100, n_grid_points)),
- tolerance=0.2,
+ tolerance=0.5,
),
NumericalDiscreteParameter(
name="Fraction_2",
diff --git a/tests/hypothesis_strategies/conditions.py b/tests/hypothesis_strategies/conditions.py
new file mode 100644
index 0000000000..fe90efdaa8
--- /dev/null
+++ b/tests/hypothesis_strategies/conditions.py
@@ -0,0 +1,24 @@
+"""Hypothesis strategies for conditions."""
+
+from typing import Any
+
+import hypothesis.strategies as st
+
+from baybe.constraints import SubSelectionCondition, ThresholdCondition
+from tests.hypothesis_strategies.basic import finite_floats
+
+
+def sub_selection_conditions(superset: list[Any] | None = None):
+ """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`."""
+ if superset is None:
+ element_strategy = st.text()
+ else:
+ element_strategy = st.sampled_from(superset)
+ return st.builds(
+ SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1)
+ )
+
+
+def threshold_conditions():
+ """Generate :class:`baybe.constraints.conditions.ThresholdCondition`."""
+ return st.builds(ThresholdCondition, threshold=finite_floats())
diff --git a/tests/hypothesis_strategies/constraints.py b/tests/hypothesis_strategies/constraints.py
index e1f1014833..dfc065d817 100644
--- a/tests/hypothesis_strategies/constraints.py
+++ b/tests/hypothesis_strategies/constraints.py
@@ -1,14 +1,11 @@
"""Hypothesis strategies for constraints."""
from functools import partial
-from typing import Any
import hypothesis.strategies as st
from hypothesis import assume
from baybe.constraints.conditions import (
- SubSelectionCondition,
- ThresholdCondition,
_valid_logic_combiners,
)
from baybe.constraints.continuous import (
@@ -26,22 +23,10 @@
from baybe.parameters.base import DiscreteParameter
from baybe.parameters.numerical import NumericalDiscreteParameter
from tests.hypothesis_strategies.basic import finite_floats
-
-
-def sub_selection_conditions(superset: list[Any] | None = None):
- """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`."""
- if superset is None:
- element_strategy = st.text()
- else:
- element_strategy = st.sampled_from(superset)
- return st.builds(
- SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1)
- )
-
-
-def threshold_conditions():
- """Generate :class:`baybe.constraints.conditions.ThresholdCondition`."""
- return st.builds(ThresholdCondition, threshold=finite_floats())
+from tests.hypothesis_strategies.conditions import (
+ sub_selection_conditions,
+ threshold_conditions,
+)
@st.composite
diff --git a/tests/hypothesis_strategies/symmetries.py b/tests/hypothesis_strategies/symmetries.py
new file mode 100644
index 0000000000..31e45aaa40
--- /dev/null
+++ b/tests/hypothesis_strategies/symmetries.py
@@ -0,0 +1,121 @@
+"""Hypothesis strategies for symmetries."""
+
+import hypothesis.strategies as st
+from hypothesis import assume
+
+from baybe.parameters.base import Parameter
+from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry
+from tests.hypothesis_strategies.basic import finite_floats
+from tests.hypothesis_strategies.conditions import (
+ sub_selection_conditions,
+ threshold_conditions,
+)
+
+
+@st.composite
+def mirror_symmetries(draw: st.DrawFn, parameter_pool: list[Parameter] | None = None):
+ """Generate :class:`baybe.symmetries.MirrorSymmetry`."""
+ if parameter_pool is None:
+ parameter_name = draw(st.text(min_size=1))
+ else:
+ parameter = draw(st.sampled_from(parameter_pool))
+ assume(parameter.is_numerical)
+ parameter_name = parameter.name
+
+ return MirrorSymmetry(
+ parameter_name=parameter_name,
+ use_data_augmentation=draw(st.booleans()),
+ mirror_point=draw(finite_floats()),
+ )
+
+
+@st.composite
+def permutation_symmetries(
+ draw: st.DrawFn, parameter_pool: list[Parameter] | None = None
+):
+ """Generate :class:`baybe.symmetries.PermutationSymmetry`."""
+ if parameter_pool is None:
+ parameter_names_pool = draw(
+ st.lists(st.text(min_size=1), unique=True, min_size=2)
+ )
+ else:
+ parameter_names_pool = [p.name for p in parameter_pool]
+
+ # Draw the first (required) group with at least 2 elements
+ first_group = draw(
+ st.lists(st.sampled_from(parameter_names_pool), min_size=2, unique=True).map(
+ tuple
+ )
+ )
+ group_size = len(first_group)
+
+ # Determine how many additional groups we can have
+ remaining_names = [n for n in parameter_names_pool if n not in first_group]
+ max_additional = len(remaining_names) // group_size if group_size > 0 else 0
+ n_additional = draw(st.integers(min_value=0, max_value=max_additional))
+
+ # Draw additional groups of the same size from remaining names
+ all_groups = [first_group]
+ for _ in range(n_additional):
+ group = draw(
+ st.lists(
+ st.sampled_from(remaining_names),
+ unique=True,
+ min_size=group_size,
+ max_size=group_size,
+ ).map(tuple)
+ )
+ # Ensure no overlap with existing groups
+ for existing in all_groups:
+ assume(not set(group) & set(existing))
+ remaining_names = [n for n in remaining_names if n not in group]
+ all_groups.append(group)
+
+ return PermutationSymmetry(
+ permutation_groups=all_groups,
+ use_data_augmentation=draw(st.booleans()),
+ )
+
+
+@st.composite
+def dependency_symmetries(
+ draw: st.DrawFn, parameter_pool: list[Parameter] | None = None
+):
+ """Generate :class:`baybe.symmetries.DependencySymmetry`."""
+ if parameter_pool is None:
+ parameter_name = draw(st.text(min_size=1))
+ affected_strat = st.lists(
+ st.text(min_size=1).filter(lambda x: x != parameter_name),
+ min_size=1,
+ unique=True,
+ ).map(tuple)
+ else:
+ parameter = draw(st.sampled_from(parameter_pool))
+ assume(parameter.is_discrete)
+ parameter_name = parameter.name
+ affected_strat = st.lists(
+ st.sampled_from(parameter_pool)
+ .filter(lambda x: x.name != parameter_name)
+ .map(lambda x: x.name),
+ unique=True,
+ min_size=1,
+ max_size=len(parameter_pool) - 1,
+ ).map(tuple)
+
+ return DependencySymmetry(
+ parameter_name=parameter_name,
+ condition=draw(st.one_of(threshold_conditions(), sub_selection_conditions())),
+ affected_parameter_names=draw(affected_strat),
+ n_discretization_points=draw(st.integers(min_value=2)),
+ use_data_augmentation=draw(st.booleans()),
+ )
+
+
+symmetries = st.one_of(
+ [
+ mirror_symmetries(),
+ permutation_symmetries(),
+ dependency_symmetries(),
+ ]
+)
+"""A strategy that generates symmetries."""
diff --git a/tests/serialization/test_symmetry_serialization.py b/tests/serialization/test_symmetry_serialization.py
new file mode 100644
index 0000000000..e1f6ae9373
--- /dev/null
+++ b/tests/serialization/test_symmetry_serialization.py
@@ -0,0 +1,28 @@
+"""Symmetry serialization tests."""
+
+import hypothesis.strategies as st
+import pytest
+from hypothesis import given
+from pytest import param
+
+from tests.hypothesis_strategies.symmetries import (
+ dependency_symmetries,
+ mirror_symmetries,
+ permutation_symmetries,
+)
+from tests.serialization.utils import assert_roundtrip_consistency
+
+
+@pytest.mark.parametrize(
+ "strategy",
+ [
+ param(mirror_symmetries(), id="MirrorSymmetry"),
+ param(permutation_symmetries(), id="PermutationSymmetry"),
+ param(dependency_symmetries(), id="DependencySymmetry"),
+ ],
+)
+@given(data=st.data())
+def test_roundtrip(strategy: st.SearchStrategy, data: st.DataObject):
+ """A serialization roundtrip yields an equivalent object."""
+ symmetry = data.draw(strategy)
+ assert_roundtrip_consistency(symmetry)
diff --git a/tests/test_measurement_augmentation.py b/tests/test_measurement_augmentation.py
new file mode 100644
index 0000000000..fa7e76c289
--- /dev/null
+++ b/tests/test_measurement_augmentation.py
@@ -0,0 +1,143 @@
+"""Tests for augmentation of measurements."""
+
+import math
+from unittest.mock import patch
+
+import numpy as np
+import pandas as pd
+import pytest
+from attrs import evolve
+from pandas.testing import assert_frame_equal
+
+from baybe.acquisition import qLogEI
+from baybe.constraints import (
+ DiscretePermutationInvarianceConstraint,
+ ThresholdCondition,
+)
+from baybe.parameters import (
+ CategoricalParameter,
+ NumericalContinuousParameter,
+ NumericalDiscreteParameter,
+)
+from baybe.recommenders import BotorchRecommender
+from baybe.searchspace import SearchSpace
+from baybe.symmetries import DependencySymmetry, MirrorSymmetry
+from baybe.utils.dataframe import create_fake_input
+
+
+@pytest.mark.parametrize("mirror_aug", [True, False], ids=["mirror", "nomirror"])
+@pytest.mark.parametrize("perm_aug", [True, False], ids=["perm", "noperm"])
+@pytest.mark.parametrize("dep_aug", [True, False], ids=["dep", "nodep"])
+@pytest.mark.parametrize(
+ "constraint_names", [["Constraint_11", "Constraint_7"]], ids=["c"]
+)
+@pytest.mark.parametrize(
+ "parameter_names",
+ [
+ [
+ "Solvent_1",
+ "Solvent_2",
+ "Solvent_3",
+ "Fraction_1",
+ "Fraction_2",
+ "Fraction_3",
+ "Num_disc_2",
+ ]
+ ],
+ ids=["p"],
+)
+def test_measurement_augmentation(
+ parameters,
+ surrogate_model,
+ objective,
+ constraints,
+ dep_aug,
+ perm_aug,
+ mirror_aug,
+):
+ """Measurement augmentation is performed if configured."""
+ original_to_botorch = qLogEI.to_botorch
+ called_args_list = []
+
+ def spy(self, *args, **kwargs):
+ called_args_list.append((args, kwargs))
+ return original_to_botorch(self, *args, **kwargs)
+
+ with patch.object(qLogEI, "to_botorch", side_effect=spy, autospec=True):
+ # Basic setup
+ c_perm = next(
+ c
+ for c in constraints
+ if isinstance(c, DiscretePermutationInvarianceConstraint)
+ )
+ c_dep = c_perm.dependencies
+ s_perm = c_perm.to_symmetry(perm_aug)
+ s_deps = c_dep.to_symmetries(dep_aug) # this is a tuple of multiple
+ s_mirror = MirrorSymmetry("Num_disc_2", use_data_augmentation=mirror_aug)
+ searchspace = SearchSpace.from_product(parameters, constraints)
+ surrogate = evolve(surrogate_model, symmetries=[*s_deps, s_perm, s_mirror])
+ recommender = BotorchRecommender(
+ surrogate_model=surrogate, acquisition_function=qLogEI()
+ )
+
+ # Perform call and watch measurements
+ measurements = create_fake_input(parameters, objective.targets, 5)
+ recommender.recommend(1, searchspace, objective, measurements)
+ measurements_passed = called_args_list[0][0][3] # take 4th arg from first call
+
+ # Create expectation
+ # We calculate how many degenerate points the augmentation should create:
+ # - n_dep: Product of the number of active values for all affected parameters
+ # - n_perm: Number of permutations possible
+ # - n_mirror: 2 if the row is not on the mirror point, else 1
+ # - If augmentation is turned off, the corresponding factor becomes 1
+ # We expect a given row to produce n_perm * (n_dep^k) * n_mirror points, where
+ # k is the number of "Fraction_*" parameters having the "causing" value 0.0. The
+ # total number of expected points after augmentation is the sum over the
+ # expectations for all rows.
+ dep_affected = [p for p in parameters if p.name in c_dep.affected_parameters[0]]
+ n_dep = math.prod(len(p.active_values) for p in dep_affected) if dep_aug else 1
+ n_perm = ( # number of permutations
+ math.prod(range(1, len(c_perm.parameters) + 1)) if perm_aug else 1
+ )
+ n_expected = 0
+ for _, row in measurements.iterrows():
+ n_mirror = (
+ 2 if (mirror_aug and row["Num_disc_2"] != s_mirror.mirror_point) else 1
+ )
+ k = row[c_dep.parameters].eq(0).sum()
+ n_expected += n_perm * np.pow(n_dep, k) * n_mirror
+
+ # Check expectation
+ if any([dep_aug, perm_aug, mirror_aug]):
+ assert len(measurements_passed) == n_expected
+ else:
+ assert_frame_equal(measurements, measurements_passed)
+
+
+@pytest.mark.parametrize("n_points", [2, 5])
+@pytest.mark.parametrize("mixed", [True, False])
+def test_continuous_dependency_augmentation(n_points, mixed):
+ """Dependency augmentation with continuous affected parameters works correctly."""
+ df = pd.DataFrame({"n1": [1, 0], "cat1": ["a", "b"], "c1": [1, 2]})
+ ps = [
+ NumericalDiscreteParameter("n1", (0, 0.5, 1.0)),
+ CategoricalParameter("cat1", ("a", "b", "c")),
+ NumericalContinuousParameter("c1", (0, 10)),
+ ]
+
+ # Model "Affected parameters only matter if n1 is > 0"
+ s = DependencySymmetry(
+ "n1",
+ condition=ThresholdCondition(0.0, ">"),
+ affected_parameter_names=("cat1", "c1") if mixed else ("c1",),
+ n_discretization_points=n_points,
+ )
+ dfa = s.augment_measurements(df, ps)
+
+ # Calculate expectation. The first row is not affected and contributes one point.
+ # The second row is degenerate and contributes `n_points` values for c1 and 3
+ # values for cat1 (if part of the symmetry).
+ n_expected = 1 + n_points * (3 if mixed else 1)
+
+ assert len(dfa) == n_expected
diff --git a/tests/utils/test_augmentation.py b/tests/utils/test_augmentation.py
index 80ffff8a2a..4dd8a14d33 100644
--- a/tests/utils/test_augmentation.py
+++ b/tests/utils/test_augmentation.py
@@ -22,7 +22,7 @@
},
"index": [0, 1],
},
- [["A"], ["B"]],
+ [["A", "B"]],
{
"data": {
"A": [1, 2, 1, 2],
@@ -41,7 +41,7 @@
},
"index": [0, 1],
},
- [["A"], ["B"]],
+ [["A", "B"]],
{
"data": {
"A": [1, 2, 1, 2],
@@ -60,7 +60,7 @@
},
"index": [0, 1],
},
- [["A"], ["B"]],
+ [["A", "B"]],
{
"data": {
"A": [1, 2, 1, 2],
@@ -80,7 +80,7 @@
},
"index": [0, 1],
},
- [["A"], ["B"]],
+ [["A", "B"]],
{
"data": {
"A": [1, 2, 1, 2],
@@ -91,7 +91,7 @@
},
id="2inv+degen_target+degen",
),
- param( # 3 invariant groups with 1 entry each
+ param( # 3 invariant cols in one group
{
"data": {
"A": [1, 1],
@@ -101,7 +101,7 @@
},
"index": [0, 1],
},
- [["A"], ["B"], ["C"]],
+ [["A", "B", "C"]],
{
"data": {
"A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5],
@@ -113,7 +113,7 @@
},
id="3inv_1add",
),
- param( # 2 groups with 2 entries each, 2 additional columns
+ param( # 2 lockstep groups with 2 entries each, 2 additional columns
{
"data": {
"Slot1": ["s1", "s2"],
@@ -125,7 +125,7 @@
},
"index": [0, 1],
},
- [["Slot1", "Frac1"], ["Slot2", "Frac2"]],
+ [["Slot1", "Slot2"], ["Frac1", "Frac2"]],
{
"data": {
"Slot1": ["s1", "s2", "s2", "s4"],
@@ -139,7 +139,7 @@
},
id="2inv_2dependent_2add",
),
- param( # 2 groups with 3 entries each, 1 additional column
+ param( # 3 lockstep groups with 2 entries each, 1 additional column
{
"data": {
"Slot1": ["s1", "s2"],
@@ -152,7 +152,7 @@
},
"index": [0, 1],
},
- [["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]],
+ [["Slot1", "Slot2"], ["Frac1", "Frac2"], ["Temp1", "Temp2"]],
{
"data": {
"Slot1": ["s1", "s2", "s2", "s4"],
@@ -187,10 +187,9 @@ def test_df_permutation_aug(content, col_groups, content_expected):
@pytest.mark.parametrize(
("col_groups", "msg"),
[
- param([], "at least two column sequences", id="no_groups"),
- param([["A"]], "at least two column sequences", id="just_one_group"),
- param([["A"], ["B", "C"]], "the amount of columns in", id="different_lengths"),
- param([[], []], "each column group has", id="empty_group"),
+ param([], "at least one group", id="no_groups"),
+ param([["A"]], "at least two entries", id="group_too_small"),
+ param([["A", "B"], ["C"]], "same number of entries", id="different_lengths"),
],
)
def test_df_permutation_aug_invalid(col_groups, msg):
diff --git a/tests/validation/test_symmetry_validation.py b/tests/validation/test_symmetry_validation.py
new file mode 100644
index 0000000000..b4093eb77f
--- /dev/null
+++ b/tests/validation/test_symmetry_validation.py
@@ -0,0 +1,291 @@
+"""Validation tests for symmetry."""
+
+import numpy as np
+import pandas as pd
+import pytest
+from pytest import param
+
+from baybe.constraints import ThresholdCondition
+from baybe.exceptions import IncompatibleSearchSpaceError
+from baybe.parameters import (
+ CategoricalParameter,
+ NumericalContinuousParameter,
+ NumericalDiscreteParameter,
+)
+from baybe.recommenders import BotorchRecommender
+from baybe.searchspace import SearchSpace
+from baybe.surrogates import GaussianProcessSurrogate
+from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry
+from baybe.targets import NumericalTarget
+from baybe.utils.dataframe import create_fake_input
+
+valid_config_mirror = {"parameter_name": "n1"}
+valid_config_perm = {
+ "permutation_groups": [["cat1", "cat2"], ["n1", "n2"]],
+}
+valid_config_dep = {
+ "parameter_name": "n1",
+ "condition": ThresholdCondition(0.0, ">="),
+ "affected_parameter_names": ["n2", "cat1"],
+}
+
+
+@pytest.mark.parametrize(
+ "cls, config, error, msg",
+ [
+ param(
+ MirrorSymmetry,
+ valid_config_mirror | {"mirror_point": np.inf},
+ ValueError,
+ "values containing infinity/nan to 'mirror_point': inf",
+ id="mirror_nonfinite",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": [["cat1", "cat1"]]},
+ ValueError,
+ r"the following group contains duplicates",
+ id="perm_not_unique",
+ ),
+ param(
+ PermutationSymmetry,
+ {
+ "permutation_groups": [
+ ["cat1", "cat2", "cat3"],
+ ["n1", "n2", "n3", "n4"],
+ ]
+ },
+ ValueError,
+ "must have the same length",
+ id="perm_different_lengths",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": [["cat1", "cat2"], ["cat1", "n2"]]},
+ ValueError,
+ r"following parameter names appear in several groups",
+ id="perm_overlap",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": [["cat1"]]},
+ ValueError,
+ "must be >= 2",
+ id="perm_group_too_small",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": []},
+ ValueError,
+ "must be >= 1",
+ id="perm_no_groups",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": [[1, 2]]},
+ TypeError,
+ "must be ",
+ id="perm_not_str",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"parameter_name": 1},
+ TypeError,
+ "must be ",
+ id="dep_param_not_str",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"condition": 1},
+ TypeError,
+ "must be ",
+ id="dep_wrong_cond_type",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"affected_parameter_names": []},
+ ValueError,
+ "Length of 'affected_parameter_names' must be >= 1",
+ id="dep_affected_empty",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"affected_parameter_names": [1]},
+ TypeError,
+ "must be ",
+ id="dep_affected_wrong_type",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"affected_parameter_names": ["a1", "a1"]},
+ ValueError,
+ r"Entries appearing multiple times: \['a1'\].",
+ id="dep_affected_not_unique",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": "abc"},
+ ValueError,
+ "must be a sequence of sequences, not a string",
+ id="perm_groups_bare_string",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": ["abc", "def"]},
+ ValueError,
+ "must be a sequence of parameter names, not a string",
+ id="perm_groups_inner_bare_string",
+ ),
+ param(
+ MirrorSymmetry,
+ {"parameter_name": 123},
+ TypeError,
+ "must be ",
+ id="mirror_param_not_str",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"n_discretization_points": 3.5},
+ TypeError,
+ "must be ",
+ id="dep_n_discretization_not_int",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"n_discretization_points": 1},
+ ValueError,
+ "must be >= 2",
+ id="dep_n_discretization_too_small",
+ ),
+ param(
+ DependencySymmetry,
+ valid_config_dep | {"affected_parameter_names": "abc"},
+ ValueError,
+ "must be a sequence but cannot be a string",
+ id="dep_affected_bare_string",
+ ),
+ param(
+ PermutationSymmetry,
+ {"permutation_groups": [["a", "b"]], "use_data_augmentation": 1},
+ TypeError,
+ "must be ",
+ id="use_aug_not_bool",
+ ),
+ ],
+)
+def test_configuration(cls, config, error, msg):
+ """Invalid configurations raise an expected error."""
+ with pytest.raises(error, match=msg):
+ cls(**config)
+
+
+_parameters = [
+ NumericalDiscreteParameter("n1", (-1, 0, 1)),
+ NumericalDiscreteParameter("n2", (-1, 0, 1)),
+ NumericalContinuousParameter("n1_not_discrete", (0.0, 10.0)),
+ NumericalContinuousParameter("n2_not_discrete", (0.0, 10.0)),
+ NumericalContinuousParameter("c1", (0.0, 10.0)),
+ NumericalContinuousParameter("c2", (0.0, 10.0)),
+ CategoricalParameter("cat1", ("a", "b", "c")),
+ CategoricalParameter("cat1_altered", ("a", "b")),
+ CategoricalParameter("cat2", ("a", "b", "c")),
+]
+
+
+@pytest.fixture
+def searchspace(parameter_names):
+ ps = tuple(p for p in _parameters if p.name in parameter_names)
+ return SearchSpace.from_product(ps)
+
+
+@pytest.mark.parametrize(
+ "parameter_names, symmetry, error, msg",
+ [
+ param(
+ ["cat1"],
+ MirrorSymmetry(parameter_name="cat1"),
+ TypeError,
+ "'cat1' is of type 'CategoricalParameter' and is not numerical",
+ id="mirror_not_numerical",
+ ),
+ param(
+ ["n1"],
+ MirrorSymmetry(parameter_name="n2"),
+ IncompatibleSearchSpaceError,
+ r"not present in the search space",
+ id="mirror_param_missing",
+ ),
+ param(
+ ["n2", "cat1"],
+ DependencySymmetry(**valid_config_dep),
+ IncompatibleSearchSpaceError,
+ r"not present in the search space",
+ id="dep_causing_missing",
+ ),
+ param(
+ ["n1", "cat1"],
+ DependencySymmetry(**valid_config_dep),
+ IncompatibleSearchSpaceError,
+ r"not present in the search space",
+ id="dep_affected_missing",
+ ),
+ param(
+ ["n1_not_discrete", "n2", "cat1"],
+ DependencySymmetry(
+ **valid_config_dep | {"parameter_name": "n1_not_discrete"}
+ ),
+ TypeError,
+ "must be discrete. However, the parameter 'n1_not_discrete'",
+ id="dep_causing_not_discrete",
+ ),
+ param(
+ ["cat1", "n1", "n2"],
+ PermutationSymmetry(**valid_config_perm),
+ IncompatibleSearchSpaceError,
+ r"not present in the search space",
+ id="perm_not_present",
+ ),
+ param(
+ ["cat1", "cat2", "n1", "n2"],
+ PermutationSymmetry(permutation_groups=[("cat1", "n1"), ("cat2", "n2")]),
+ ValueError,
+ r"differ in their specification",
+ id="perm_inconsistent_types",
+ ),
+ param(
+ ["cat1_altered", "cat2", "n1", "n2"],
+ PermutationSymmetry(
+ permutation_groups=[["cat1_altered", "cat2"], ["n1", "n2"]]
+ ),
+ ValueError,
+ r"differ in their specification",
+ id="perm_inconsistent_values",
+ ),
+ ],
+)
+def test_searchspace_context(searchspace, symmetry, error, msg):
+ """Configurations not compatible with the searchspace raise an expected error."""
+ recommender = BotorchRecommender(
+ surrogate_model=GaussianProcessSurrogate(symmetries=(symmetry,))
+ )
+ t = NumericalTarget("t")
+ measurements = create_fake_input(searchspace.parameters, [t])
+
+ with pytest.raises(error, match=msg):
+ recommender.recommend(
+ 1, searchspace, t.to_objective(), measurements=measurements
+ )
+
+
+def test_dependency_augmentation_requires_parameters():
+ """DependencySymmetry.augment_measurements raises when parameters is None."""
+ s = DependencySymmetry(**valid_config_dep)
+ df = pd.DataFrame({"n1": [0], "n2": [1], "cat1": ["a"]})
+ with pytest.raises(ValueError, match="requires parameter objects"):
+ s.augment_measurements(df)
+
+
+def test_surrogate_rejects_non_symmetry():
+ """Surrogate.symmetries rejects non-Symmetry members."""
+ with pytest.raises(TypeError, match="must be