Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1ae9bfd
Refactor shared validators and converters
Scienfitz Mar 20, 2026
764caca
Update permutation augmentation utility interface
Scienfitz Mar 20, 2026
99e1e60
Add mirror augmentation utility
Scienfitz Mar 20, 2026
221234b
Add Symmetry domain model
Scienfitz Mar 20, 2026
e9dff9c
Add Parameter.is_equivalent and apply in PermutationSymmetry
Scienfitz Mar 20, 2026
d1c88db
Integrate symmetries into surrogates and recommenders
Scienfitz Mar 20, 2026
68b46a9
Update constraints for symmetry support
Scienfitz Mar 20, 2026
241a9d3
Add hypothesis strategies for symmetries and conditions
Scienfitz Mar 20, 2026
d82f8de
Add symmetry tests
Scienfitz Mar 20, 2026
3d08a20
Add symmetry documentation
Scienfitz Mar 20, 2026
49a046e
Add symmetry example
Scienfitz Mar 20, 2026
bfb9fe9
Handle CompositeSurrogate in symmetry integration
Scienfitz Mar 20, 2026
7b9e134
Fix mypy errors in categorical validator and dependency type ignore
Scienfitz Mar 20, 2026
fb21269
Add symmetry validation tests
Scienfitz Mar 20, 2026
d058531
Update CHANGELOG
Scienfitz Mar 20, 2026
57d03e9
Replace deprecated set_random_seed with Settings in example
Scienfitz Mar 20, 2026
e8a97e8
Fix Sphinx cross-references for symmetry classes
Scienfitz Mar 20, 2026
aa783fe
Fix bug in permutation constraint
Scienfitz Mar 20, 2026
36c4779
Improve docstring
Scienfitz Apr 2, 2026
c853cac
Add docstrings to to_symmetries and to_symmetry methods
Scienfitz Apr 10, 2026
9f076b6
Add use_data_augmentation to symmetry summary
Scienfitz Apr 10, 2026
be1c293
Rework imports
Scienfitz Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `identify_non_dominated_configurations` method to `Campaign` and `Objective`
for determining the Pareto front
- Interpoint constraints for continuous search spaces
- Symmetry classes (`PermutationSymmetry`, `MirrorSymmetry`, `DependencySymmetry`)
for expressing invariances and configuring surrogate data augmentation
- `Parameter.is_equivalent` method for structural parameter comparison

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `df_apply_permutation_augmentation` has a different interface and now expects
permutation groups instead of column groups

### Fixed
- `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to
scalar conversion
- `DiscretePermutationInvarianceConstraint` no longer erroneously removes diagonal
points (e.g., where all permuted parameters have the same value)

### Changed
- The `Campaign.allow_*` flag mechanism is now based on `AutoBool` logic, providing
Expand Down
4 changes: 0 additions & 4 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ class Constraint(ABC, SerialMixin):
eval_during_modeling: ClassVar[bool]
"""Class variable encoding whether the condition is evaluated during modeling."""

eval_during_augmentation: ClassVar[bool] = False
Comment thread
Scienfitz marked this conversation as resolved.
"""Class variable encoding whether the constraint could be considered during data
augmentation."""

numerical_only: ClassVar[bool] = False
"""Class variable encoding whether the constraint is valid only for numerical
parameters."""
Expand Down
2 changes: 1 addition & 1 deletion baybe/constraints/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin):
"""Abstract base class for all conditions.

Conditions always evaluate an expression regarding a single parameter.
Conditions are part of constraints, a constraint can have multiple conditions.
Conditions are part of constraints and symmetries.
"""

@abstractmethod
Expand Down
94 changes: 65 additions & 29 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
if TYPE_CHECKING:
import polars as pl

from baybe.symmetries.dependency import DependencySymmetry
from baybe.symmetries.permutation import PermutationSymmetry


@define
class DiscreteExcludeConstraint(DiscreteConstraint):
Expand Down Expand Up @@ -195,10 +198,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint):
a single constraint.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
conditions: list[Condition] = field()
"""The list of individual conditions."""
Expand Down Expand Up @@ -271,39 +270,56 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:

return inds_bad

def to_symmetries(
self, use_data_augmentation: bool = True
) -> tuple[DependencySymmetry, ...]:
"""Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects.

Create one symmetry object per dependency relationship, i.e., per
(parameter, condition, affected_parameters) triple.

Args:
use_data_augmentation: Flag indicating whether the resulting symmetry
objects should apply data augmentation. ``True`` means that
measurement augmentation will be performed by replacing inactive
affected parameter values with all possible values.

Returns:
A tuple of dependency symmetries, one for each dependency in the
constraint.
"""
from baybe.symmetries.dependency import DependencySymmetry

return tuple(
DependencySymmetry(
parameter_name=p,
condition=c,
affected_parameter_names=aps,
use_data_augmentation=use_data_augmentation,
)
for p, c, aps in zip(
self.parameters, self.conditions, self.affected_parameters, strict=True
)
)


@define
class DiscretePermutationInvarianceConstraint(DiscreteConstraint):
"""Constraint class for declaring that a set of parameters is permutation invariant.

More precisely, this means that, ``(val_from_param1, val_from_param2)`` is
equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense
to have this constraint with duplicated labels, this implementation also internally
applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`.
equivalent to ``(val_from_param2, val_from_param1)``.

