diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ab1fa1c52..3da863c0f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `identify_non_dominated_configurations` method to `Campaign` and `Objective` for determining the Pareto front - Interpoint constraints for continuous search spaces +- `has_polars_implementation` property on `DiscreteConstraint` +- `allow_missing` flag on `DiscreteConstraint.get_invalid` and `get_valid` + +### Changed +- Discrete search space construction now applies constraints incrementally during + Cartesian product building, significantly reducing memory usage and construction + time for constrained spaces +- Polars path in discrete search space construction now builds the Cartesian product + only for parameters involved in Polars-capable constraints, merging the rest + incrementally via pandas ### Breaking Changes +- `parameter_cartesian_prod_pandas` and `parameter_cartesian_prod_polars` moved + from `baybe.searchspace.discrete` to `baybe.searchspace.utils` - `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples instead of a single tuple (needed for interpoint constraints) diff --git a/baybe/constraints/base.py b/baybe/constraints/base.py index 5c1a6d33ed..e1fa0e47d3 100644 --- a/baybe/constraints/base.py +++ b/baybe/constraints/base.py @@ -20,6 +20,7 @@ from baybe.serialization.core import ( converter, ) +from baybe.utils.basic import classproperty if TYPE_CHECKING: import polars as pl @@ -81,6 +82,17 @@ def is_discrete(self) -> bool: """Boolean indicating if this is a constraint over discrete parameters.""" return isinstance(self, DiscreteConstraint) + @property + def _required_parameters(self) -> set[str]: + """All parameter names needed for full constraint evaluation. + + For most constraints, this is simply the set of names from + :attr:`~baybe.constraints.base.Constraint.parameters`. + Constraints with additional parameter references (e.g., affected + parameters in dependency constraints) override this to include those. + """ + return set(self.parameters) + @define class DiscreteConstraint(Constraint, ABC): @@ -97,29 +109,76 @@ class DiscreteConstraint(Constraint, ABC): eval_during_modeling: ClassVar[bool] = False # See base class. - def get_valid(self, df: pd.DataFrame, /) -> pd.Index: + def get_valid( + self, df: pd.DataFrame, /, *, allow_missing: bool = False + ) -> pd.Index: """Get the indices of dataframe entries that are valid under the constraint. Args: df: A dataframe where each row represents a parameter configuration. + allow_missing: If ``False``, a :class:`ValueError` is raised when + the dataframe is missing required parameter columns. If + ``True``, the constraint performs partial filtering on the + available columns. Returns: The dataframe indices of rows that fulfill the constraint. """ - invalid = self.get_invalid(df) + invalid = self.get_invalid(df, allow_missing=allow_missing) return df.index.drop(invalid) - @abstractmethod - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def get_invalid( + self, data: pd.DataFrame, /, *, allow_missing: bool = False + ) -> pd.Index: """Get the indices of dataframe entries that are invalid under the constraint. Args: - data: A dataframe where each row represents a parameter configuration. + data: A dataframe where each row represents a parameter + configuration. + allow_missing: If ``False``, a :class:`ValueError` is raised when + the dataframe is missing required parameter columns. If + ``True``, the constraint performs partial filtering on the + available columns, returning an empty index when insufficient + columns are present. + + Raises: + ValueError: If ``allow_missing`` is ``False`` and the dataframe + is missing required parameter columns. Returns: The dataframe indices of rows that violate the constraint. """ # TODO: Should switch backends (pandas/polars/...) behind the scenes + if not allow_missing: + if missing := self._required_parameters - set(data.columns): + raise ValueError( + f"'{self.__class__.__name__}' requires columns {missing} " + f"which are missing from the dataframe." + ) + return self._get_invalid(data) + + @abstractmethod + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + """Get the indices of invalid entries (implementation for subclasses). + + Subclasses implement this method with their specific filtering logic. + When the dataframe contains only a subset of the constraint's + parameters, implementations should return an empty index if they + cannot perform useful filtering. + + Args: + data: A dataframe where each row represents a parameter + configuration. May contain all or a subset of the constraint's + parameters. + + Returns: + The dataframe indices of rows that violate the constraint. + """ + + @classproperty + def has_polars_implementation(cls) -> bool: + """Whether this constraint class has a Polars implementation.""" + return cls.get_invalid_polars is not DiscreteConstraint.get_invalid_polars def get_invalid_polars(self) -> pl.Expr: """Translate the constraint to Polars expression identifying undesired rows. diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index 740e603f89..c66d4f8f83 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -42,12 +42,19 @@ class DiscreteExcludeConstraint(DiscreteConstraint): """Operator encoding how to combine the individual conditions.""" @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: - satisfied = [ - cond.evaluate(data[self.parameters[k]]) - for k, cond in enumerate(self.conditions) - ] + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + pairs = [(p, c) for p, c in zip(self.parameters, self.conditions) if p in data] + if not pairs: + return pd.Index([]) + + # Only the OR combiner supports incremental filtering: a single + # true condition is sufficient to mark a row as invalid. + if self.combiner != "OR" and len(pairs) < len(self.parameters): + return pd.Index([]) + + satisfied = [cond.evaluate(data[p]) for p, cond in pairs] res = reduce(_valid_logic_combiners[self.combiner], satisfied) + return data.index[res] @override @@ -78,7 +85,13 @@ class DiscreteSumConstraint(DiscreteConstraint): """The condition modeled by this constraint.""" @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + # IMPROVE: Look-ahead filtering would be possible if parameter + # value ranges (min/max) were available to the constraint, allowing + # bound-based pruning of partial sums before all parameters are + # present. + if not set(self.parameters) <= set(data.columns): + return pd.Index([]) evaluate_data = data[self.parameters].sum(axis=1) mask_bad = ~self.condition.evaluate(evaluate_data) @@ -106,7 +119,13 @@ class DiscreteProductConstraint(DiscreteConstraint): """The condition that is used for this constraint.""" @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + # IMPROVE: Look-ahead filtering would be possible if parameter + # value ranges (min/max) were available to the constraint, allowing + # bound-based pruning of partial products before all parameters are + # present. + if not set(self.parameters) <= set(data.columns): + return pd.Index([]) evaluate_data = data[self.parameters].prod(axis=1) mask_bad = ~self.condition.evaluate(evaluate_data) @@ -140,8 +159,11 @@ class DiscreteNoLabelDuplicatesConstraint(DiscreteConstraint): """ @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: - mask_bad = data[self.parameters].nunique(axis=1) != len(self.parameters) + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + params = [p for p in self.parameters if p in data] + if len(params) < 2: + return pd.Index([]) + mask_bad = data[params].nunique(axis=1) != len(params) return data.index[mask_bad] @@ -158,6 +180,7 @@ def get_invalid_polars(self) -> pl.Expr: return expr +@define class DiscreteLinkedParametersConstraint(DiscreteConstraint): """Constraint class for linking the values of parameters. @@ -168,8 +191,11 @@ class DiscreteLinkedParametersConstraint(DiscreteConstraint): """ @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: - mask_bad = data[self.parameters].nunique(axis=1) != 1 + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + params = [p for p in self.parameters if p in set(data.columns)] + if len(params) < 2: + return pd.Index([]) + mask_bad = data[params].nunique(axis=1) != 1 return data.index[mask_bad] @@ -228,8 +254,19 @@ def _validate_affected_parameters( # noqa: DOC101, DOC103 f"the conditions list." ) + @property + @override + def _required_parameters(self) -> set[str]: + """See base class.""" + params = set(self.parameters) + for group in self.affected_parameters: + params.update(group) + return params + @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + if not self._required_parameters <= set(data.columns): + return pd.Index([]) # Create data copy and mark entries where the dependency conditions are negative # with a dummy value to cause degeneracy. censored_data = data.copy() @@ -293,28 +330,45 @@ class DiscretePermutationInvarianceConstraint(DiscreteConstraint): dependencies: DiscreteDependenciesConstraint | None = field(default=None) """Dependencies connected with the invariant parameters.""" + @property @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def _required_parameters(self) -> set[str]: + """See base class.""" + params = set(self.parameters) + if self.dependencies: + params.update(self.dependencies._required_parameters) + return params + + @override + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + cols = set(data.columns) + params = [p for p in self.parameters if p in cols] + if len(params) < 2: + return pd.Index([]) + # When dependencies exist, permutation dedup on a partial set of + # parameters is not safe because the dependency logic can change + # which permutations are equivalent. In this case, only the + # label-dedup part (which is always safe incrementally) is applied. + if self.dependencies: + if not self._required_parameters <= cols: + return DiscreteNoLabelDuplicatesConstraint( + parameters=params + ).get_invalid(data) + # Get indices of entries with duplicate label entries. These will also be # dropped by this constraint. mask_duplicate_labels = pd.Series(False, index=data.index) mask_duplicate_labels[ - DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid( - data - ) + DiscreteNoLabelDuplicatesConstraint(parameters=params).get_invalid(data) ] = True # Merge a permutation invariant representation of all affected parameters with # the other parameters and indicate duplicates. This ensures that variation in # other parameters is also accounted for. - other_params = data.columns.drop(self.parameters).tolist() - df_eval = pd.concat( - [ - data[other_params].copy(), - data[self.parameters].apply(cast(Callable, frozenset), axis=1), - ], - axis=1, - ).loc[ + other_params = data.columns.drop(params).tolist() + frozen = data[params].apply(cast(Callable, frozenset), axis=1) + parts = [data[other_params].copy(), frozen] if other_params else [frozen] + df_eval = pd.concat(parts, axis=1).loc[ ~mask_duplicate_labels # only consider label-duplicate-free part ] mask_duplicate_permutations = df_eval.duplicated(keep="first") @@ -349,7 +403,9 @@ class DiscreteCustomConstraint(DiscreteConstraint): you want to keep/remove.""" @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + if not set(self.parameters) <= set(data.columns): + return pd.Index([]) mask_bad = ~self.validator(data[self.parameters]) return data.index[mask_bad] @@ -364,10 +420,21 @@ class DiscreteCardinalityConstraint(CardinalityConstraint, DiscreteConstraint): # See base class. @override - def get_invalid(self, data: pd.DataFrame) -> pd.Index: - non_zeros = (data[self.parameters] != 0.0).sum(axis=1) + def _get_invalid(self, data: pd.DataFrame) -> pd.Index: + cols = set(data.columns) + params = [p for p in self.parameters if p in cols] + if not params: + return pd.Index([]) + all_present = len(params) == len(self.parameters) + + non_zeros = (data[params] != 0.0).sum(axis=1) + # The max_cardinality check is safe on a partial subset: the nonzero + # count can only increase as more parameters are added. mask_bad = non_zeros > self.max_cardinality - mask_bad |= non_zeros < self.min_cardinality + # The min_cardinality check can only be applied when all parameters + # are present, since missing parameters could still add nonzero values. + if all_present: + mask_bad |= non_zeros < self.min_cardinality return data.index[mask_bad] diff --git a/baybe/constraints/validation.py b/baybe/constraints/validation.py index 51a1a7a918..bd2bc7a89f 100644 --- a/baybe/constraints/validation.py +++ b/baybe/constraints/validation.py @@ -54,29 +54,30 @@ def validate_constraints( # noqa: DOC101, DOC103 ] for constraint in constraints: - if not all(p in param_names_all for p in constraint.parameters): + if not all(p in param_names_all for p in constraint._required_parameters): raise ValueError( f"You are trying to create a constraint with at least one parameter " f"name that does not exist in the list of defined parameters. " - f"Parameter list of the affected constraint: {constraint.parameters}" + f"Parameter list of the affected constraint: " + f"{constraint._required_parameters}" ) if constraint.is_continuous and any( - p in param_names_discrete for p in constraint.parameters + p in param_names_discrete for p in constraint._required_parameters ): raise ValueError( f"You are trying to initialize a continuous constraint over a " f"parameter that is discrete. Parameter list of the affected " - f"constraint: {constraint.parameters}" + f"constraint: {constraint._required_parameters}" ) if constraint.is_discrete and any( - p in param_names_continuous for p in constraint.parameters + p in param_names_continuous for p in constraint._required_parameters ): raise ValueError( f"You are trying to initialize a discrete constraint over a parameter " f"that is continuous. Parameter list of the affected constraint: " - f"{constraint.parameters}" + f"{constraint._required_parameters}" ) if constraint.numerical_only and any( diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index efae2cfc6b..9046812775 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -3,8 +3,8 @@ from __future__ import annotations import gc +import warnings from collections.abc import Collection, Sequence -from itertools import compress from math import prod from typing import TYPE_CHECKING, Any @@ -24,6 +24,10 @@ ) from baybe.parameters.base import DiscreteParameter from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters +from baybe.searchspace.utils import ( + parameter_cartesian_prod_pandas_constrained, + parameter_cartesian_prod_polars, +) from baybe.searchspace.validation import validate_parameter_names, validate_parameters from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.settings import active_settings @@ -189,18 +193,37 @@ def from_product( ) if active_settings.use_polars_for_constraints: - lazy_df = parameter_cartesian_prod_polars(parameters) - lazy_df, mask_missing = _apply_constraint_filter_polars( - lazy_df, constraints + # Partition constraints by Polars support + polars_constraints = [c for c in constraints if c.has_polars_implementation] + pandas_constraints = [ + c for c in constraints if not c.has_polars_implementation + ] + + # Determine which parameters are needed by Polars-capable constraints + polars_param_names: set[str] = set() + for c in polars_constraints: + polars_param_names.update(c._required_parameters) + polars_params = [p for p in parameters if p.name in polars_param_names] + remaining_params = [ + p for p in parameters if p.name not in polars_param_names + ] + + if polars_params: + # Build Polars product only for relevant parameters and filter + lazy_df = parameter_cartesian_prod_polars(polars_params) + lazy_df, _ = _apply_constraint_filter_polars( + lazy_df, polars_constraints + ) + initial_df = lazy_df.collect().to_pandas() + else: + initial_df = None + + # Merge remaining parameters with pandas-only constraint filtering + df = parameter_cartesian_prod_pandas_constrained( + remaining_params, pandas_constraints, initial_df=initial_df ) - df_records = lazy_df.collect().to_dicts() - df = pd.DataFrame.from_records(df_records) else: - df = parameter_cartesian_prod_pandas(parameters) - mask_missing = [True] * len(constraints) - - # Gather and use constraints not yet applied - _apply_constraint_filter_pandas(df, list(compress(constraints, mask_missing))) + df = parameter_cartesian_prod_pandas_constrained(parameters, constraints) return SubspaceDiscrete( parameters=parameters, @@ -358,10 +381,16 @@ def from_simplex( f"parameters: {overlap}." ) - # Construct the product part of the space - product_space = parameter_cartesian_prod_pandas(product_parameters) - if not simplex_parameters: - return cls(parameters=product_parameters, exp_rep=product_space) + # Handle degenerate simplex cases + if len(simplex_parameters) < 2: + warnings.warn( + f"'{cls.from_simplex.__name__}' was called with less than 2 " + f"simplex parameters, so smart simplex construction has no effect." + f"Consider using '{cls.from_product.__name__}' instead.", + UserWarning, + ) + if len(simplex_parameters) < 1: + return cls.from_product(product_parameters, constraints) # Validate non-negativity min_values = [min(p.values) for p in simplex_parameters] @@ -471,12 +500,10 @@ def drop_invalid( if boundary_only: drop_invalid(exp_rep, max_sum, boundary_only=True) - # Augment the Cartesian product created from all other parameter types - if product_parameters: - exp_rep = pd.merge(exp_rep, product_space, how="cross") - - # Remove entries that violate parameter constraints: - _apply_constraint_filter_pandas(exp_rep, constraints) + # Merge product parameters and apply constraints incrementally + exp_rep = parameter_cartesian_prod_pandas_constrained( + product_parameters, constraints, initial_df=exp_rep + ) return cls( parameters=[*simplex_parameters, *product_parameters], @@ -688,59 +715,6 @@ def _apply_constraint_filter_polars( return ldf, mask_missing -def parameter_cartesian_prod_polars( - parameters: Sequence[DiscreteParameter], -) -> pl.LazyFrame: - """Create the Cartesian product of discrete parameter values using Polars. - - Args: - parameters: List of discrete parameter objects. - - Returns: - A lazy dataframe containing all possible discrete parameter value combinations. - """ - from baybe._optional.polars import polars as pl - - if not parameters: - return pl.LazyFrame() - - # Convert each parameter to a lazy dataframe for cross-join operation - param_frames = [pl.LazyFrame({p.name: p.active_values}) for p in parameters] - - # Handling edge cases - if len(param_frames) == 1: - return param_frames[0] - - # Cross-join parameters - res = param_frames[0] - for frame in param_frames[1:]: - res = res.join(frame, how="cross", force_parallel=True) - - return res - - -def parameter_cartesian_prod_pandas( - parameters: Sequence[DiscreteParameter], -) -> pd.DataFrame: - """Create the Cartesian product of discrete parameter values using Pandas. - - Args: - parameters: List of discrete parameter objects. - - Returns: - A dataframe containing all possible discrete parameter value combinations. - """ - if not parameters: - return pd.DataFrame() - - index = pd.MultiIndex.from_product( - [p.active_values for p in parameters], names=[p.name for p in parameters] - ) - ret = pd.DataFrame(index=index).reset_index() - - return ret - - def validate_simplex_subspace_from_config(specs: dict, _) -> None: """Validate the discrete space while skipping costly creation steps.""" # Validate product inputs without constructing it diff --git a/baybe/searchspace/utils.py b/baybe/searchspace/utils.py new file mode 100644 index 0000000000..9a6ebbb4dd --- /dev/null +++ b/baybe/searchspace/utils.py @@ -0,0 +1,219 @@ +"""Utilities for search space construction.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import pandas as pd + +from baybe.constraints.base import DiscreteConstraint +from baybe.parameters.base import DiscreteParameter + +if TYPE_CHECKING: + import polars as pl + + +def optimize_parameter_order( + parameters: Sequence[DiscreteParameter], + constraints: Sequence[DiscreteConstraint], +) -> list[DiscreteParameter]: + """Determine a heuristic parameter ordering for incremental space construction. + + Parameters involved in constraints are placed first, ordered so that the + parameters completing the most constraints come earliest. Parameters not + involved in any constraint are placed last. + + Args: + parameters: The discrete parameters. + constraints: The discrete constraints. + + Returns: + The parameters in an order optimized for incremental constraint + filtering. + """ + if not constraints: + return list(parameters) + + # Compute which parameter names each constraint needs + constraint_params = [c._required_parameters for c in constraints] + + # Separate constrained from unconstrained parameters + all_constrained_names = set().union(*constraint_params) + constrained = [p for p in parameters if p.name in all_constrained_names] + unconstrained = [p for p in parameters if p.name not in all_constrained_names] + + # Greedy ordering: at each step, pick the parameter whose addition + # completes (is the last missing parameter for) the most constraints. + # Ties are broken by picking the parameter with fewest active values + # (smallest expansion factor during cross-merging). + ordered: list[DiscreteParameter] = [] + available: set[str] = set() + remaining = list(constrained) + + while remaining: + best_param = None + best_score = (-1, float("inf")) # (completions, -active_values) + + for param in remaining: + candidate_available = available | {param.name} + completions = sum( + 1 + for cp in constraint_params + if cp <= candidate_available and not cp <= available + ) + n_values = len(param.active_values) + score = (completions, -n_values) + if score > best_score: + best_score = score + best_param = param + + assert best_param is not None + ordered.append(best_param) + available.add(best_param.name) + remaining.remove(best_param) + + # Unconstrained parameters go last + ordered.extend(unconstrained) + return ordered + + +def parameter_cartesian_prod_polars( + parameters: Sequence[DiscreteParameter], +) -> pl.LazyFrame: + """Create the Cartesian product of discrete parameter values using Polars. + + Args: + parameters: List of discrete parameter objects. + + Returns: + A lazy dataframe containing all possible discrete parameter value combinations. + """ + from baybe._optional.polars import polars as pl + + if not parameters: + return pl.LazyFrame() + + # Convert each parameter to a lazy dataframe for cross-join operation + param_frames = [pl.LazyFrame({p.name: p.active_values}) for p in parameters] + + # Handling edge cases + if len(param_frames) == 1: + return param_frames[0] + + # Cross-join parameters + res = param_frames[0] + for frame in param_frames[1:]: + res = res.join(frame, how="cross", force_parallel=True) + + return res + + +def parameter_cartesian_prod_pandas( + parameters: Sequence[DiscreteParameter], +) -> pd.DataFrame: + """Create the Cartesian product of discrete parameter values using Pandas. + + Args: + parameters: List of discrete parameter objects. + + Returns: + A dataframe containing all possible discrete parameter value combinations. + """ + if not parameters: + return pd.DataFrame() + + index = pd.MultiIndex.from_product( + [p.active_values for p in parameters], names=[p.name for p in parameters] + ) + ret = pd.DataFrame(index=index).reset_index() + + return ret + + +def parameter_cartesian_prod_pandas_constrained( + parameters: Sequence[DiscreteParameter], + constraints: Sequence[DiscreteConstraint], + initial_df: pd.DataFrame | None = None, +) -> pd.DataFrame: + """Build a Cartesian product of parameters with incremental constraint filtering. + + Instead of creating the full Cartesian product and then filtering, this + function cross-merges parameters one by one, applying constraint filters + as early as possible. This significantly reduces memory usage and + construction time for highly constrained spaces. + + Parameters are ordered so that constrained parameters come first, enabling + constraints to fire early when the intermediate dataframe is still small. + + Args: + parameters: The discrete parameters to combine. + constraints: The discrete constraints to apply. + initial_df: An optional starting dataframe. When provided, the given + parameters are cross-merged into it (its columns count as already + available for constraint evaluation). + + Returns: + A dataframe containing all valid parameter combinations. + """ + # Filter to constraints that should be applied during creation + filtering_constraints = [c for c in constraints if c.eval_during_creation] + + # Fast path: no constraints and no initial dataframe + if not filtering_constraints and initial_df is None: + return parameter_cartesian_prod_pandas(parameters) + + # Compute optimal parameter order + ordered_params = optimize_parameter_order(parameters, filtering_constraints) + + # Determine which parameter names each constraint needs for completion + pending: list[tuple[DiscreteConstraint, set[str]]] = [ + (c, c._required_parameters) for c in filtering_constraints + ] + + # Initialize the dataframe + if initial_df is not None: + df = initial_df + else: + df = pd.DataFrame() + + # Original column order for final reindexing + original_columns = (list(initial_df.columns) if initial_df is not None else []) + [ + p.name for p in parameters + ] + + # Incremental cross-merge loop + for param in ordered_params: + param_df = pd.DataFrame({param.name: param.active_values}) + if df.empty: + df = param_df + else: + df = pd.merge(df, param_df, how="cross") + + available = set(df.columns) + still_pending: list[tuple[DiscreteConstraint, set[str]]] = [] + + for constraint, all_params in pending: + idxs = constraint.get_invalid(df, allow_missing=True) + df.drop(index=idxs, inplace=True) + + if not (all_params <= available): + still_pending.append((constraint, all_params)) + + pending = still_pending + + # Apply any remaining constraints whose parameters were already present + # in the initial_df (i.e., no new parameters were needed to complete them) + if pending and not df.empty: + available = set(df.columns) + for constraint, all_params in pending: + if all_params <= available: + idxs = constraint.get_invalid(df) + df.drop(index=idxs, inplace=True) + + # Reorder columns and reset index + if original_columns: + df = df[original_columns] + df.reset_index(drop=True, inplace=True) + + return df diff --git a/docs/conf.py b/docs/conf.py index 1927455f49..ce8ef93e2a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -139,6 +139,7 @@ (r"py:obj", "baybe.acquisition.base.*.supports_multi_output"), (r"py:obj", "baybe.acquisition.base.*.is_analytic"), (r"py:obj", "baybe.surrogates.*.is_available"), + (r"py:obj", r"baybe.constraints.*.has_polars_implementation"), # KMedoids (r"py:.*", r".*clustering_algorithms.*KMedoids.*"), (r"ref:.*", r".*clustering_algorithms.*KMedoids.*"), diff --git a/tests/constraints/test_constrained_cartesian_product.py b/tests/constraints/test_constrained_cartesian_product.py new file mode 100644 index 0000000000..9786c6b32e --- /dev/null +++ b/tests/constraints/test_constrained_cartesian_product.py @@ -0,0 +1,245 @@ +"""Tests comparing naive vs incremental constrained Cartesian product construction.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import partial + +import numpy as np +import pytest +from pandas.testing import assert_frame_equal + +from baybe.constraints import ( + DiscreteCardinalityConstraint, + DiscreteDependenciesConstraint, + DiscreteExcludeConstraint, + DiscreteLinkedParametersConstraint, + DiscreteNoLabelDuplicatesConstraint, + DiscretePermutationInvarianceConstraint, + DiscreteSumConstraint, + SubSelectionCondition, + ThresholdCondition, +) +from baybe.constraints.base import DiscreteConstraint +from baybe.parameters import CategoricalParameter, NumericalDiscreteParameter +from baybe.parameters.base import DiscreteParameter +from baybe.searchspace.discrete import _apply_constraint_filter_pandas +from baybe.searchspace.utils import ( + parameter_cartesian_prod_pandas, + parameter_cartesian_prod_pandas_constrained, +) + + +def _no_constraints_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + params = [ + CategoricalParameter(name="A", values=["a1", "a2"]), + CategoricalParameter(name="B", values=["b1", "b2", "b3"]), + ] + return params, [] + + +def _no_label_duplicates_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + values = ["x", "y", "z", "w"] + params = [CategoricalParameter(name=f"P{i}", values=values) for i in range(4)] + constraints = [ + DiscreteNoLabelDuplicatesConstraint(parameters=[p.name for p in params]) + ] + return params, constraints + + +def _linked_parameters_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + values = ["a", "b", "c"] + params = [CategoricalParameter(name=f"P{i}", values=values) for i in range(3)] + constraints = [ + DiscreteLinkedParametersConstraint(parameters=[p.name for p in params]) + ] + return params, constraints + + +def _exclude_scenario( + combiner: str, +) -> tuple[Sequence[DiscreteParameter], Sequence[DiscreteConstraint]]: + params = [ + CategoricalParameter(name="A", values=["a1", "a2", "a3"]), + CategoricalParameter(name="B", values=["b1", "b2", "b3"]), + CategoricalParameter(name="C", values=["c1", "c2", "c3"]), + ] + constraints = [ + DiscreteExcludeConstraint( + parameters=["A", "B"], + conditions=[ + SubSelectionCondition(selection=["a1"]), + SubSelectionCondition(selection=["b1"]), + ], + combiner=combiner, + ) + ] + return params, constraints + + +def _cardinality_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + params = [ + NumericalDiscreteParameter(name=f"P{i}", values=[0.0, 1.0, 2.0]) + for i in range(4) + ] + constraints = [ + DiscreteCardinalityConstraint( + parameters=[p.name for p in params], + min_cardinality=1, + max_cardinality=2, + ) + ] + return params, constraints + + +def _sum_scenario() -> tuple[Sequence[DiscreteParameter], Sequence[DiscreteConstraint]]: + params = [ + NumericalDiscreteParameter(name=f"P{i}", values=[0.0, 25.0, 50.0, 75.0, 100.0]) + for i in range(3) + ] + constraints = [ + DiscreteSumConstraint( + parameters=[p.name for p in params], + condition=ThresholdCondition(threshold=100, operator="=", tolerance=0.1), + ) + ] + return params, constraints + + +def _dependencies_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + params = [ + CategoricalParameter(name="Switch", values=["on", "off"]), + CategoricalParameter(name="Label", values=["a", "b", "c"]), + NumericalDiscreteParameter(name="Amount", values=[1.0, 2.0, 3.0]), + ] + constraints = [ + DiscreteDependenciesConstraint( + parameters=["Switch"], + conditions=[SubSelectionCondition(selection=["on"])], + affected_parameters=[["Label"]], + ) + ] + return params, constraints + + +def _permutation_invariance_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + values = ["a", "b", "c", "d"] + params = [CategoricalParameter(name=f"P{i}", values=values) for i in range(3)] + constraints = [ + DiscretePermutationInvarianceConstraint(parameters=[p.name for p in params]) + ] + return params, constraints + + +def _permutation_invariance_with_dependencies_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + solvents = ["water", "ethanol", "methanol", "acetone"] + labels = [ + CategoricalParameter(name=f"Slot{i}_Label", values=solvents) for i in range(3) + ] + amounts = [ + NumericalDiscreteParameter( + name=f"Slot{i}_Amount", values=list(np.linspace(0, 100, 5)) + ) + for i in range(3) + ] + params = labels + amounts + label_names = [lbl.name for lbl in labels] + amount_names = [amt.name for amt in amounts] + + constraints = [ + DiscretePermutationInvarianceConstraint( + parameters=label_names, + dependencies=DiscreteDependenciesConstraint( + parameters=amount_names, + conditions=[ + ThresholdCondition(threshold=0.0, operator=">"), + ThresholdCondition(threshold=0.0, operator=">"), + ThresholdCondition(threshold=0.0, operator=">"), + ], + affected_parameters=[[n] for n in label_names], + ), + ), + DiscreteSumConstraint( + parameters=amount_names, + condition=ThresholdCondition(threshold=100, operator="=", tolerance=0.1), + ), + DiscreteNoLabelDuplicatesConstraint(parameters=label_names), + ] + return params, constraints + + +def _mixed_scenario() -> tuple[ + Sequence[DiscreteParameter], Sequence[DiscreteConstraint] +]: + params = [ + CategoricalParameter(name="Cat1", values=["a", "b", "c"]), + CategoricalParameter(name="Cat2", values=["a", "b", "c"]), + CategoricalParameter(name="Cat3", values=["a", "b", "c"]), + NumericalDiscreteParameter(name="Num1", values=[0.0, 50.0, 100.0]), + NumericalDiscreteParameter(name="Num2", values=[0.0, 50.0, 100.0]), + ] + constraints = [ + DiscreteNoLabelDuplicatesConstraint(parameters=["Cat1", "Cat2", "Cat3"]), + DiscreteSumConstraint( + parameters=["Num1", "Num2"], + condition=ThresholdCondition(threshold=100, operator="<="), + ), + ] + return params, constraints + + +@pytest.mark.parametrize( + "scenario", + [ + pytest.param(_no_constraints_scenario, id="no_constraints"), + pytest.param(_no_label_duplicates_scenario, id="no_label_duplicates"), + pytest.param(_linked_parameters_scenario, id="linked_parameters"), + pytest.param(partial(_exclude_scenario, "OR"), id="exclude_or"), + pytest.param(partial(_exclude_scenario, "AND"), id="exclude_and"), + pytest.param(_cardinality_scenario, id="cardinality"), + pytest.param(_sum_scenario, id="sum"), + pytest.param(_dependencies_scenario, id="dependencies"), + pytest.param(_permutation_invariance_scenario, id="permutation_invariance"), + pytest.param( + _permutation_invariance_with_dependencies_scenario, + id="permutation_invariance_with_deps", + ), + pytest.param(_mixed_scenario, id="mixed"), + ], +) +def test_constrained_cartesian_product(scenario): + """Verify incremental and naive product construction produce identical results.""" + parameters, constraints = scenario() + + # Naive approach: full product then filter + df_naive = parameter_cartesian_prod_pandas(parameters) + _apply_constraint_filter_pandas(df_naive, constraints) + + # Incremental approach + df_incremental = parameter_cartesian_prod_pandas_constrained( + parameters, constraints + ) + + # Column order must be identical + assert list(df_incremental.columns) == list(df_naive.columns) + + # Content must be identical (row order may differ) + cols = df_naive.columns.tolist() + assert_frame_equal( + df_incremental.sort_values(cols).reset_index(drop=True), + df_naive.sort_values(cols).reset_index(drop=True), + ) diff --git a/tests/constraints/test_constraints_polars.py b/tests/constraints/test_constraints_polars.py index dda694804f..9301d9cdf8 100644 --- a/tests/constraints/test_constraints_polars.py +++ b/tests/constraints/test_constraints_polars.py @@ -7,6 +7,8 @@ from baybe.searchspace.discrete import ( _apply_constraint_filter_pandas, _apply_constraint_filter_polars, +) +from baybe.searchspace.utils import ( parameter_cartesian_prod_pandas, parameter_cartesian_prod_polars, ) diff --git a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py index e408358c9a..662e898134 100644 --- a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py +++ b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py @@ -116,7 +116,7 @@ def test_discrete_searchspace_creation_from_degenerate_dataframe(): @given( parameters=st.lists( numerical_discrete_parameters(min_value=0.0, max_value=1.0), - min_size=1, + min_size=2, max_size=5, unique_by=lambda x: x.name, ) @@ -157,7 +157,6 @@ def test_discrete_space_creation_from_simplex_inner(parameters, boundary_only): [ param([p_d1, p_d2], [p_t1, p_t2], 6 * 4, id="both"), param([p_d1, p_d2], [], 6, id="simplex-only"), - param([], [p_t1, p_t2], 4, id="task_only"), ], ) def test_discrete_space_creation_from_simplex_mixed( diff --git a/tests/test_searchspace.py b/tests/test_searchspace.py index 8b394db683..0579a82d44 100644 --- a/tests/test_searchspace.py +++ b/tests/test_searchspace.py @@ -29,7 +29,7 @@ SubspaceContinuous, SubspaceDiscrete, ) -from baybe.searchspace.discrete import ( +from baybe.searchspace.utils import ( parameter_cartesian_prod_pandas, parameter_cartesian_prod_polars, ) @@ -139,6 +139,35 @@ def test_invalid_simplex_creating_with_overlapping_parameters(): ) +@pytest.mark.parametrize( + ("simplex_parameters", "expected_len"), + [ + pytest.param([], 3, id="zero_simplex"), + pytest.param( + [NumericalDiscreteParameter("x", values=[0.0, 0.5, 1.0, 1.5, 2.0])], + 9, + id="one_simplex", + ), + ], +) +def test_from_simplex_with_degenerate_parameter_count(simplex_parameters, expected_len): + """Calling from_simplex with less than 2 simplex parameters emits a warning.""" + product_parameters = [CategoricalParameter(name="C", values=["a", "b", "c"])] + + with pytest.warns(UserWarning, match="less than 2 simplex parameters"): + subspace = SubspaceDiscrete.from_simplex( + max_sum=1.0, + simplex_parameters=simplex_parameters, + product_parameters=product_parameters, + ) + + assert len(subspace.exp_rep) == expected_len + + if simplex_parameters: + simplex_cols = [p.name for p in simplex_parameters] + assert all(subspace.exp_rep[simplex_cols].sum(axis=1) <= 1.0) + + def test_continuous_searchspace_creation_from_bounds(): """A purely continuous search space is created from example bounds.""" parameters = (