diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ab1fa1c52..ca286f78fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,14 +13,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `identify_non_dominated_configurations` method to `Campaign` and `Objective` for determining the Pareto front - Interpoint constraints for continuous search spaces +- Symmetry classes (`PermutationSymmetry`, `MirrorSymmetry`, `DependencySymmetry`) + for expressing invariances and configuring surrogate data augmentation +- `Parameter.is_equivalent` method for structural parameter comparison ### Breaking Changes - `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples instead of a single tuple (needed for interpoint constraints) +- `df_apply_permutation_augmentation` has a different interface and now expects + permutation groups instead of column groups ### Fixed - `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to scalar conversion +- `DiscretePermutationInvarianceConstraint` no longer erroneously removes diagonal + points (e.g., where all permuted parameters have the same value) ### Changed - The `Campaign.allow_*` flag mechanism is now based on `AutoBool` logic, providing diff --git a/baybe/constraints/base.py b/baybe/constraints/base.py index 5c1a6d33ed..b38548efcc 100644 --- a/baybe/constraints/base.py +++ b/baybe/constraints/base.py @@ -37,10 +37,6 @@ class Constraint(ABC, SerialMixin): eval_during_modeling: ClassVar[bool] """Class variable encoding whether the condition is evaluated during modeling.""" - eval_during_augmentation: ClassVar[bool] = False - """Class variable encoding whether the constraint could be considered during data - augmentation.""" - numerical_only: ClassVar[bool] = False """Class variable encoding whether the constraint is valid only for numerical parameters.""" diff --git a/baybe/constraints/conditions.py b/baybe/constraints/conditions.py index a136f6e170..3dd52b5d78 100644 --- a/baybe/constraints/conditions.py +++ b/baybe/constraints/conditions.py @@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin): """Abstract base class for all conditions. Conditions always evaluate an expression regarding a single parameter. - Conditions are part of constraints, a constraint can have multiple conditions. + Conditions are part of constraints and symmetries. """ @abstractmethod diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 740e603f89..d291127e5e 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -29,6 +29,9 @@ if TYPE_CHECKING: import polars as pl + from baybe.symmetries.dependency import DependencySymmetry + from baybe.symmetries.permutation import PermutationSymmetry + @define class DiscreteExcludeConstraint(DiscreteConstraint): @@ -195,10 +198,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint): a single constraint. """ - # class variables - eval_during_augmentation: ClassVar[bool] = True - # See base class - # object variables conditions: list[Condition] = field() """The list of individual conditions.""" @@ -271,39 +270,56 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_bad + def to_symmetries( + self, use_data_augmentation: bool = True + ) -> tuple[DependencySymmetry, ...]: + """Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects. + + Create one symmetry object per dependency relationship, i.e., per + (parameter, condition, affected_parameters) triple. + + Args: + use_data_augmentation: Flag indicating whether the resulting symmetry + objects should apply data augmentation. ``True`` means that + measurement augmentation will be performed by replacing inactive + affected parameter values with all possible values. + + Returns: + A tuple of dependency symmetries, one for each dependency in the + constraint. + """ + from baybe.symmetries.dependency import DependencySymmetry + + return tuple( + DependencySymmetry( + parameter_name=p, + condition=c, + affected_parameter_names=aps, + use_data_augmentation=use_data_augmentation, + ) + for p, c, aps in zip( + self.parameters, self.conditions, self.affected_parameters, strict=True + ) + ) + @define class DiscretePermutationInvarianceConstraint(DiscreteConstraint): """Constraint class for declaring that a set of parameters is permutation invariant. More precisely, this means that, ``(val_from_param1, val_from_param2)`` is - equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense - to have this constraint with duplicated labels, this implementation also internally - applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`. + equivalent to ``(val_from_param2, val_from_param1)``. *Note:* This constraint is evaluated during creation. In the future it might also be evaluated during modeling to make use of the invariance. """ - # class variables - eval_during_augmentation: ClassVar[bool] = True - # See base class - # object variables dependencies: DiscreteDependenciesConstraint | None = field(default=None) """Dependencies connected with the invariant parameters.""" @override def get_invalid(self, data: pd.DataFrame) -> pd.Index: - # Get indices of entries with duplicate label entries. These will also be - # dropped by this constraint. - mask_duplicate_labels = pd.Series(False, index=data.index) - mask_duplicate_labels[ - DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid( - data - ) - ] = True - # Merge a permutation invariant representation of all affected parameters with # the other parameters and indicate duplicates. This ensures that variation in # other parameters is also accounted for. @@ -314,20 +330,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: data[self.parameters].apply(cast(Callable, frozenset), axis=1), ], axis=1, - ).loc[ - ~mask_duplicate_labels # only consider label-duplicate-free part - ] + ) mask_duplicate_permutations = df_eval.duplicated(keep="first") - # Indices of entries with label-duplicates - inds_duplicate_labels = data.index[mask_duplicate_labels] - - # Indices of duplicate permutations in the (already label-duplicate-free) data - inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations] + # Indices of duplicate permutations + inds_invalid = data.index[mask_duplicate_permutations] # If there are dependencies connected to the invariant parameters evaluate them # here and remove resulting duplicates with a DependenciesConstraint - inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations) if self.dependencies: self.dependencies.permutation_invariant = True inds_duplicate_independency_adjusted = self.dependencies.get_invalid( @@ -337,6 +347,32 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: return inds_invalid + def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry: + """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`. + + The constraint's parameters form the primary permutation group. If + dependencies are attached, their parameters are added as an additional + group that is permuted in lockstep. + + Args: + use_data_augmentation: Flag indicating whether the resulting symmetry + object should apply data augmentation. ``True`` means that + measurement augmentation will be performed by generating all + permutations of parameter values within each group. + + Returns: + The corresponding permutation symmetry. + """ + from baybe.symmetries.permutation import PermutationSymmetry + + groups = [self.parameters] + if self.dependencies: + groups.append(list(self.dependencies.parameters)) + return PermutationSymmetry( + permutation_groups=groups, + use_data_augmentation=use_data_augmentation, + ) + @define class DiscreteCustomConstraint(DiscreteConstraint): diff --git a/baybe/parameters/base.py b/baybe/parameters/base.py index 2d4df2bc77..9ae13e9a98 100644 --- a/baybe/parameters/base.py +++ b/baybe/parameters/base.py @@ -7,6 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar +import attrs import pandas as pd from attrs import define, field from attrs.converters import optional as optional_c @@ -88,6 +89,22 @@ def to_searchspace(self) -> SearchSpace: return SearchSpace.from_parameter(self) + def is_equivalent(self, other: Parameter) -> bool: + """Check if this parameter is equivalent to another, ignoring the name. + + Two parameters are considered equivalent if they have the same type and + all attributes are equal except for the name. + + Args: + other: The parameter to compare against. + + Returns: + ``True`` if the parameters are equivalent, ``False`` otherwise. + """ + if type(self) is not type(other): + return False + return attrs.evolve(self, name=other.name) == other + @abstractmethod def summary(self) -> dict: """Return a custom summarization of the parameter.""" diff --git a/baybe/parameters/categorical.py b/baybe/parameters/categorical.py index bffecd5682..7ebe673f76 100644 --- a/baybe/parameters/categorical.py +++ b/baybe/parameters/categorical.py @@ -13,13 +13,7 @@ from baybe.parameters.enum import CategoricalEncoding from baybe.parameters.validation import validate_unique_values from baybe.settings import active_settings -from baybe.utils.conversion import nonstring_to_tuple - - -def _convert_values(value, self, field) -> tuple[str, ...]: - """Sort and convert values for categorical parameters.""" - value = nonstring_to_tuple(value, self, field) - return tuple(sorted(value, key=lambda x: (str(type(x)), x))) +from baybe.utils.conversion import normalize_convertible2str_sequence def _validate_label_min_len(self, attr, value) -> None: @@ -38,8 +32,10 @@ class CategoricalParameter(_DiscreteLabelLikeParameter): # object variables _values: tuple[str | bool, ...] = field( alias="values", - converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore - validator=( + converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter + normalize_convertible2str_sequence, takes_self=True, takes_field=True + ), + validator=( # type: ignore[arg-type] # mypy: validator tuple validate_unique_values, deep_iterable( member_validator=(instance_of((str, bool)), _validate_label_min_len), diff --git a/baybe/parameters/numerical.py b/baybe/parameters/numerical.py index ba210de244..e56ca2c3f9 100644 --- a/baybe/parameters/numerical.py +++ b/baybe/parameters/numerical.py @@ -13,9 +13,10 @@ from baybe.exceptions import NumericalUnderflowError from baybe.parameters.base import ContinuousParameter, DiscreteParameter -from baybe.parameters.validation import validate_is_finite, validate_unique_values +from baybe.parameters.validation import validate_unique_values from baybe.settings import active_settings from baybe.utils.interval import InfiniteIntervalError, Interval +from baybe.utils.validation import validate_is_finite @define(frozen=True, slots=False) diff --git a/baybe/parameters/validation.py b/baybe/parameters/validation.py index 5c70ed5ecf..0367b2b30f 100644 --- a/baybe/parameters/validation.py +++ b/baybe/parameters/validation.py @@ -1,9 +1,7 @@ """Validation functionality for parameters.""" -from collections.abc import Sequence from typing import Any -import numpy as np from attrs.validators import gt, instance_of, lt @@ -28,18 +26,3 @@ def validate_decorrelation(obj: Any, attribute: Any, value: float) -> None: if isinstance(value, float): gt(0.0)(obj, attribute, value) lt(1.0)(obj, attribute, value) - - -def validate_is_finite( # noqa: DOC101, DOC103 - obj: Any, _: Any, value: Sequence[float] -) -> None: - """Validate that ``value`` contains no infinity/nan. - - Raises: - ValueError: If ``value`` contains infinity/nan. - """ - if not all(np.isfinite(value)): - raise ValueError( - f"Cannot assign the following values containing infinity/nan to " - f"parameter {obj.name}: {value}." - ) diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 4ac5c1eed2..2d06db2edc 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction return qLogNEHVI() if objective.is_multi_output else qLogEI() return self.acquisition_function + def _get_surrogate_for_augmentation(self) -> Surrogate | None: + """Get the Surrogate instance for augmentation/validation, if available.""" + from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping + + model = self._surrogate_model + if isinstance(model, Surrogate): + return model + if isinstance(model, CompositeSurrogate): + # All inner surrogates are copies of the same template + surrogates = model.surrogates + if isinstance(surrogates, _ReplicationMapping): + template = surrogates.template + if isinstance(template, Surrogate): + return template + return None + def get_surrogate( self, searchspace: SearchSpace, @@ -114,6 +130,13 @@ def _setup_botorch_acqf( f"{len(objective.targets)}-target multi-output context." ) + # Perform data augmentation if configured + surrogate_for_augmentation = self._get_surrogate_for_augmentation() + if surrogate_for_augmentation is not None: + measurements = surrogate_for_augmentation.augment_measurements( + measurements, searchspace.parameters + ) + surrogate = self.get_surrogate(searchspace, objective, measurements) self._botorch_acqf = acqf.to_botorch( surrogate, @@ -156,6 +179,13 @@ def recommend( validate_object_names(searchspace.parameters + objective.targets) + # Validate compatibility of surrogate symmetries with searchspace + surrogate_for_validation = self._get_surrogate_for_augmentation() + if surrogate_for_validation is not None: + for s in surrogate_for_validation.symmetries: + s.validate_searchspace_context(searchspace) + + # Experimental input validation if (measurements is None) or measurements.empty: raise NotImplementedError( f"Recommenders of type '{BayesianRecommender.__name__}' do not support " diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index 8b0da30c92..34e5f64047 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -381,11 +381,6 @@ def transform( return comp_rep - @property - def constraints_augmentable(self) -> tuple[Constraint, ...]: - """The searchspace constraints that can be considered during augmentation.""" - return tuple(c for c in self.constraints if c.eval_during_augmentation) - def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]: """Return parameters with the specified names. diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 205e32f703..2b85b6d8e1 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -4,12 +4,13 @@ import gc from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from enum import Enum, auto from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias import pandas as pd from attrs import define, field +from attrs.validators import deep_iterable, instance_of from joblib.hashing import hash from typing_extensions import override @@ -18,6 +19,7 @@ from baybe.parameters.base import Parameter from baybe.searchspace import SearchSpace from baybe.serialization.mixin import SerialMixin +from baybe.symmetries import Symmetry from baybe.utils.basic import classproperty from baybe.utils.conversion import to_string from baybe.utils.dataframe import handle_missing_values, to_tensor @@ -90,6 +92,14 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): """Class variable encoding whether or not the surrogate is multi-output compatible.""" + symmetries: tuple[Symmetry, ...] = field( + factory=tuple, + converter=tuple, + validator=deep_iterable(member_validator=instance_of(Symmetry)), + kw_only=True, + ) + """Symmetries to be considered by the surrogate model.""" + _searchspace: SearchSpace | None = field(init=False, default=None, eq=False) """The search space on which the surrogate operates. Available after fitting.""" @@ -115,6 +125,27 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): Scales a tensor containing target measurements in computational representation to make them digestible for the model-specific, scale-agnostic posterior logic.""" + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + """Apply data augmentation to measurements. + + Args: + measurements: A dataframe with measurements. + parameters: Optional parameter objects carrying additional information. + Only required by specific augmentation implementations. + + Returns: + A dataframe with the augmented measurements, including the original + ones. + """ + for s in self.symmetries: + measurements = s.augment_measurements(measurements, parameters) + + return measurements + @classproperty def is_available(cls) -> bool: """Indicates if the surrogate class is available in the Python environment. diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index c0148aca55..a2fddc7a33 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -142,9 +142,10 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: @override def _fit(self, train_x: Tensor, train_y: Tensor) -> None: - import botorch import gpytorch import torch + from botorch.fit import fit_gpytorch_mll + from botorch.models import SingleTaskGP from botorch.models.transforms import Normalize, Standardize # FIXME[typing]: It seems there is currently no better way to inform the type @@ -197,7 +198,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: likelihood.noise = torch.tensor([noise_prior[1]]) # construct and fit the Gaussian process - self._model = botorch.models.SingleTaskGP( + self._model = SingleTaskGP( train_x, train_y, input_transform=input_transform, @@ -218,7 +219,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: self._model.likelihood, self._model ) - botorch.fit.fit_gpytorch_mll(mll) + fit_gpytorch_mll(mll) @override def __str__(self) -> str: diff --git a/baybe/symmetries/__init__.py b/baybe/symmetries/__init__.py new file mode 100644 index 0000000000..abbeedd974 --- /dev/null +++ b/baybe/symmetries/__init__.py @@ -0,0 +1,13 @@ +"""Symmetry classes for expressing invariances of the modeling process.""" + +from baybe.symmetries.base import Symmetry +from baybe.symmetries.dependency import DependencySymmetry +from baybe.symmetries.mirror import MirrorSymmetry +from baybe.symmetries.permutation import PermutationSymmetry + +__all__ = [ + "DependencySymmetry", + "MirrorSymmetry", + "PermutationSymmetry", + "Symmetry", +] diff --git a/baybe/symmetries/base.py b/baybe/symmetries/base.py new file mode 100644 index 0000000000..780eae358d --- /dev/null +++ b/baybe/symmetries/base.py @@ -0,0 +1,89 @@ +"""Base class for symmetries.""" + +from __future__ import annotations + +import gc +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import pandas as pd +from attrs import define, field +from attrs.validators import instance_of + +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.serialization import SerialMixin + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class Symmetry(SerialMixin, ABC): + """Abstract base class for symmetries. + + A ``Symmetry`` is a concept that can be used to configure the modeling process in + the presence of invariances. + """ + + use_data_augmentation: bool = field( + default=True, validator=instance_of(bool), kw_only=True + ) + """Flag indicating whether data augmentation is to be used.""" + + @property + @abstractmethod + def parameter_names(self) -> tuple[str, ...]: + """The names of the parameters affected by the symmetry.""" + + def summary(self) -> dict: + """Return a custom summarization of the symmetry.""" + symmetry_dict = dict( + Type=self.__class__.__name__, + Affected_Parameters=self.parameter_names, + Data_Augmentation=self.use_data_augmentation, + ) + return symmetry_dict + + @abstractmethod + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + """Augment the given measurements according to the symmetry. + + Args: + measurements: The dataframe containing the measurements to be + augmented. + parameters: Optional parameter objects carrying additional information. + Only required by specific augmentation implementations. + + Returns: + The augmented dataframe including the original measurements. + """ + + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """Validate that the symmetry is compatible with the given searchspace. + + Args: + searchspace: The searchspace to validate against. + + Raises: + IncompatibleSearchSpaceError: If the symmetry affects parameters not + present in the searchspace. + """ + parameters_missing = set(self.parameter_names).difference( + searchspace.parameter_names + ) + if parameters_missing: + raise IncompatibleSearchSpaceError( + f"The symmetry of type '{self.__class__.__name__}' was set up with the " + f"following parameters that are not present in the search space: " + f"{parameters_missing}." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/dependency.py b/baybe/symmetries/dependency.py new file mode 100644 index 0000000000..9bd72ab8bf --- /dev/null +++ b/baybe/symmetries/dependency.py @@ -0,0 +1,169 @@ +"""Dependency symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable +from typing import TYPE_CHECKING, cast + +import numpy as np +import pandas as pd +from attrs import Converter, define, field +from attrs.validators import deep_iterable, ge, instance_of, min_len +from typing_extensions import override + +from baybe.constraints.conditions import Condition +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_dependency_augmentation +from baybe.utils.conversion import normalize_convertible2str_sequence +from baybe.utils.validation import validate_unique_values + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class DependencySymmetry(Symmetry): + """Class for representing dependency symmetries. + + A dependency symmetry expresses that certain parameters are dependent on another + parameter having a specific value. For instance, the situation "The value of + parameter y only matters if parameter x has the value 'on'.". In this scenario x + is the causing parameter and y depends on x. + """ + + _parameter_name: str = field(validator=instance_of(str), alias="parameter_name") + """The names of the causing parameter others are depending on.""" + + # object variables + condition: Condition = field(validator=instance_of(Condition)) + """The condition specifying the active range of the causing parameter.""" + + affected_parameter_names: tuple[str, ...] = field( + converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter + normalize_convertible2str_sequence, takes_self=True, takes_field=True + ), + validator=( + validate_unique_values, + deep_iterable( + member_validator=instance_of(str), iterable_validator=min_len(1) + ), + ), + ) + """The parameters affected by the dependency.""" + + n_discretization_points: int = field( + default=3, validator=(instance_of(int), ge(2)), kw_only=True + ) + """Number of points used when subsampling continuous parameter ranges.""" + + @override + @property + def parameter_names(self) -> tuple[str, ...]: + return (self._parameter_name,) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + if not self.use_data_augmentation: + return measurements + + if parameters is None: + raise ValueError( + f"A '{self.__class__.__name__}' requires parameter objects " + f"for data augmentation." + ) + + from baybe.parameters.base import DiscreteParameter + + # The 'causing' entry describes the parameters and the value + # for which one or more affected parameters become degenerate. + # 'cond' specifies for which values the affected parameter + # values are active, i.e. not degenerate. Hence, here we get the + # values that are not active, as rows containing them should be + # augmented. + param = next( + cast(DiscreteParameter, p) + for p in parameters + if p.name == self._parameter_name + ) + + causing_values = [ + x + for x, flag in zip( + param.values, + ~self.condition.evaluate(pd.Series(param.values)), + strict=True, + ) + if flag + ] + causing = (param.name, causing_values) + + # The 'affected' entry describes the affected parameters and the + # values they are allowed to take, which are all degenerate if + # the corresponding condition for the causing parameter is met. + affected: list[tuple[str, tuple[float, ...]]] = [] + for pn in self.affected_parameter_names: + p = next(p for p in parameters if p.name == pn) + if p.is_discrete: + # Use all values for augmentation + vals = cast(DiscreteParameter, p).values + else: + # Use linear subsample of parameter bounds interval for augmentation. + # Note: The original value will not necessarily be part of this. + vals = tuple( + np.linspace( + p.bounds.lower, # type: ignore[attr-defined] + p.bounds.upper, # type: ignore[attr-defined] + self.n_discretization_points, + ) + ) + affected.append((p.name, vals)) + + measurements = df_apply_dependency_augmentation(measurements, causing, affected) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + IncompatibleSearchSpaceError: If any of the affected parameters is + not present in the searchspace. + TypeError: If the causing parameter is not discrete. + """ + super().validate_searchspace_context(searchspace) + + # Affected parameters must be in the searchspace + parameters_missing = set(self.affected_parameter_names).difference( + searchspace.parameter_names + ) + if parameters_missing: + raise IncompatibleSearchSpaceError( + f"The symmetry of type '{self.__class__.__name__}' was set up " + f"with at least one parameter which is not present in the " + f"search space: {parameters_missing}." + ) + + # Causing parameter must be discrete + param = searchspace.get_parameters_by_name(self._parameter_name)[0] + if not param.is_discrete: + raise TypeError( + f"In a '{self.__class__.__name__}', the causing parameter must " + f"be discrete. However, the parameter '{param.name}' is of " + f"type '{param.__class__.__name__}' and is not discrete." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/mirror.py b/baybe/symmetries/mirror.py new file mode 100644 index 0000000000..9a7ac7d06f --- /dev/null +++ b/baybe/symmetries/mirror.py @@ -0,0 +1,86 @@ +"""Mirror symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import pandas as pd +from attrs import define, field +from attrs.validators import instance_of +from typing_extensions import override + +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_mirror_augmentation +from baybe.utils.validation import validate_is_finite + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +@define(frozen=True) +class MirrorSymmetry(Symmetry): + """Class for representing mirror symmetries. + + A mirror symmetry expresses that certain parameters can be inflected at a mirror + point without affecting the outcome of the model. For instance, when specified + for parameter ``x`` and mirror point ``c``, the symmetry expresses that + $f(..., c+x, ...) = f(..., c-x, ...)$. + """ + + _parameter_name: str = field(validator=instance_of(str), alias="parameter_name") + """The name of the single parameter affected by the symmetry.""" + + # object variables + mirror_point: float = field( + default=0.0, converter=float, validator=validate_is_finite, kw_only=True + ) + """The mirror point.""" + + @override + @property + def parameter_names(self) -> tuple[str]: + return (self._parameter_name,) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + + if not self.use_data_augmentation: + return measurements + + measurements = df_apply_mirror_augmentation( + measurements, self._parameter_name, mirror_point=self.mirror_point + ) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + TypeError: If the affected parameter is not numerical. + """ + super().validate_searchspace_context(searchspace) + + param = searchspace.get_parameters_by_name(self.parameter_names)[0] + if not param.is_numerical: + raise TypeError( + f"In a '{self.__class__.__name__}', the affected parameter must " + f"be numerical. However, the parameter '{param.name}' is of " + f"type '{param.__class__.__name__}' and is not numerical." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/symmetries/permutation.py b/baybe/symmetries/permutation.py new file mode 100644 index 0000000000..deab356f9e --- /dev/null +++ b/baybe/symmetries/permutation.py @@ -0,0 +1,158 @@ +"""Permutation symmetry.""" + +from __future__ import annotations + +import gc +from collections.abc import Iterable, Sequence +from itertools import combinations +from typing import TYPE_CHECKING, Any + +import pandas as pd +from attrs import define, field +from attrs.validators import deep_iterable, instance_of, min_len +from typing_extensions import override + +from baybe.symmetries.base import Symmetry +from baybe.utils.augmentation import df_apply_permutation_augmentation + +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.searchspace import SearchSpace + + +def _convert_groups( + groups: Sequence[Sequence[str]], +) -> tuple[tuple[str, ...], ...]: + """Convert nested sequences to a tuple of tuples, blocking bare strings.""" + if isinstance(groups, str): + raise ValueError( + "The 'permutation_groups' argument must be a sequence of sequences, " + "not a string." + ) + converted = [] + for g in groups: + if isinstance(g, str): + raise ValueError( + "Each element in 'permutation_groups' must be a sequence of " + "parameter names, not a string." + ) + converted.append(tuple(g)) + return tuple(converted) + + +@define(frozen=True) +class PermutationSymmetry(Symmetry): + """Class for representing permutation symmetries. + + A permutation symmetry expresses that certain parameters can be permuted without + affecting the outcome of the model. For instance, this is the case if + $f(x,y) = f(y,x)$. + """ + + permutation_groups: tuple[tuple[str, ...], ...] = field( + converter=_convert_groups, + validator=( + min_len(1), + deep_iterable( + member_validator=deep_iterable( + member_validator=instance_of(str), + iterable_validator=min_len(2), + ), + ), + ), + ) + """The permutation groups. Each group contains parameter names that are permuted + in lockstep. All groups must have the same length.""" + + @permutation_groups.validator + def _validate_permutation_groups( # noqa: DOC101, DOC103 + self, _: Any, groups: tuple[tuple[str, ...], ...] + ) -> None: + """Validate the permutation groups. + + Raises: + ValueError: If the groups have different lengths. + ValueError: If any group contains duplicate parameters. + ValueError: If any parameter name appears in multiple groups. + """ + # Ensure all groups have the same length + if len({len(g) for g in groups}) != 1: + lengths = {k + 1: len(g) for k, g in enumerate(groups)} + raise ValueError( + f"In a '{self.__class__.__name__}', all permutation groups " + f"must have the same length. Got group lengths: {lengths}." + ) + + # Ensure parameter names in each group are unique + for group in groups: + if len(set(group)) != len(group): + raise ValueError( + f"In a '{self.__class__.__name__}', all parameters being " + f"permuted with each other must be unique. However, the " + f"following group contains duplicates: {group}." + ) + + # Ensure there is no overlap between any permutation group + for a, b in combinations(groups, 2): + if overlap := set(a) & set(b): + raise ValueError( + f"In a '{self.__class__.__name__}', parameter names cannot " + f"appear in multiple permutation groups. However, the " + f"following parameter names appear in several groups: " + f"{overlap}." + ) + + @override + @property + def parameter_names(self) -> tuple[str, ...]: + return tuple(name for group in self.permutation_groups for name in group) + + @override + def augment_measurements( + self, + measurements: pd.DataFrame, + parameters: Iterable[Parameter] | None = None, + ) -> pd.DataFrame: + # See base class. + + if not self.use_data_augmentation: + return measurements + + measurements = df_apply_permutation_augmentation( + measurements, self.permutation_groups + ) + + return measurements + + @override + def validate_searchspace_context(self, searchspace: SearchSpace) -> None: + """See base class. + + Args: + searchspace: The searchspace to validate against. + + Raises: + ValueError: If parameters within a permutation group are not + equivalent (i.e., differ in type or specification). + """ + super().validate_searchspace_context(searchspace) + + # Ensure permuted parameters all have the same specification. + # Without this, it could be attempted to read in data that is not allowed + # for parameters that only allow a subset or different values compared to + # parameters they are being permuted with. + for group in self.permutation_groups: + params = searchspace.get_parameters_by_name(group) + ref = params[0] + for p in params[1:]: + if not ref.is_equivalent(p): + raise ValueError( + f"In a '{self.__class__.__name__}', all parameters " + f"within a permutation group must be equivalent. " + f"However, '{ref.name}' and '{p.name}' differ in " + f"their specification." + ) + + +# Collect leftover original slotted classes processed by `attrs.define` +gc.collect() diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py index b9fc2d14aa..6b1fd50d28 100644 --- a/baybe/utils/augmentation.py +++ b/baybe/utils/augmentation.py @@ -8,24 +8,29 @@ def df_apply_permutation_augmentation( df: pd.DataFrame, - column_groups: Sequence[Sequence[str]], + permutation_groups: Sequence[Sequence[str]], ) -> pd.DataFrame: """Augment a dataframe if permutation invariant columns are present. + Each group in ``permutation_groups`` contains the names of columns that are + permuted in lockstep. All groups must have the same length, and that length + must be at least 2 (otherwise there is nothing to permute). + Args: df: The dataframe that should be augmented. - column_groups: Sequences of permutation invariant columns. The n'th column in - each group will be permuted together with each n'th column in the other - groups. + permutation_groups: Groups of column names that are permuted in lockstep. + For example, ``[["A1", "A2"], ["B1", "B2"]]`` means that the columns + ``A1`` and ``A2`` are permuted, and ``B1`` and ``B2`` are permuted + in the same way. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Raises: - ValueError: If less than two column groups are given. - ValueError: If a column group is empty. - ValueError: If the column groups have differing amounts of entries. + ValueError: If no permutation groups are given. + ValueError: If any permutation group has fewer than two entries. + ValueError: If the permutation groups have differing amounts of entries. Examples: >>> df = pd.DataFrame({'A1':[1,2],'A2':[3,4], 'B1': [5, 6], 'B2': [7, 8]}) @@ -34,8 +39,8 @@ def df_apply_permutation_augmentation( 0 1 3 5 7 1 2 4 6 8 - >>> column_groups = [['A1'], ['A2']] - >>> dfa = df_apply_permutation_augmentation(df, column_groups) + >>> groups = [['A1', 'A2']] + >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 @@ -43,8 +48,8 @@ def df_apply_permutation_augmentation( 1 2 4 6 8 1 4 2 6 8 - >>> column_groups = [['A1', 'B1'], ['A2', 'B2']] - >>> dfa = df_apply_permutation_augmentation(df, column_groups) + >>> groups = [['A1', 'A2'], ['B1', 'B2']] + >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 @@ -53,36 +58,35 @@ def df_apply_permutation_augmentation( 1 4 2 8 6 """ # Validation - if len(column_groups) < 2: - raise ValueError( - "When augmenting permutation invariance, at least two column sequences " - "must be given." - ) + if len(permutation_groups) < 1: + raise ValueError("Permutation augmentation requires at least one group.") - if len({len(seq) for seq in column_groups}) != 1: + if len({len(seq) for seq in permutation_groups}) != 1: raise ValueError( - "Permutation augmentation can only work if the amount of columns in each " - "sequence is the same." + "Permutation augmentation can only work if all groups have the same " + "number of entries." ) - elif len(column_groups[0]) < 1: + + if len(permutation_groups[0]) < 2: raise ValueError( - "Permutation augmentation can only work if each column group has at " - "least one entry." + "Permutation augmentation can only work if each group has at " + "least two entries." ) # Augmentation Loop + n_positions = len(permutation_groups[0]) new_rows: list[pd.DataFrame] = [] - idx_permutation = list(permutations(range(len(column_groups)))) + idx_permutation = list(permutations(range(n_positions))) for _, row in df.iterrows(): # For each row in the original df, collect all its permutations to_add = [] for perm in idx_permutation: new_row = row.copy() - # Permute columns, this is done separately for each tuple of columns that - # belong together - for deps in map(list, zip(*column_groups)): - new_row[deps] = row[[deps[k] for k in perm]] + # Permute columns within each group according to the permutation + for group in permutation_groups: + cols = list(group) + new_row[cols] = row[[cols[k] for k in perm]] to_add.append(new_row) @@ -94,6 +98,69 @@ def df_apply_permutation_augmentation( return pd.concat(new_rows) +def df_apply_mirror_augmentation( + df: pd.DataFrame, + column: str, + *, + mirror_point: float = 0.0, +) -> pd.DataFrame: + """Augment a dataframe for a mirror invariant column. + + Args: + df: The dataframe that should be augmented. + column: The name of the affected column. + mirror_point: The point along which to mirror the values. Points that have + exactly this value will not be augmented. + + Returns: + The augmented dataframe containing the original one. Augmented row indices are + identical with the index of their original row. + + Examples: + >>> df = pd.DataFrame({'A':[1, 0, -2], 'B': [3, 4, 5]}) + >>> df + A B + 0 1 3 + 1 0 4 + 2 -2 5 + + >>> dfa = df_apply_mirror_augmentation(df, "A") + >>> dfa + A B + 0 1 3 + 0 -1 3 + 1 0 4 + 2 -2 5 + 2 2 5 + + >>> dfa = df_apply_mirror_augmentation(df, "A", mirror_point=1) + >>> dfa + A B + 0 1 3 + 1 0 4 + 1 2 4 + 2 -2 5 + 2 4 5 + """ + new_rows: list[pd.DataFrame] = [] + for _, row in df.iterrows(): + to_add = [row] # Always keep original row + + # Create the augmented row by mirroring the point at the mirror point. + # x_mirrored = mirror_point + (mirror_point - x) = 2*mirror_point - x + if row[column] != mirror_point: + row_new = row.copy() + row_new[column] = 2.0 * mirror_point - row[column] + to_add.append(row_new) + + # Store augmented rows, keeping the index of their original row + new_rows.append( + pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) + ) + + return pd.concat(new_rows) + + def df_apply_dependency_augmentation( df: pd.DataFrame, causing: tuple[str, Sequence], diff --git a/baybe/utils/conversion.py b/baybe/utils/conversion.py index ed5f622532..846fc925ad 100644 --- a/baybe/utils/conversion.py +++ b/baybe/utils/conversion.py @@ -42,6 +42,18 @@ def nonstring_to_tuple(x: Sequence[_T], self: type, field: Attribute) -> tuple[_ return tuple(x) +def normalize_convertible2str_sequence( + value: Sequence[str | bool], self: type, field: Attribute +) -> tuple[str | bool, ...]: + """Sort and convert values for a sequence of string-convertible types. + + If the sequence is a string itself, this is blocked to avoid unintended iteration + over its characters. + """ + value = nonstring_to_tuple(value, self, field) + return tuple(sorted(value, key=lambda x: (str(type(x)), x))) + + def _indent(text: str, amount: int = 3, ch: str = " ") -> str: """Indent a given text by a certain amount.""" padding = amount * ch diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 84b831a406..595eba3fc3 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -192,6 +192,10 @@ def create_fake_input( Raises: ValueError: If less than one row was requested. + + Note: + This function does not consider constraints and might provide unexpected or + invalid data if certain constraints are present. """ # Assert at least one fake entry is being generated if n_rows < 1: diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 93c87ab316..03a7462767 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -3,7 +3,8 @@ from __future__ import annotations import math -from collections.abc import Callable, Iterable +from collections import Counter +from collections.abc import Callable, Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any import numpy as np @@ -261,3 +262,34 @@ def preprocess_dataframe( else: targets = () return normalize_input_dtypes(df, [*searchspace.parameters, *targets]) + + +def validate_is_finite( # noqa: DOC101, DOC103 + _: Any, attribute: Attribute, value: float | Sequence[float] +) -> None: + """Validate that ``value`` contains no resp. is not infinity/nan. + + Raises: + ValueError: If ``value`` contains infinity/nan. + """ + if not np.isfinite(value).all(): + raise ValueError( + f"Cannot assign the following values containing infinity/nan to " + f"'{attribute.alias}': {value}." + ) + + +def validate_unique_values( # noqa: DOC101, DOC103 + _: Any, attribute: Attribute, value: Collection[str] +) -> None: + """Validate that there are no duplicates in ``value``. + + Raises: + ValueError: If there are duplicates in ``value``. + """ + duplicates = [item for item, count in Counter(value).items() if count > 1] + if duplicates: + raise ValueError( + f"Entries appearing multiple times: {duplicates}. " + f"All entries of '{attribute.alias}' must be unique." + ) diff --git a/docs/_static/symmetries/augmentation.svg b/docs/_static/symmetries/augmentation.svg new file mode 100644 index 0000000000..ac1ffefeeb --- /dev/null +++ b/docs/_static/symmetries/augmentation.svg @@ -0,0 +1,162 @@ + + + + + + + + + + + + + + + + + + + original points + + augmented points + guide line + + + + + + + + + + + x + y + + + + PermutationSymmetry + + + + + mirror_point + + + x + y + + + + + MirrorSymmetry + + + + + + + + + + + + DependencySymmetry with c(x) = x < p + discrete y + + + + p + + + + + + x + y + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y.values + + continuous y with n_discretization_points=4 + + p + + + + + + + x + y + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y.bounds.upper + y.bounds.lower + + + diff --git a/docs/scripts/build_examples.py b/docs/scripts/build_examples.py index bfd62db2d5..ddf9be318c 100644 --- a/docs/scripts/build_examples.py +++ b/docs/scripts/build_examples.py @@ -31,8 +31,12 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool): else: raise OSError("Destination directory exists but should not be removed.") - # Copy the examples folder in the destination directory - shutil.copytree("examples", destination_directory) + # Copy the examples folder in the destination directory. "__pycache__" might be + # present in the examples folder and needs to be ignored + def ignore_pycache(_, contents: list[str]): + return [item for item in contents if item == "__pycache__"] + + shutil.copytree("examples", destination_directory, ignore=ignore_pycache) # For the toctree of the top level example folder, we need to keep track of all # folders. We thus write the header here and populate it during the execution of the @@ -44,7 +48,7 @@ def build_examples(destination_directory: Path, dummy: bool, remove_dir: bool): # This list contains the order of the examples as we want to have them in the end. # The examples that should be the first ones are already included here and skipped - # later on. ALl other are just included. + # later on. All others are just included. ex_order = [ "Basics\n", "Searchspaces\n", diff --git a/docs/userguide/constraints.md b/docs/userguide/constraints.md index e66be3051c..535320d399 100644 --- a/docs/userguide/constraints.md +++ b/docs/userguide/constraints.md @@ -399,6 +399,7 @@ DiscreteDependenciesConstraint( ``` An end to end example can be found [here](../../examples/Constraints_Discrete/dependency_constraints). +For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation). ### DiscretePermutationInvarianceConstraint Permutation invariance, enabled by the @@ -480,6 +481,7 @@ DiscretePermutationInvarianceConstraint( The usage of `DiscretePermutationInvarianceConstraint` is also part of the [example on slot-based mixtures](../../examples/Mixtures/slot_based). +For more information about the possibility of data augmentation, see [here](surrogate_data_augmentation). ### DiscreteCardinalityConstraint Like its [continuous cousin](#ContinuousCardinalityConstraint), the diff --git a/docs/userguide/surrogates.md b/docs/userguide/surrogates.md index 781051573f..7464cb2866 100644 --- a/docs/userguide/surrogates.md +++ b/docs/userguide/surrogates.md @@ -6,6 +6,7 @@ the utilization of custom models. All surrogate models are based upon the genera [`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer learning, as indicated by the `supports_transfer_learning` attribute. + ## Available Models BayBE provides a comprehensive selection of surrogate models, empowering you to choose @@ -93,7 +94,6 @@ A noticeable difference to the replication approach is that manual assembly requ the exact set of target variables to be known at the time the object is created. - ## Extracting the Model for Advanced Study In principle, the surrogate model does not need to be a persistent object during @@ -117,6 +117,22 @@ shap_values = explainer(data) shap.plots.bar(shap_values) ~~~ + +(surrogate_data_augmentation)= +## Data Augmentation +In certain situations like [mixture modeling](/examples/Mixtures/slot_based), +symmetries are present. Data augmentation is a model-agnostic way of enabling the +surrogate model to learn such symmetries effectively, which might result in a better +performance, similar as e.g. for image classification models. BayBE +`Surrogate`[baybe.surrogates.base.Surrogate] models automatically perform data +augmentation if +{attr}`~baybe.surrogates.base.Surrogate.symmetries` with +`use_data_augmentation=True` are present. This means you can add a data point in +any acceptable representation and BayBE will train the model on this point plus +augmented points that can be generated from it. To see the effect in practice, refer to +[this example](/examples/Symmetries/permutation). + + ## Using Custom Models BayBE goes one step further by allowing you to incorporate custom models based on the diff --git a/docs/userguide/symmetries.md b/docs/userguide/symmetries.md new file mode 100644 index 0000000000..d31c43d64d --- /dev/null +++ b/docs/userguide/symmetries.md @@ -0,0 +1,51 @@ +# Symmetry +{class}`~baybe.symmetries.base.Symmetry` is a concept tied to the structure of the searchspace. +It is thus closely related to a {class}`~baybe.constraints.base.Constraint`, but has a +different purpose in BayBE. If the searchspace is symmetric in any sense, you can +exclude the degenerate parts via a constraint. But this would not change the modeling +process. The role of a {class}`~baybe.symmetries.base.Symmetry` is exactly this: Influence how +the surrogate model is constructed to include the knowledge about the symmetry. This +can be applied independently of constraints. For an example of the influence of +symmetries and constraints on the optimization of a permutation invariant function, +[see here](/examples/Symmetries/permutation). + +## Definitions +The following table summarizes available symmetries in BayBE: + +| Symmetry | Functional Definition | Corresponding Constraint | +|:-----------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------| +| {class}`~baybe.symmetries.permutation.PermutationSymmetry` | $f(x,y) = f(y,x)$ | {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` | +| {class}`~baybe.symmetries.dependency.DependencySymmetry` | $f(x,y) = \begin{cases}f(x,y) & \text{if }c(x) \\f(x) & \text{otherwise}\end{cases}$
where $c(x)$ is a condition that is either true or false | {class}`~baybe.constraints.discrete.DiscreteDependenciesConstraint` | +| {class}`~baybe.symmetries.mirror.MirrorSymmetry` | $f(x,y) = f(-x,y)$ | No constraint is available. Instead, the number range for that parameter can simply be restricted. | + +## Data Augmentation +This can be a powerful tool to improve the modeling process. Data augmentation +essentially changes the data that the model is fitted on by adding more points. The +augmented points are constructed such that they represent a symmetric point compared +with their original, which always corresponds to a different transformation depending +on which symmetry is responsible. + +If the surrogate model receives such augmented points, it can learn the symmetry. This +has the advantage that it can improve predictions for unseen points and is fully +model-agnostic. Downsides are increased training time and potential computational +challenges arising from a fit on substantially more points. It is thus possible to +control the data augmentation behavior of any {class}`~baybe.symmetries.base.Symmetry` by +setting its {attr}`~baybe.symmetries.base.Symmetry.use_data_augmentation` attribute +(`True` by default). + +Below we illustrate the effect of data augmentation for the different symmetries +supported by BayBE: + +![Symmetry and Data Augmentation](../_static/symmetries/augmentation.svg) + +## Invariant Kernels +Some machine learning models can be constructed with architectures that automatically +respect a symmetry, i.e. applying the model to an augmented point always produces the +same output as the original point by construction. + +For Gaussian processes, this can be achieved by applying special kernels. +```{admonition} Not Implemented Yet +:class: warning +Ideally, invariant kernels will be applied automatically when a corresponding symmetry has been +configured for the surrogate model GP. This feature is not implemented yet. +``` \ No newline at end of file diff --git a/docs/userguide/userguide.md b/docs/userguide/userguide.md index 059f6ca581..cefb6516f3 100644 --- a/docs/userguide/userguide.md +++ b/docs/userguide/userguide.md @@ -44,6 +44,7 @@ Serialization Settings Simulation Surrogates +Symmetries Targets Transformations Transfer Learning diff --git a/examples/Symmetries/Symmetries_Header.md b/examples/Symmetries/Symmetries_Header.md new file mode 100644 index 0000000000..361d774e3d --- /dev/null +++ b/examples/Symmetries/Symmetries_Header.md @@ -0,0 +1,3 @@ +# Symmetries + +These examples demonstrate the impact of Symmetries. \ No newline at end of file diff --git a/examples/Symmetries/permutation.py b/examples/Symmetries/permutation.py new file mode 100644 index 0000000000..a909aba0ad --- /dev/null +++ b/examples/Symmetries/permutation.py @@ -0,0 +1,212 @@ +# # Optimizing a Permutation-Invariant Function + +# In this example, we explore BayBE's capabilities for handling optimization problems +# with symmetry via automatic data augmentation and/or constraint. + +# ## Imports + +import os + +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.ticker import MaxNLocator + +from baybe import Campaign, Settings +from baybe.constraints import DiscretePermutationInvarianceConstraint +from baybe.parameters import NumericalDiscreteParameter +from baybe.recommenders import ( + BotorchRecommender, + TwoPhaseMetaRecommender, +) +from baybe.searchspace import SearchSpace +from baybe.simulation import simulate_scenarios +from baybe.surrogates import NGBoostSurrogate +from baybe.targets import NumericalTarget + +# ## Settings + +Settings(random_seed=1337).activate() +SMOKE_TEST = "SMOKE_TEST" in os.environ +N_MC_ITERATIONS = 2 if SMOKE_TEST else 100 +N_DOE_ITERATIONS = 2 if SMOKE_TEST else 50 + +# ## The Scenario + +# We will explore a 2-dimensional permutation-invariant function, i.e., it holds that +# $f(x,y) = f(y,x)$. The function was crafted to exhibit no additional mirror symmetry +# (a common way of also resulting in permutation invariance) and have multiple minima. +# In practice, permutation invariance can arise e.g. for +# [mixtures when modeled with a slot-based approach](/examples/Mixtures/slot_based). + +# There are several ways to handle such symmetries. The simplest one is +# to augment your data. In the case of permutation invariance, augmentation means for +# each measurement $(x,y)$ you also add a measurement with switched values, i.e., $(y,x)§. +# This has the advantage that it is fully model-agnostic, but might +# come at the expense of increased training time and efficiency due to the larger amount +# of effective training points. + + +LBOUND = -2.0 +UBOUND = 2.0 + + +def lookup(df: pd.DataFrame) -> pd.DataFrame: + """A lookup modeling a permutation-invariant 2D function with multiple minima.""" + x = df["x"].values + y = df["y"].values + result = ( + (x - y) ** 2 + + (x**3 + y**3) + + ((x**2 - 1) ** 2 + (y**2 - 1) ** 2) + + np.sin(3 * (x + y)) ** 2 + + np.sin(3 * np.abs(x - y)) ** 2 + ) + + df_z = pd.DataFrame({"f": result}, index=df.index) + return df_z + + +x = np.linspace(LBOUND, UBOUND, 25) +y = np.linspace(LBOUND, UBOUND, 25) +xx, yy = np.meshgrid(x, y) +df_plot = lookup(pd.DataFrame({"x": xx.ravel(), "y": yy.ravel()})) +zz = df_plot["f"].values.reshape(xx.shape) +line_vals = np.linspace(LBOUND, UBOUND, 2) + +# fmt: off +fig, axs = plt.subplots(1, 2, figsize=(15, 6)) +contour = axs[0].contourf(xx, yy, zz, levels=50, cmap="viridis") +fig.colorbar(contour, ax=axs[0]) +axs[0].plot(line_vals, line_vals, "r--", alpha=0.5, linewidth=2) +axs[0].set_title("Ground Truth: $f(x, y)$ = $f(y, x)$ (Permutation Invariant)") +axs[0].set_xlabel("x") +axs[0].set_ylabel("y"); +# fmt: on + +# The plots can be found at the bottom of this file. +# The first subplot shows the function we want to minimize. The dashed red line +# illustrates the permutation invariance, which is similar to a mirror-symmetry, just +# not along any of the parameter axis but along the diagonal. We can also see several +# local minima. + +# Such a situation can be challenging for optimization algorithms if no information +# about the invariance is considered. For instance, if no +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` was used +# at all, BayBE would search for the optima across the entire 2D space. But it is clear +# that the search can be restricted to the lower (or equivalently the upper) triangle +# of the searchspace. This is exactly what +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` does: +# Remove entries that are "duplicated" in the sense of already being represented by +# another invariant point. + +# If the surrogate is additionally configured with `symmetries` that use +# `use_data_augmentation=True`, the model will be fit with an extended set of points, +# including augmented ones. So as a user, you don't have to generate permutations and +# add them manually. Depending on the surrogate model, this might have different +# impacts. For example, we can expect a strong effect for tree-based models because their splits are +# always parallel to the parameter axes. Thus, without augmented measurements, it is +# easy to fall into suboptimal splits and overfit. We illustrate this by using the +# {class}`~baybe.surrogates.ngboost.NGBoostSurrogate`. + +# ## The Optimization Problem + +p1 = NumericalDiscreteParameter("x", np.linspace(LBOUND, UBOUND, 51)) +p2 = NumericalDiscreteParameter("y", np.linspace(LBOUND, UBOUND, 51)) +objective = NumericalTarget("f", minimize=True).to_objective() + +# We set up a constrained and an unconstrained (plain) searchspace to demonstrate the +# impact of the constraint on optimization performance. + +constraint = DiscretePermutationInvarianceConstraint(["x", "y"]) +searchspace_plain = SearchSpace.from_product([p1, p2]) +searchspace_constrained = SearchSpace.from_product([p1, p2], [constraint]) + +print("Number of Points in the Searchspace") +print(f"{'Without Constraint:':<35} {len(searchspace_plain.discrete.exp_rep)}") +print(f"{'With Constraint:':<35} {len(searchspace_constrained.discrete.exp_rep)}") + +# We can see that the unconstrained searchspace has roughly twice as many points +# compared to the constrained one. This is expected, as the +# {class}`~baybe.constraints.discrete.DiscretePermutationInvarianceConstraint` +# effectively models only one half of the parameter triangle. Note that the factor is +# not exactly 2 due to the (still included) points on the diagonal. + +# BayBE can automatically perform this augmentation if configured to do so. +# Specifically, surrogate models have the +# {attr}`~baybe.surrogates.base.Surrogate.symmetries` attribute. If any of +# these symmetries has `use_data_augmentation=True` (enabled by default), +# BayBE will automatically augment measurements internally before performing the model +# fit. To construct symmetries quickly, we use the `to_symmetry` method of the +# constraint. + +symmetry = constraint.to_symmetry(use_data_augmentation=True) +recommender_plain = TwoPhaseMetaRecommender( + recommender=BotorchRecommender(surrogate_model=NGBoostSurrogate()) +) +recommender_symmetric = TwoPhaseMetaRecommender( + recommender=BotorchRecommender( + surrogate_model=NGBoostSurrogate(symmetries=[symmetry]) + ) +) + +# The combination of constraint and augmentation settings results in four different +# campaigns: + +campaign_plain = Campaign(searchspace_plain, objective, recommender_plain) +campaign_c = Campaign(searchspace_constrained, objective, recommender_plain) +campaign_s = Campaign(searchspace_plain, objective, recommender_symmetric) +campaign_cs = Campaign(searchspace_constrained, objective, recommender_symmetric) + +# ## Simulating the Optimization Loop + + +scenarios = { + "Unconstrained, Unsymmetric": campaign_plain, + "Constrained, Unsymmetric": campaign_c, + "Unconstrained, Symmetric": campaign_s, + "Constrained, Symmetric": campaign_cs, +} + +results = simulate_scenarios( + scenarios, + lookup, + n_doe_iterations=N_DOE_ITERATIONS, + n_mc_iterations=N_MC_ITERATIONS, +).rename( + columns={ + "f_CumBest": "$f(x,y)$ (cumulative best)", + "Num_Experiments": "# Experiments", + } +) + +# ## Results + +# Let us visualize the optimization process in the second subplot: + +sns.lineplot( + data=results, + x="# Experiments", + y="$f(x,y)$ (cumulative best)", + hue="Scenario", + marker="o", + ax=axs[1], +) +axs[1].xaxis.set_major_locator(MaxNLocator(integer=True)) +axs[1].set_ylim(axs[1].get_ylim()[0], 3) +axs[1].set_title("Minimization Performance") +plt.tight_layout() +plt.show() + +# We find that the campaigns utilizing the permutation invariance constraint +# perform better than the ones without. This can be attributed simply to the reduced +# number of searchspace points they operate on. However, this effect is rather minor +# compared to the effect of symmetry. + +# Furthermore, there is a strong impact on whether data augmentation is used or not, +# the effect we expected for a tree-based surrogate model. Indeed, the campaign with +# constraint but without augmentation is barely better than the campaign not utilizing +# the constraint at all. Conversely, the data-augmented campaign has a clearly superior +# performance. The best result is achieved by using both constraints and data +# augmentation. diff --git a/examples/Symmetries/permutation.svg b/examples/Symmetries/permutation.svg new file mode 100644 index 0000000000..ab17b0ac9a --- /dev/null +++ b/examples/Symmetries/permutation.svg @@ -0,0 +1,7525 @@ + + + + + + + + 2026-03-20T19:16:57.550802 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/conftest.py b/tests/conftest.py index 882204c156..5cd9537b67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -242,7 +242,7 @@ def fixture_parameters( NumericalDiscreteParameter( name="Fraction_1", values=tuple(np.linspace(0, 100, n_grid_points)), - tolerance=0.2, + tolerance=0.5, ), NumericalDiscreteParameter( name="Fraction_2", diff --git a/tests/hypothesis_strategies/conditions.py b/tests/hypothesis_strategies/conditions.py new file mode 100644 index 0000000000..fe90efdaa8 --- /dev/null +++ b/tests/hypothesis_strategies/conditions.py @@ -0,0 +1,24 @@ +"""Hypothesis strategies for conditions.""" + +from typing import Any + +import hypothesis.strategies as st + +from baybe.constraints import SubSelectionCondition, ThresholdCondition +from tests.hypothesis_strategies.basic import finite_floats + + +def sub_selection_conditions(superset: list[Any] | None = None): + """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" + if superset is None: + element_strategy = st.text() + else: + element_strategy = st.sampled_from(superset) + return st.builds( + SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1) + ) + + +def threshold_conditions(): + """Generate :class:`baybe.constraints.conditions.ThresholdCondition`.""" + return st.builds(ThresholdCondition, threshold=finite_floats()) diff --git a/tests/hypothesis_strategies/constraints.py b/tests/hypothesis_strategies/constraints.py index e1f1014833..dfc065d817 100644 --- a/tests/hypothesis_strategies/constraints.py +++ b/tests/hypothesis_strategies/constraints.py @@ -1,14 +1,11 @@ """Hypothesis strategies for constraints.""" from functools import partial -from typing import Any import hypothesis.strategies as st from hypothesis import assume from baybe.constraints.conditions import ( - SubSelectionCondition, - ThresholdCondition, _valid_logic_combiners, ) from baybe.constraints.continuous import ( @@ -26,22 +23,10 @@ from baybe.parameters.base import DiscreteParameter from baybe.parameters.numerical import NumericalDiscreteParameter from tests.hypothesis_strategies.basic import finite_floats - - -def sub_selection_conditions(superset: list[Any] | None = None): - """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" - if superset is None: - element_strategy = st.text() - else: - element_strategy = st.sampled_from(superset) - return st.builds( - SubSelectionCondition, st.lists(element_strategy, unique=True, min_size=1) - ) - - -def threshold_conditions(): - """Generate :class:`baybe.constraints.conditions.ThresholdCondition`.""" - return st.builds(ThresholdCondition, threshold=finite_floats()) +from tests.hypothesis_strategies.conditions import ( + sub_selection_conditions, + threshold_conditions, +) @st.composite diff --git a/tests/hypothesis_strategies/symmetries.py b/tests/hypothesis_strategies/symmetries.py new file mode 100644 index 0000000000..31e45aaa40 --- /dev/null +++ b/tests/hypothesis_strategies/symmetries.py @@ -0,0 +1,121 @@ +"""Hypothesis strategies for symmetries.""" + +import hypothesis.strategies as st +from hypothesis import assume + +from baybe.parameters.base import Parameter +from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry +from tests.hypothesis_strategies.basic import finite_floats +from tests.hypothesis_strategies.conditions import ( + sub_selection_conditions, + threshold_conditions, +) + + +@st.composite +def mirror_symmetries(draw: st.DrawFn, parameter_pool: list[Parameter] | None = None): + """Generate :class:`baybe.symmetries.MirrorSymmetry`.""" + if parameter_pool is None: + parameter_name = draw(st.text(min_size=1)) + else: + parameter = draw(st.sampled_from(parameter_pool)) + assume(parameter.is_numerical) + parameter_name = parameter.name + + return MirrorSymmetry( + parameter_name=parameter_name, + use_data_augmentation=draw(st.booleans()), + mirror_point=draw(finite_floats()), + ) + + +@st.composite +def permutation_symmetries( + draw: st.DrawFn, parameter_pool: list[Parameter] | None = None +): + """Generate :class:`baybe.symmetries.PermutationSymmetry`.""" + if parameter_pool is None: + parameter_names_pool = draw( + st.lists(st.text(min_size=1), unique=True, min_size=2) + ) + else: + parameter_names_pool = [p.name for p in parameter_pool] + + # Draw the first (required) group with at least 2 elements + first_group = draw( + st.lists(st.sampled_from(parameter_names_pool), min_size=2, unique=True).map( + tuple + ) + ) + group_size = len(first_group) + + # Determine how many additional groups we can have + remaining_names = [n for n in parameter_names_pool if n not in first_group] + max_additional = len(remaining_names) // group_size if group_size > 0 else 0 + n_additional = draw(st.integers(min_value=0, max_value=max_additional)) + + # Draw additional groups of the same size from remaining names + all_groups = [first_group] + for _ in range(n_additional): + group = draw( + st.lists( + st.sampled_from(remaining_names), + unique=True, + min_size=group_size, + max_size=group_size, + ).map(tuple) + ) + # Ensure no overlap with existing groups + for existing in all_groups: + assume(not set(group) & set(existing)) + remaining_names = [n for n in remaining_names if n not in group] + all_groups.append(group) + + return PermutationSymmetry( + permutation_groups=all_groups, + use_data_augmentation=draw(st.booleans()), + ) + + +@st.composite +def dependency_symmetries( + draw: st.DrawFn, parameter_pool: list[Parameter] | None = None +): + """Generate :class:`baybe.symmetries.DependencySymmetry`.""" + if parameter_pool is None: + parameter_name = draw(st.text(min_size=1)) + affected_strat = st.lists( + st.text(min_size=1).filter(lambda x: x != parameter_name), + min_size=1, + unique=True, + ).map(tuple) + else: + parameter = draw(st.sampled_from(parameter_pool)) + assume(parameter.is_discrete) + parameter_name = parameter.name + affected_strat = st.lists( + st.sampled_from(parameter_pool) + .filter(lambda x: x.name != parameter_name) + .map(lambda x: x.name), + unique=True, + min_size=1, + max_size=len(parameter_pool) - 1, + ).map(tuple) + + return DependencySymmetry( + parameter_name=parameter_name, + condition=draw(st.one_of(threshold_conditions(), sub_selection_conditions())), + affected_parameter_names=draw(affected_strat), + n_discretization_points=draw(st.integers(min_value=2)), + use_data_augmentation=draw(st.booleans()), + ) + + +symmetries = st.one_of( + [ + mirror_symmetries(), + permutation_symmetries(), + dependency_symmetries(), + ] +) +"""A strategy that generates symmetries.""" diff --git a/tests/serialization/test_symmetry_serialization.py b/tests/serialization/test_symmetry_serialization.py new file mode 100644 index 0000000000..e1f6ae9373 --- /dev/null +++ b/tests/serialization/test_symmetry_serialization.py @@ -0,0 +1,28 @@ +"""Symmetry serialization tests.""" + +import hypothesis.strategies as st +import pytest +from hypothesis import given +from pytest import param + +from tests.hypothesis_strategies.symmetries import ( + dependency_symmetries, + mirror_symmetries, + permutation_symmetries, +) +from tests.serialization.utils import assert_roundtrip_consistency + + +@pytest.mark.parametrize( + "strategy", + [ + param(mirror_symmetries(), id="MirrorSymmetry"), + param(permutation_symmetries(), id="PermutationSymmetry"), + param(dependency_symmetries(), id="DependencySymmetry"), + ], +) +@given(data=st.data()) +def test_roundtrip(strategy: st.SearchStrategy, data: st.DataObject): + """A serialization roundtrip yields an equivalent object.""" + symmetry = data.draw(strategy) + assert_roundtrip_consistency(symmetry) diff --git a/tests/test_measurement_augmentation.py b/tests/test_measurement_augmentation.py new file mode 100644 index 0000000000..fa7e76c289 --- /dev/null +++ b/tests/test_measurement_augmentation.py @@ -0,0 +1,143 @@ +"""Tests for augmentation of measurements.""" + +import math +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +from attrs import evolve +from pandas.testing import assert_frame_equal + +from baybe.acquisition import qLogEI +from baybe.constraints import ( + DiscretePermutationInvarianceConstraint, + ThresholdCondition, +) +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.recommenders import BotorchRecommender +from baybe.searchspace import SearchSpace +from baybe.symmetries import DependencySymmetry, MirrorSymmetry +from baybe.utils.dataframe import create_fake_input + + +@pytest.mark.parametrize("mirror_aug", [True, False], ids=["mirror", "nomirror"]) +@pytest.mark.parametrize("perm_aug", [True, False], ids=["perm", "noperm"]) +@pytest.mark.parametrize("dep_aug", [True, False], ids=["dep", "nodep"]) +@pytest.mark.parametrize( + "constraint_names", [["Constraint_11", "Constraint_7"]], ids=["c"] +) +@pytest.mark.parametrize( + "parameter_names", + [ + [ + "Solvent_1", + "Solvent_2", + "Solvent_3", + "Fraction_1", + "Fraction_2", + "Fraction_3", + "Num_disc_2", + ] + ], + ids=["p"], +) +def test_measurement_augmentation( + parameters, + surrogate_model, + objective, + constraints, + dep_aug, + perm_aug, + mirror_aug, +): + """Measurement augmentation is performed if configured.""" + original_to_botorch = qLogEI.to_botorch + called_args_list = [] + + def spy(self, *args, **kwargs): + called_args_list.append((args, kwargs)) + return original_to_botorch(self, *args, **kwargs) + + with patch.object(qLogEI, "to_botorch", side_effect=spy, autospec=True): + # Basic setup + c_perm = next( + c + for c in constraints + if isinstance(c, DiscretePermutationInvarianceConstraint) + ) + c_dep = c_perm.dependencies + s_perm = c_perm.to_symmetry(perm_aug) + s_deps = c_dep.to_symmetries(dep_aug) # this is a tuple of multiple + s_mirror = MirrorSymmetry("Num_disc_2", use_data_augmentation=mirror_aug) + searchspace = SearchSpace.from_product(parameters, constraints) + surrogate = evolve(surrogate_model, symmetries=[*s_deps, s_perm, s_mirror]) + recommender = BotorchRecommender( + surrogate_model=surrogate, acquisition_function=qLogEI() + ) + + # Perform call and watch measurements + measurements = create_fake_input(parameters, objective.targets, 5) + recommender.recommend(1, searchspace, objective, measurements) + measurements_passed = called_args_list[0][0][3] # take 4th arg from first call + + # Create expectation + # We calculate how many degenerate points the augmentation should create: + # - n_dep: Product of the number of active values for all affected parameters + # - n_perm: Number of permutations possible + # - n_mirror: 2 if the row is not on the mirror point, else 1 + # - If augmentation is turned off, the corresponding factor becomes 1 + # We expect a given row to produce n_perm * (n_dep^k) * n_mirror points, where + # k is the number of "Fraction_*" parameters having the "causing" value 0.0. The + # total number of expected points after augmentation is the sum over the + # expectations for all rows. + dep_affected = [p for p in parameters if p.name in c_dep.affected_parameters[0]] + n_dep = math.prod(len(p.active_values) for p in dep_affected) if dep_aug else 1 + n_perm = ( # number of permutations + math.prod(range(1, len(c_perm.parameters) + 1)) if perm_aug else 1 + ) + n_expected = 0 + for _, row in measurements.iterrows(): + n_mirror = ( + 2 if (mirror_aug and row["Num_disc_2"] != s_mirror.mirror_point) else 1 + ) + k = row[c_dep.parameters].eq(0).sum() + n_expected += n_perm * np.pow(n_dep, k) * n_mirror + + # Check expectation + if any([dep_aug, perm_aug, mirror_aug]): + assert len(measurements_passed) == n_expected + else: + assert_frame_equal(measurements, measurements_passed) + + +@pytest.mark.parametrize("n_points", [2, 5]) +@pytest.mark.parametrize("mixed", [True, False]) +def test_continuous_dependency_augmentation(n_points, mixed): + """Dependency augmentation with continuous affected parameters works correctly.""" + df = pd.DataFrame({"n1": [1, 0], "cat1": ["a", "b"], "c1": [1, 2]}) + ps = [ + NumericalDiscreteParameter("n1", (0, 0.5, 1.0)), + CategoricalParameter("cat1", ("a", "b", "c")), + NumericalContinuousParameter("c1", (0, 10)), + ] + + # Model "Affected parameters only matter if n1 is > 0" + s = DependencySymmetry( + "n1", + condition=ThresholdCondition(0.0, ">"), + affected_parameter_names=("cat1", "c1") if mixed else ("c1",), + n_discretization_points=n_points, + ) + dfa = s.augment_measurements(df, ps) + + # Calculate expectation. The first row is not affected and contributes one point. + # The second row is degenerate and contributes `n_points` values for c1 and 3 + # values for cat1 (if part of the symmetry). + n_expected = 1 + n_points * (3 if mixed else 1) + + assert len(dfa) == n_expected diff --git a/tests/utils/test_augmentation.py b/tests/utils/test_augmentation.py index 80ffff8a2a..4dd8a14d33 100644 --- a/tests/utils/test_augmentation.py +++ b/tests/utils/test_augmentation.py @@ -22,7 +22,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -41,7 +41,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -60,7 +60,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -80,7 +80,7 @@ }, "index": [0, 1], }, - [["A"], ["B"]], + [["A", "B"]], { "data": { "A": [1, 2, 1, 2], @@ -91,7 +91,7 @@ }, id="2inv+degen_target+degen", ), - param( # 3 invariant groups with 1 entry each + param( # 3 invariant cols in one group { "data": { "A": [1, 1], @@ -101,7 +101,7 @@ }, "index": [0, 1], }, - [["A"], ["B"], ["C"]], + [["A", "B", "C"]], { "data": { "A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5], @@ -113,7 +113,7 @@ }, id="3inv_1add", ), - param( # 2 groups with 2 entries each, 2 additional columns + param( # 2 lockstep groups with 2 entries each, 2 additional columns { "data": { "Slot1": ["s1", "s2"], @@ -125,7 +125,7 @@ }, "index": [0, 1], }, - [["Slot1", "Frac1"], ["Slot2", "Frac2"]], + [["Slot1", "Slot2"], ["Frac1", "Frac2"]], { "data": { "Slot1": ["s1", "s2", "s2", "s4"], @@ -139,7 +139,7 @@ }, id="2inv_2dependent_2add", ), - param( # 2 groups with 3 entries each, 1 additional column + param( # 3 lockstep groups with 2 entries each, 1 additional column { "data": { "Slot1": ["s1", "s2"], @@ -152,7 +152,7 @@ }, "index": [0, 1], }, - [["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]], + [["Slot1", "Slot2"], ["Frac1", "Frac2"], ["Temp1", "Temp2"]], { "data": { "Slot1": ["s1", "s2", "s2", "s4"], @@ -187,10 +187,9 @@ def test_df_permutation_aug(content, col_groups, content_expected): @pytest.mark.parametrize( ("col_groups", "msg"), [ - param([], "at least two column sequences", id="no_groups"), - param([["A"]], "at least two column sequences", id="just_one_group"), - param([["A"], ["B", "C"]], "the amount of columns in", id="different_lengths"), - param([[], []], "each column group has", id="empty_group"), + param([], "at least one group", id="no_groups"), + param([["A"]], "at least two entries", id="group_too_small"), + param([["A", "B"], ["C"]], "same number of entries", id="different_lengths"), ], ) def test_df_permutation_aug_invalid(col_groups, msg): diff --git a/tests/validation/test_symmetry_validation.py b/tests/validation/test_symmetry_validation.py new file mode 100644 index 0000000000..b4093eb77f --- /dev/null +++ b/tests/validation/test_symmetry_validation.py @@ -0,0 +1,291 @@ +"""Validation tests for symmetry.""" + +import numpy as np +import pandas as pd +import pytest +from pytest import param + +from baybe.constraints import ThresholdCondition +from baybe.exceptions import IncompatibleSearchSpaceError +from baybe.parameters import ( + CategoricalParameter, + NumericalContinuousParameter, + NumericalDiscreteParameter, +) +from baybe.recommenders import BotorchRecommender +from baybe.searchspace import SearchSpace +from baybe.surrogates import GaussianProcessSurrogate +from baybe.symmetries import DependencySymmetry, MirrorSymmetry, PermutationSymmetry +from baybe.targets import NumericalTarget +from baybe.utils.dataframe import create_fake_input + +valid_config_mirror = {"parameter_name": "n1"} +valid_config_perm = { + "permutation_groups": [["cat1", "cat2"], ["n1", "n2"]], +} +valid_config_dep = { + "parameter_name": "n1", + "condition": ThresholdCondition(0.0, ">="), + "affected_parameter_names": ["n2", "cat1"], +} + + +@pytest.mark.parametrize( + "cls, config, error, msg", + [ + param( + MirrorSymmetry, + valid_config_mirror | {"mirror_point": np.inf}, + ValueError, + "values containing infinity/nan to 'mirror_point': inf", + id="mirror_nonfinite", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1", "cat1"]]}, + ValueError, + r"the following group contains duplicates", + id="perm_not_unique", + ), + param( + PermutationSymmetry, + { + "permutation_groups": [ + ["cat1", "cat2", "cat3"], + ["n1", "n2", "n3", "n4"], + ] + }, + ValueError, + "must have the same length", + id="perm_different_lengths", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1", "cat2"], ["cat1", "n2"]]}, + ValueError, + r"following parameter names appear in several groups", + id="perm_overlap", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["cat1"]]}, + ValueError, + "must be >= 2", + id="perm_group_too_small", + ), + param( + PermutationSymmetry, + {"permutation_groups": []}, + ValueError, + "must be >= 1", + id="perm_no_groups", + ), + param( + PermutationSymmetry, + {"permutation_groups": [[1, 2]]}, + TypeError, + "must be ", + id="perm_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"parameter_name": 1}, + TypeError, + "must be ", + id="dep_param_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"condition": 1}, + TypeError, + "must be ", + id="dep_wrong_cond_type", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": []}, + ValueError, + "Length of 'affected_parameter_names' must be >= 1", + id="dep_affected_empty", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": [1]}, + TypeError, + "must be ", + id="dep_affected_wrong_type", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": ["a1", "a1"]}, + ValueError, + r"Entries appearing multiple times: \['a1'\].", + id="dep_affected_not_unique", + ), + param( + PermutationSymmetry, + {"permutation_groups": "abc"}, + ValueError, + "must be a sequence of sequences, not a string", + id="perm_groups_bare_string", + ), + param( + PermutationSymmetry, + {"permutation_groups": ["abc", "def"]}, + ValueError, + "must be a sequence of parameter names, not a string", + id="perm_groups_inner_bare_string", + ), + param( + MirrorSymmetry, + {"parameter_name": 123}, + TypeError, + "must be ", + id="mirror_param_not_str", + ), + param( + DependencySymmetry, + valid_config_dep | {"n_discretization_points": 3.5}, + TypeError, + "must be ", + id="dep_n_discretization_not_int", + ), + param( + DependencySymmetry, + valid_config_dep | {"n_discretization_points": 1}, + ValueError, + "must be >= 2", + id="dep_n_discretization_too_small", + ), + param( + DependencySymmetry, + valid_config_dep | {"affected_parameter_names": "abc"}, + ValueError, + "must be a sequence but cannot be a string", + id="dep_affected_bare_string", + ), + param( + PermutationSymmetry, + {"permutation_groups": [["a", "b"]], "use_data_augmentation": 1}, + TypeError, + "must be ", + id="use_aug_not_bool", + ), + ], +) +def test_configuration(cls, config, error, msg): + """Invalid configurations raise an expected error.""" + with pytest.raises(error, match=msg): + cls(**config) + + +_parameters = [ + NumericalDiscreteParameter("n1", (-1, 0, 1)), + NumericalDiscreteParameter("n2", (-1, 0, 1)), + NumericalContinuousParameter("n1_not_discrete", (0.0, 10.0)), + NumericalContinuousParameter("n2_not_discrete", (0.0, 10.0)), + NumericalContinuousParameter("c1", (0.0, 10.0)), + NumericalContinuousParameter("c2", (0.0, 10.0)), + CategoricalParameter("cat1", ("a", "b", "c")), + CategoricalParameter("cat1_altered", ("a", "b")), + CategoricalParameter("cat2", ("a", "b", "c")), +] + + +@pytest.fixture +def searchspace(parameter_names): + ps = tuple(p for p in _parameters if p.name in parameter_names) + return SearchSpace.from_product(ps) + + +@pytest.mark.parametrize( + "parameter_names, symmetry, error, msg", + [ + param( + ["cat1"], + MirrorSymmetry(parameter_name="cat1"), + TypeError, + "'cat1' is of type 'CategoricalParameter' and is not numerical", + id="mirror_not_numerical", + ), + param( + ["n1"], + MirrorSymmetry(parameter_name="n2"), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="mirror_param_missing", + ), + param( + ["n2", "cat1"], + DependencySymmetry(**valid_config_dep), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="dep_causing_missing", + ), + param( + ["n1", "cat1"], + DependencySymmetry(**valid_config_dep), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="dep_affected_missing", + ), + param( + ["n1_not_discrete", "n2", "cat1"], + DependencySymmetry( + **valid_config_dep | {"parameter_name": "n1_not_discrete"} + ), + TypeError, + "must be discrete. However, the parameter 'n1_not_discrete'", + id="dep_causing_not_discrete", + ), + param( + ["cat1", "n1", "n2"], + PermutationSymmetry(**valid_config_perm), + IncompatibleSearchSpaceError, + r"not present in the search space", + id="perm_not_present", + ), + param( + ["cat1", "cat2", "n1", "n2"], + PermutationSymmetry(permutation_groups=[("cat1", "n1"), ("cat2", "n2")]), + ValueError, + r"differ in their specification", + id="perm_inconsistent_types", + ), + param( + ["cat1_altered", "cat2", "n1", "n2"], + PermutationSymmetry( + permutation_groups=[["cat1_altered", "cat2"], ["n1", "n2"]] + ), + ValueError, + r"differ in their specification", + id="perm_inconsistent_values", + ), + ], +) +def test_searchspace_context(searchspace, symmetry, error, msg): + """Configurations not compatible with the searchspace raise an expected error.""" + recommender = BotorchRecommender( + surrogate_model=GaussianProcessSurrogate(symmetries=(symmetry,)) + ) + t = NumericalTarget("t") + measurements = create_fake_input(searchspace.parameters, [t]) + + with pytest.raises(error, match=msg): + recommender.recommend( + 1, searchspace, t.to_objective(), measurements=measurements + ) + + +def test_dependency_augmentation_requires_parameters(): + """DependencySymmetry.augment_measurements raises when parameters is None.""" + s = DependencySymmetry(**valid_config_dep) + df = pd.DataFrame({"n1": [0], "n2": [1], "cat1": ["a"]}) + with pytest.raises(ValueError, match="requires parameter objects"): + s.augment_measurements(df) + + +def test_surrogate_rejects_non_symmetry(): + """Surrogate.symmetries rejects non-Symmetry members.""" + with pytest.raises(TypeError, match="must be