*Note:* This constraint is evaluated during creation. In the future it might also be
evaluated during modeling to make use of the invariance.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
dependencies: DiscreteDependenciesConstraint | None = field(default=None)
"""Dependencies connected with the invariant parameters."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
# Get indices of entries with duplicate label entries. These will also be
# dropped by this constraint.
mask_duplicate_labels = pd.Series(False, index=data.index)
mask_duplicate_labels[
DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid(
data
)
] = True

# Merge a permutation invariant representation of all affected parameters with
# the other parameters and indicate duplicates. This ensures that variation in
# other parameters is also accounted for.
Expand All @@ -314,20 +330,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:
data[self.parameters].apply(cast(Callable, frozenset), axis=1),
],
axis=1,
).loc[
~mask_duplicate_labels # only consider label-duplicate-free part
]
)
mask_duplicate_permutations = df_eval.duplicated(keep="first")

# Indices of entries with label-duplicates
inds_duplicate_labels = data.index[mask_duplicate_labels]

# Indices of duplicate permutations in the (already label-duplicate-free) data
inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations]
# Indices of duplicate permutations
inds_invalid = data.index[mask_duplicate_permutations]

# If there are dependencies connected to the invariant parameters evaluate them
# here and remove resulting duplicates with a DependenciesConstraint
inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations)
if self.dependencies:
self.dependencies.permutation_invariant = True
inds_duplicate_independency_adjusted = self.dependencies.get_invalid(
Expand All @@ -337,6 +347,32 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:

return inds_invalid

def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry:
"""Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`.

The constraint's parameters form the primary permutation group. If
dependencies are attached, their parameters are added as an additional
group that is permuted in lockstep.

Args:
use_data_augmentation: Flag indicating whether the resulting symmetry
object should apply data augmentation. ``True`` means that
measurement augmentation will be performed by generating all
permutations of parameter values within each group.

Returns:
The corresponding permutation symmetry.
"""
from baybe.symmetries.permutation import PermutationSymmetry

groups = [self.parameters]
if self.dependencies:
groups.append(list(self.dependencies.parameters))
return PermutationSymmetry(
permutation_groups=groups,
use_data_augmentation=use_data_augmentation,
)


@define
class DiscreteCustomConstraint(DiscreteConstraint):
Expand Down
17 changes: 17 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar

import attrs
import pandas as pd
from attrs import define, field
from attrs.converters import optional as optional_c
Expand Down Expand Up @@ -88,6 +89,22 @@ def to_searchspace(self) -> SearchSpace:

return SearchSpace.from_parameter(self)

def is_equivalent(self, other: Parameter) -> bool:
"""Check if this parameter is equivalent to another, ignoring the name.

Two parameters are considered equivalent if they have the same type and
all attributes are equal except for the name.

Args:
other: The parameter to compare against.

Returns:
``True`` if the parameters are equivalent, ``False`` otherwise.
"""
if type(self) is not type(other):
return False
return attrs.evolve(self, name=other.name) == other

@abstractmethod
def summary(self) -> dict:
"""Return a custom summarization of the parameter."""
Expand Down
14 changes: 5 additions & 9 deletions baybe/parameters/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from baybe.parameters.enum import CategoricalEncoding
from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
from baybe.utils.conversion import nonstring_to_tuple


def _convert_values(value, self, field) -> tuple[str, ...]:
"""Sort and convert values for categorical parameters."""
value = nonstring_to_tuple(value, self, field)
return tuple(sorted(value, key=lambda x: (str(type(x)), x)))
from baybe.utils.conversion import normalize_convertible2str_sequence


def _validate_label_min_len(self, attr, value) -> None:
Expand All @@ -38,8 +32,10 @@ class CategoricalParameter(_DiscreteLabelLikeParameter):
# object variables
_values: tuple[str | bool, ...] = field(
alias="values",
converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore
validator=(
converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter
normalize_convertible2str_sequence, takes_self=True, takes_field=True
),
validator=( # type: ignore[arg-type] # mypy: validator tuple
validate_unique_values,
deep_iterable(
member_validator=(instance_of((str, bool)), _validate_label_min_len),
Expand Down
3 changes: 2 additions & 1 deletion baybe/parameters/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from baybe.exceptions import NumericalUnderflowError
from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.parameters.validation import validate_is_finite, validate_unique_values
from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
from baybe.utils.interval import InfiniteIntervalError, Interval
from baybe.utils.validation import validate_is_finite


@define(frozen=True, slots=False)
Expand Down
17 changes: 0 additions & 17 deletions baybe/parameters/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Validation functionality for parameters."""

from collections.abc import Sequence
from typing import Any

import numpy as np
from attrs.validators import gt, instance_of, lt


Expand All @@ -28,18 +26,3 @@ def validate_decorrelation(obj: Any, attribute: Any, value: float) -> None:
if isinstance(value, float):
gt(0.0)(obj, attribute, value)
lt(1.0)(obj, attribute, value)


def validate_is_finite( # noqa: DOC101, DOC103
obj: Any, _: Any, value: Sequence[float]
) -> None:
"""Validate that ``value`` contains no infinity/nan.

Raises:
ValueError: If ``value`` contains infinity/nan.
"""
if not all(np.isfinite(value)):
raise ValueError(
f"Cannot assign the following values containing infinity/nan to "
f"parameter {obj.name}: {value}."
)
30 changes: 30 additions & 0 deletions baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction
return qLogNEHVI() if objective.is_multi_output else qLogEI()
return self.acquisition_function

def _get_surrogate_for_augmentation(self) -> Surrogate | None:
"""Get the Surrogate instance for augmentation/validation, if available."""
from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping

model = self._surrogate_model
if isinstance(model, Surrogate):
return model
if isinstance(model, CompositeSurrogate):
# All inner surrogates are copies of the same template
surrogates = model.surrogates
if isinstance(surrogates, _ReplicationMapping):
template = surrogates.template
if isinstance(template, Surrogate):
return template
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember our logic correctly, then we allow users to implement whatever surrogates they want as long as they implement the SurrogateProtocol. In particular, we do not enforce them to implement the Surrogate class or inherit from this. This means that this function returns None, hence no augmentation will be done. However, I do not think that the user is informed at any point that no augmentation has been done in that case. Also, I think the same happens for a CompositeSurrogate that does not have a _ReplicationMapping but actual different surrogates.

Three questions:

  1. Do I understand this logic here correctly?
  2. Do we already have some sort of safe guard/warning informing the user that no augmentation will be done in this case?
  3. If not, can/do we want to add one?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm in the case you describe there is also no moment when the user ever assingned any symmetries because their custom surrogate doesnt have this attribute - so why would there be an expectation of applied symmetries that has to be warned about?


def get_surrogate(
self,
searchspace: SearchSpace,
Expand Down Expand Up @@ -114,6 +130,13 @@ def _setup_botorch_acqf(
f"{len(objective.targets)}-target multi-output context."
)

# Perform data augmentation if configured
surrogate_for_augmentation = self._get_surrogate_for_augmentation()
if surrogate_for_augmentation is not None:
Comment thread
AVHopp marked this conversation as resolved.
measurements = surrogate_for_augmentation.augment_measurements(
measurements, searchspace.parameters
)

surrogate = self.get_surrogate(searchspace, objective, measurements)
self._botorch_acqf = acqf.to_botorch(
surrogate,
Expand Down Expand Up @@ -156,6 +179,13 @@ def recommend(

validate_object_names(searchspace.parameters + objective.targets)

# Validate compatibility of surrogate symmetries with searchspace
surrogate_for_validation = self._get_surrogate_for_augmentation()
if surrogate_for_validation is not None:
for s in surrogate_for_validation.symmetries:
s.validate_searchspace_context(searchspace)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: Validation so far is only part of the recommend call here in the recommenders. Validation has not been included in the Campaign yet. This is due to two factors

  • To properly validate the symmetries and searchspace compatibility there needs to be a mechanism that can iterate over all possible recommenders of a metarecommender. Otherwise this upfront validation already fails for the two phase recommender if the second recommender has symmetries
  • There would be double validation with campaign and recommend call so the context info of whether validation was already performed needs to be passed somewhere. Likely fixable with settings mechanism not yet available

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AdrianSosic I see now that the 2nd point could be solved with the Settings mechanism but I have no idea how to solve issue 1.

In the absence of that its not realy possible to turn it into an upfront validation, so I would probably not change the validation for this moment unless you have a smarter idea

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for being pragmatic and not trying to come up with something potentially convoluted right now. Even if we find a better way for the validation later, including it is just a plain improvement without negative consequences to users, so we can add it later without problems.


# Experimental input validation
if (measurements is None) or measurements.empty:
raise NotImplementedError(
f"Recommenders of type '{BayesianRecommender.__name__}' do not support "
Expand Down
5 changes: 0 additions & 5 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,6 @@ def transform(

return comp_rep

@property
def constraints_augmentable(self) -> tuple[Constraint, ...]:
"""The searchspace constraints that can be considered during augmentation."""
return tuple(c for c in self.constraints if c.eval_during_augmentation)

def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]:
"""Return parameters with the specified names.

Expand Down
Loading
Loading