-
Notifications
You must be signed in to change notification settings - Fork 76
Support PositiveIndexKernel and dispatching via TaskParameter
#728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cc9803d
545727d
2020786
a2d42a6
32c5167
36702cb
45ece28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,9 @@ | ||||||
| """Categorical parameters.""" | ||||||
|
|
||||||
| import gc | ||||||
| from enum import Enum | ||||||
| from functools import cached_property | ||||||
| from typing import Any | ||||||
|
|
||||||
| import numpy as np | ||||||
| import pandas as pd | ||||||
|
|
@@ -16,6 +18,13 @@ | |||||
| from baybe.utils.numerical import DTypeFloatNumpy | ||||||
|
|
||||||
|
|
||||||
| class TaskCorrelation(Enum): | ||||||
| """Task correlation modes for TaskParameter.""" | ||||||
|
|
||||||
| UNKNOWN = "unknown" | ||||||
| POSITIVE = "positive" | ||||||
|
|
||||||
|
|
||||||
| def _convert_values(value, self, field) -> tuple[str, ...]: | ||||||
| """Sort and convert values for categorical parameters.""" | ||||||
| value = nonstring_to_tuple(value, self, field) | ||||||
|
|
@@ -87,6 +96,30 @@ class TaskParameter(CategoricalParameter): | |||||
| encoding: CategoricalEncoding = field(default=CategoricalEncoding.INT, init=False) | ||||||
| # See base class. | ||||||
|
|
||||||
| task_correlation: TaskCorrelation = field(default=TaskCorrelation.POSITIVE) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the only big potential problem with this PR I cans pot is the naming of this attribute In isolation the name is totally accurate and fine. But we already have plans to expand this attribute so have potentially more choices, like eg Now of course we could change the name of the attribute later, but since this is merged to main and potentially released before we have the other choices, we would introduce a breaking change that has tobe deprecated. So it would be beneficial if we would avoid that situation. Here two proposals how to do that:
@AdrianSosic do you agree with this issue of the attribute name? |
||||||
| """Task correlation. Defaults to positive correlation via PositiveIndexKernel.""" | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| @task_correlation.validator | ||||||
| def _validate_task_correlation_active_values( # noqa: DOC101, DOC103 | ||||||
| self, _: Any, value: TaskCorrelation | ||||||
| ) -> None: | ||||||
| """Validate active values compatibility with task correlation mode. | ||||||
|
|
||||||
| Raises: | ||||||
| ValueError: If task_correlation is POSITIVE but active_values contains more | ||||||
| than one value. | ||||||
| """ | ||||||
| # Check POSITIVE constraint: must have exactly one active value | ||||||
| # Note: _active_values is the internal field, could be None | ||||||
| if value == TaskCorrelation.POSITIVE and self._active_values is not None: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use Why are you suing |
||||||
| if len(self._active_values) > 1: | ||||||
| raise ValueError( | ||||||
| f"Task correlation '{TaskCorrelation.POSITIVE.value}' requires " | ||||||
| f"one active value, but {len(self._active_values)} were provided: " | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| f"{self._active_values}. The POSITIVE mode uses the " | ||||||
| f"PositiveIndexKernel which assumes a single target task." | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # Collect leftover original slotted classes processed by `attrs.define` | ||||||
| gc.collect() | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| from baybe.constraints.base import Constraint | ||
| from baybe.parameters import TaskParameter | ||
| from baybe.parameters.base import Parameter | ||
| from baybe.parameters.categorical import TaskCorrelation | ||
| from baybe.searchspace.continuous import SubspaceContinuous | ||
| from baybe.searchspace.discrete import ( | ||
| MemorySize, | ||
|
|
@@ -279,6 +280,45 @@ def n_tasks(self) -> int: | |
| except StopIteration: | ||
| return 1 | ||
|
|
||
| @property | ||
| def target_task_idxs(self) -> list[int] | None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would always prefer returning tuples such cases unless there is a limitation that it really must be a list |
||
| """The indices of the target tasks in the computational representation. | ||
|
|
||
| Returns a list of integer indices corresponding to each active value in the | ||
| TaskParameter. Returns None when there are no task parameters. | ||
| """ | ||
| # TODO [16932]: This approach only works for a single task parameter. | ||
| try: | ||
| task_param = next( | ||
| p for p in self.parameters if isinstance(p, TaskParameter) | ||
| ) | ||
| comp_df = task_param.comp_df | ||
|
|
||
| # Extract computational representation indices for all active values | ||
| target_task_idxs = [ | ||
| int(comp_df.loc[active_value].iloc[0]) | ||
| for active_value in task_param.active_values | ||
| ] | ||
| return target_task_idxs | ||
|
|
||
| # When there are no task parameters, return None | ||
| except StopIteration: | ||
| return None | ||
|
|
||
| @property | ||
| def task_correlation(self) -> TaskCorrelation | None: | ||
| """The task correlation mode for this searchspace.""" | ||
| # TODO [16932]: This approach only works for a single task parameter. | ||
| try: | ||
| task_param = next( | ||
| p for p in self.parameters if isinstance(p, TaskParameter) | ||
| ) | ||
| return task_param.task_correlation | ||
|
|
||
| # When there are no task parameters, we return None | ||
| except StopIteration: | ||
| return None | ||
|
|
||
| def get_comp_rep_parameter_indices(self, name: str, /) -> tuple[int, ...]: | ||
| """Find a parameter's column indices in the computational representation. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||
| from typing_extensions import override | ||||||
|
|
||||||
| from baybe.parameters.base import Parameter | ||||||
| from baybe.parameters.categorical import TaskCorrelation | ||||||
| from baybe.searchspace.core import SearchSpace | ||||||
| from baybe.surrogates.base import Surrogate | ||||||
| from baybe.surrogates.gaussian_process.kernel_factory import ( | ||||||
|
|
@@ -69,6 +70,16 @@ def parameter_bounds(self) -> Tensor: | |||||
|
|
||||||
| return torch.from_numpy(self.searchspace.scaling_bounds.values) | ||||||
|
|
||||||
| @property | ||||||
| def task_correlation(self) -> TaskCorrelation | None: | ||||||
| """Get the task correlation mode of the task parameter, if available.""" | ||||||
| return self.searchspace.task_correlation | ||||||
|
|
||||||
| @property | ||||||
| def target_task_idxs(self) -> list[int] | None: | ||||||
| """Determine target task index for PositiveIndexKernel normalization.""" | ||||||
| return self.searchspace.target_task_idxs | ||||||
|
|
||||||
|
Comment on lines
+73
to
+82
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how necessary are these helpers? I can get them just via |
||||||
| def get_numerical_indices(self, n_inputs: int) -> tuple[int, ...]: | ||||||
| """Get the indices of the regular numerical model inputs.""" | ||||||
| return tuple(i for i in range(n_inputs) if i != self.task_idx) | ||||||
|
|
@@ -181,7 +192,17 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | |||||
| # create GP covariance | ||||||
| if not context.is_multitask: | ||||||
| covar_module = base_covar_module | ||||||
| else: | ||||||
| elif context.task_correlation == TaskCorrelation.POSITIVE: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| task_covar_module = ( | ||||||
| botorch.models.kernels.positive_index.PositiveIndexKernel( | ||||||
| num_tasks=context.n_tasks, | ||||||
| active_dims=context.task_idx, | ||||||
| rank=context.n_tasks, # TODO: make controllable | ||||||
| target_task_index=context.target_task_idxs[0], | ||||||
| ) | ||||||
| ) | ||||||
| covar_module = base_covar_module * task_covar_module | ||||||
| elif context.task_correlation == TaskCorrelation.UNKNOWN: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+195
to
+205
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just for our common understanding: these parts will eventually have to e outsourced to a |
||||||
| task_covar_module = gpytorch.kernels.IndexKernel( | ||||||
| num_tasks=context.n_tasks, | ||||||
| active_dims=context.task_idx, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
|
|
||
| from baybe.objectives import SingleTargetObjective | ||
| from baybe.parameters import TaskParameter | ||
| from baybe.parameters.categorical import TaskCorrelation | ||
| from baybe.searchspace import SearchSpace | ||
| from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate | ||
| from benchmarks.definition import TransferLearningRegressionBenchmarkSettings | ||
|
|
@@ -39,7 +40,12 @@ def __call__(self) -> pd.DataFrame: | |
| class SearchSpaceFactory(Protocol): | ||
| """Protocol for SearchSpace creation used in TL regression benchmarks.""" | ||
|
|
||
| def __call__(self, data: pd.DataFrame, use_task_parameter: bool) -> SearchSpace: | ||
| def __call__( | ||
| self, | ||
| data: pd.DataFrame, | ||
| use_task_parameter: bool, | ||
| task_correlation: TaskCorrelation = TaskCorrelation.UNKNOWN, | ||
| ) -> SearchSpace: | ||
| """Create a SearchSpace for regression benchmark evaluation. | ||
|
|
||
| Args: | ||
|
|
@@ -48,6 +54,8 @@ def __call__(self, data: pd.DataFrame, use_task_parameter: bool) -> SearchSpace: | |
| scenarios. If True, creates search space with TaskParameter for | ||
| TL models. If False, creates vanilla search space without | ||
| task parameter. | ||
| task_correlation: The task correlation mode (UNKNOWN or POSITIVE). | ||
| Only used when use_task_parameter is True. | ||
|
|
||
| Returns: | ||
| The TL and non-TL searchspaces for the benchmark. | ||
|
|
@@ -100,12 +108,6 @@ def spearman_rho_score(x: np.ndarray, y: np.ndarray, /) -> float: | |
| return rho | ||
|
|
||
|
|
||
| # Dictionary mapping transfer learning model names to their surrogate classes | ||
| TL_MODELS = { | ||
| "index_kernel": GaussianProcessSurrogate, | ||
| } | ||
|
|
||
|
|
||
| # Regression metrics to evaluate model performance | ||
| REGRESSION_METRICS = { | ||
| root_mean_squared_error, | ||
|
|
@@ -161,12 +163,17 @@ def run_tl_regression_benchmark( | |
| # Create search space without task parameter | ||
| vanilla_searchspace = searchspace_factory(data=data, use_task_parameter=False) | ||
|
|
||
| # Create transfer learning search space (with task parameter) | ||
| tl_searchspace = searchspace_factory(data=data, use_task_parameter=True) | ||
| # Create transfer learning search spaces (with task parameter) | ||
| tl_index_searchspace = searchspace_factory( | ||
| data=data, use_task_parameter=True, task_correlation=TaskCorrelation.UNKNOWN | ||
| ) | ||
| tl_pos_index_searchspace = searchspace_factory( | ||
| data=data, use_task_parameter=True, task_correlation=TaskCorrelation.POSITIVE | ||
| ) | ||
|
|
||
| # Extract task parameter details | ||
| # Extract task parameter details (use index searchspace as reference) | ||
| task_param = next( | ||
| p for p in tl_searchspace.parameters if isinstance(p, TaskParameter) | ||
| p for p in tl_index_searchspace.parameters if isinstance(p, TaskParameter) | ||
| ) | ||
| name_task = task_param.name | ||
|
|
||
|
|
@@ -234,16 +241,36 @@ def run_tl_regression_benchmark( | |
| result.update(metrics) | ||
| results.append(result) | ||
|
|
||
| # Naive GP on full search space | ||
| # IndexKernel on full search space, no source data | ||
| metrics = _evaluate_model( | ||
| GaussianProcessSurrogate(), | ||
| target_train, | ||
| target_test, | ||
| tl_index_searchspace, | ||
| objective, | ||
| ) | ||
| result = { | ||
| "scenario": "0_index", | ||
| "mc_iter": mc_iter, | ||
| "n_train_pts": n_train_pts, | ||
| "fraction_source": 0.0, | ||
| "n_source_pts": 0, | ||
| "n_test_pts": len(target_test), | ||
| "source_data_seed": settings.random_seed + mc_iter, | ||
| } | ||
| result.update(metrics) | ||
| results.append(result) | ||
|
|
||
| # PositiveIndexKernel on full search space, no source data | ||
| metrics = _evaluate_model( | ||
| GaussianProcessSurrogate(), | ||
| target_train, | ||
| target_test, | ||
| tl_searchspace, | ||
| tl_pos_index_searchspace, | ||
| objective, | ||
| ) | ||
| result = { | ||
| "scenario": "0_full_searchspace", | ||
| "scenario": "0_pos_index", | ||
| "mc_iter": mc_iter, | ||
| "n_train_pts": n_train_pts, | ||
| "fraction_source": 0.0, | ||
|
|
@@ -277,29 +304,47 @@ def run_tl_regression_benchmark( | |
|
|
||
| combined_data = pd.concat([source_subset, target_train]) | ||
|
|
||
| for model_suffix, model_class in TL_MODELS.items(): | ||
| scenario_name = f"{int(100 * fraction_source)}_{model_suffix}" | ||
| model = model_class() | ||
|
|
||
| metrics = _evaluate_model( | ||
| model, | ||
| combined_data, | ||
| target_test, | ||
| tl_searchspace, | ||
| objective, | ||
| ) | ||
|
|
||
| result = { | ||
| "scenario": scenario_name, | ||
| "mc_iter": mc_iter, | ||
| "n_train_pts": n_train_pts, | ||
| "fraction_source": fraction_source, | ||
| "n_source_pts": len(source_subset), | ||
| "n_test_pts": len(target_test), | ||
| "source_data_seed": settings.random_seed + mc_iter, | ||
| } | ||
| result.update(metrics) | ||
| results.append(result) | ||
| # Evaluate IndexKernel | ||
| scenario_name = f"{int(100 * fraction_source)}_index" | ||
| metrics = _evaluate_model( | ||
| GaussianProcessSurrogate(), | ||
| combined_data, | ||
| target_test, | ||
| tl_index_searchspace, | ||
| objective, | ||
| ) | ||
| result = { | ||
| "scenario": scenario_name, | ||
| "mc_iter": mc_iter, | ||
| "n_train_pts": n_train_pts, | ||
| "fraction_source": fraction_source, | ||
| "n_source_pts": len(source_subset), | ||
| "n_test_pts": len(target_test), | ||
| "source_data_seed": settings.random_seed + mc_iter, | ||
| } | ||
| result.update(metrics) | ||
| results.append(result) | ||
|
|
||
| # Evaluate PositiveIndexKernel | ||
| scenario_name = f"{int(100 * fraction_source)}_pos_index" | ||
| metrics = _evaluate_model( | ||
| GaussianProcessSurrogate(), | ||
| combined_data, | ||
| target_test, | ||
| tl_pos_index_searchspace, | ||
| objective, | ||
| ) | ||
| result = { | ||
| "scenario": scenario_name, | ||
| "mc_iter": mc_iter, | ||
| "n_train_pts": n_train_pts, | ||
| "fraction_source": fraction_source, | ||
| "n_source_pts": len(source_subset), | ||
| "n_test_pts": len(target_test), | ||
| "source_data_seed": settings.random_seed + mc_iter, | ||
| } | ||
| result.update(metrics) | ||
| results.append(result) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since you expanded the benchmarks: are they still feasible or are they now timing out due to the longer runtime? |
||
|
|
||
| pbar.update(1) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kalama-ai @AdrianSosic can you quickly comment on the state of this PR? If I remember correctly, this was one of the PRs that are somewhat depending on the current refactoring work of Adrian. Has this code here already been rebased and is thus ready to review? Or do I misremember?