-
Notifications
You must be signed in to change notification settings - Fork 76
Botorch preset #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/gp
Are you sure you want to change the base?
Botorch preset #757
Changes from all commits
59e9697
bcc51bb
9c6051c
447cf6a
5845f0f
741997b
762b8b9
d6f3128
6909be7
13b1349
aca3f69
e1c4150
6ad33f4
69c41da
a259c8b
9a32c53
4f4fd55
3c5ec16
2f8be3a
680d8a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """Custom GPyTorch components.""" | ||
|
|
||
| import torch | ||
| from botorch.models.multitask import _compute_multitask_mean | ||
| from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL | ||
|
Comment on lines
+4
to
+5
|
||
| from gpytorch.constraints import GreaterThan | ||
| from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood | ||
| from gpytorch.means import MultitaskMean | ||
| from gpytorch.means.multitask_mean import Mean | ||
| from gpytorch.priors import LogNormalPrior | ||
| from torch import Tensor | ||
| from torch.nn import Module | ||
|
|
||
|
|
||
| class HadamardConstantMean(Mean): | ||
| """A GPyTorch mean function implementing BoTorch's multitask mean logic. | ||
|
|
||
| While GPyTorch already provides a :class:`~gpytorch.means.MultitaskMean` class, it | ||
| computes mean values for all (input, task)-pairs (where input means all parameters | ||
| except the task parameter), i.e. it intrinsically applies a Cartesian expansion. | ||
| However, for the regular transfer learning setting, we only need the means for the | ||
| pairs that are actually observed/requested. BoTorch subselects the relevant means | ||
| from the GPyTorch output in `MultiTaskGP.forward`, i.e. it uses a class-based | ||
| approach to define its special logic for the multitask case. In contrast, BayBE uses | ||
| a composition approach, which is more flexible but requires that the logic is | ||
| injected via a self-contained `Mean` object, which is what this class provides. | ||
|
|
||
| Note: | ||
| Analogous to GPyTorch's | ||
| https://github.com/cornellius-gp/gpytorch/blob/main/gpytorch/likelihoods/hadamard_gaussian_likelihood.py | ||
| but where the logic is applied to the mean function, i.e. we learn a different | ||
| (constant) mean for each task. | ||
| """ | ||
|
|
||
| def __init__(self, mean_module: Module, num_tasks: int, task_feature: int): | ||
| super().__init__() | ||
| self.multitask_mean = MultitaskMean(mean_module, num_tasks=num_tasks) | ||
| self.task_feature = task_feature | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| # Adapted from https://github.com/meta-pytorch/botorch/blob/e0f4f5b941b5949a4a1171bf8d4ee9f74f146f3a/botorch/models/multitask.py#L397 | ||
|
|
||
| # Convert task feature to positive index | ||
| task_feature = self.task_feature % x.shape[-1] | ||
|
|
||
| # Split input into task and non-task components | ||
| x_before = x[..., :task_feature] | ||
| task_idcs = x[..., task_feature : task_feature + 1] | ||
| x_after = x[..., task_feature + 1 :] | ||
|
|
||
| return _compute_multitask_mean( | ||
| self.multitask_mean, x_before, task_idcs, x_after | ||
| ) | ||
|
|
||
|
|
||
| def make_botorch_multitask_likelihood( | ||
| num_tasks: int, task_feature: int | ||
| ) -> HadamardGaussianLikelihood: | ||
| """Adapted from :class:`botorch.models.multitask.MultiTaskGP`.""" | ||
| noise_prior = LogNormalPrior(loc=-4.0, scale=1.0) | ||
| return HadamardGaussianLikelihood( | ||
| num_tasks=num_tasks, | ||
| batch_shape=torch.Size(), | ||
| noise_prior=noise_prior, | ||
| noise_constraint=GreaterThan( | ||
| MIN_INFERRED_NOISE_LEVEL, | ||
| transform=None, | ||
| initial_value=noise_prior.mode, | ||
| ), | ||
| task_feature_index=task_feature, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,18 +2,24 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||||||||||||||
| from collections.abc import Iterable | ||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, ClassVar | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from attrs import define, field | ||||||||||||||||||||||||||||||||||||||||||
| from attrs.converters import optional | ||||||||||||||||||||||||||||||||||||||||||
| from attrs.validators import is_callable | ||||||||||||||||||||||||||||||||||||||||||
| from typing_extensions import override | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from baybe.exceptions import IncompatibleSearchSpaceError | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.kernels.base import Kernel | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.kernels.composite import ProductKernel | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.parameters.categorical import TaskParameter | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.parameters.enum import ParameterKind | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.parameters.selectors import ( | ||||||||||||||||||||||||||||||||||||||||||
| ParameterSelectorProtocol, | ||||||||||||||||||||||||||||||||||||||||||
| TypeSelector, | ||||||||||||||||||||||||||||||||||||||||||
| to_parameter_selector, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.searchspace.core import SearchSpace | ||||||||||||||||||||||||||||||||||||||||||
| from baybe.surrogates.gaussian_process.components.generic import ( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -27,6 +33,8 @@ | |||||||||||||||||||||||||||||||||||||||||
| from gpytorch.kernels import Kernel as GPyTorchKernel | ||||||||||||||||||||||||||||||||||||||||||
| from torch import Tensor | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from baybe.parameters.base import Parameter | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| KernelFactoryProtocol = GPComponentFactoryProtocol[Kernel | GPyTorchKernel] | ||||||||||||||||||||||||||||||||||||||||||
| PlainKernelFactory = PlainGPComponentFactory[Kernel | GPyTorchKernel] | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -35,6 +43,80 @@ | |||||||||||||||||||||||||||||||||||||||||
| PlainKernelFactory = PlainGPComponentFactory[Kernel] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @define | ||||||||||||||||||||||||||||||||||||||||||
| class _KernelFactory(KernelFactoryProtocol, ABC): | ||||||||||||||||||||||||||||||||||||||||||
| """Base class for kernel factories.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # For internal use only: sanity check mechanism to remind developers of new | ||||||||||||||||||||||||||||||||||||||||||
| # factories to actually use the parameter selector when it is provided | ||||||||||||||||||||||||||||||||||||||||||
| # TODO: Perhaps we can find a more elegant way to enforce this by design | ||||||||||||||||||||||||||||||||||||||||||
| _uses_parameter_names: ClassVar[bool] = False | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| supported_parameter_kinds: ClassVar[ParameterKind] = ParameterKind.REGULAR | ||||||||||||||||||||||||||||||||||||||||||
| """The parameter kinds supported by the kernel factory.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| parameter_selector: ParameterSelectorProtocol | None = field( | ||||||||||||||||||||||||||||||||||||||||||
| default=None, converter=optional(to_parameter_selector) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| """An optional selector to specify which parameters are considered by the kernel.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def get_parameter_names(self, searchspace: SearchSpace) -> tuple[str, ...] | None: | ||||||||||||||||||||||||||||||||||||||||||
| """Get the names of the parameters to be considered by the kernel.""" | ||||||||||||||||||||||||||||||||||||||||||
| if self.parameter_selector is None: | ||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return tuple( | ||||||||||||||||||||||||||||||||||||||||||
| p.name for p in searchspace.parameters if self.parameter_selector(p) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| """Validate that the given parameters are supported by the factory. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||
| parameters: The parameters to validate. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||
| IncompatibleSearchSpaceError: If unsupported parameter kinds are found. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| if unsupported := [ | ||||||||||||||||||||||||||||||||||||||||||
| p.name for p in parameters if not (p.kind & self.supported_parameter_kinds) | ||||||||||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||||||||||
| raise IncompatibleSearchSpaceError( | ||||||||||||||||||||||||||||||||||||||||||
| f"'{type(self).__name__}' does not support parameter kind(s) for " | ||||||||||||||||||||||||||||||||||||||||||
| f"parameter(s) {unsupported}. Supported kinds: " | ||||||||||||||||||||||||||||||||||||||||||
| f"{self.supported_parameter_kinds}." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||
| def __call__( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel | GPyTorchKernel: | ||||||||||||||||||||||||||||||||||||||||||
| """Construct the kernel, validating parameter kinds before construction.""" | ||||||||||||||||||||||||||||||||||||||||||
| if self.parameter_selector is not None: | ||||||||||||||||||||||||||||||||||||||||||
| params = [p for p in searchspace.parameters if self.parameter_selector(p)] | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| params = list(searchspace.parameters) | ||||||||||||||||||||||||||||||||||||||||||
| self._validate_parameter_kinds(params) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| return self._make(searchspace, train_x, train_y) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||||||||||||||||||||
| def _make( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel | GPyTorchKernel: | ||||||||||||||||||||||||||||||||||||||||||
| """Construct the kernel.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __attrs_post_init__(self): | ||||||||||||||||||||||||||||||||||||||||||
| if self.parameter_selector is not None and not self._uses_parameter_names: | ||||||||||||||||||||||||||||||||||||||||||
| raise AssertionError( | ||||||||||||||||||||||||||||||||||||||||||
| f"A `parameter_selector` was provided to " | ||||||||||||||||||||||||||||||||||||||||||
| f"`{type(self).__name__}`, but the class does not set " | ||||||||||||||||||||||||||||||||||||||||||
| f"`_uses_parameter_names = True`. Subclasses that accept a " | ||||||||||||||||||||||||||||||||||||||||||
| f"parameter selector must explicitly set this flag to confirm " | ||||||||||||||||||||||||||||||||||||||||||
| f"they actually use the selected parameter names." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @define | ||||||||||||||||||||||||||||||||||||||||||
| class ICMKernelFactory(KernelFactoryProtocol): | ||||||||||||||||||||||||||||||||||||||||||
| """A kernel factory that constructs an ICM kernel for transfer learning. | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -76,6 +158,43 @@ def _default_task_kernel_factory(self) -> KernelFactoryProtocol: | |||||||||||||||||||||||||||||||||||||||||
| def __call__( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel: | ||||||||||||||||||||||||||||||||||||||||||
| if searchspace.task_idx is None: | ||||||||||||||||||||||||||||||||||||||||||
| raise IncompatibleSearchSpaceError( | ||||||||||||||||||||||||||||||||||||||||||
| f"'{type(self).__name__}' can only be used with a searchspace that " | ||||||||||||||||||||||||||||||||||||||||||
| f"contains a '{TaskParameter.__name__}'." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| base_kernel = self.base_kernel_factory(searchspace, train_x, train_y) | ||||||||||||||||||||||||||||||||||||||||||
| task_kernel = self.task_kernel_factory(searchspace, train_x, train_y) | ||||||||||||||||||||||||||||||||||||||||||
| return ProductKernel([base_kernel, task_kernel]) | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(base_kernel, Kernel): | ||||||||||||||||||||||||||||||||||||||||||
| base_kernel = base_kernel.to_gpytorch(searchspace) | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(task_kernel, Kernel): | ||||||||||||||||||||||||||||||||||||||||||
| task_kernel = task_kernel.to_gpytorch(searchspace) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Ensure correct partitioning between base and task kernels active dimensions | ||||||||||||||||||||||||||||||||||||||||||
| all_idcs = set(range(len(searchspace.comp_rep_columns))) | ||||||||||||||||||||||||||||||||||||||||||
| allowed_task_idcs = {searchspace.task_idx} | ||||||||||||||||||||||||||||||||||||||||||
| allowed_base_idcs = all_idcs - allowed_task_idcs | ||||||||||||||||||||||||||||||||||||||||||
| base_idcs = ( | ||||||||||||||||||||||||||||||||||||||||||
| set(dims) | ||||||||||||||||||||||||||||||||||||||||||
| if (dims := base_kernel.active_dims.tolist()) is not None | ||||||||||||||||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| task_idcs = ( | ||||||||||||||||||||||||||||||||||||||||||
| set(dims) | ||||||||||||||||||||||||||||||||||||||||||
| if (dims := task_kernel.active_dims.tolist()) is not None | ||||||||||||||||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+178
to
+186
|
||||||||||||||||||||||||||||||||||||||||||
| base_idcs = ( | |
| set(dims) | |
| if (dims := base_kernel.active_dims.tolist()) is not None | |
| else None | |
| ) | |
| task_idcs = ( | |
| set(dims) | |
| if (dims := task_kernel.active_dims.tolist()) is not None | |
| else None | |
| base_active_dims = base_kernel.active_dims | |
| task_active_dims = task_kernel.active_dims | |
| base_idcs = ( | |
| all_idcs | |
| if base_active_dims is None | |
| else set(base_active_dims.tolist()) | |
| ) | |
| task_idcs = ( | |
| all_idcs | |
| if task_active_dims is None | |
| else set(task_active_dims.tolist()) |
Copilot
AI
Apr 17, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The subset check for base-kernel active_dims is incorrect: base_idcs > allowed_base_idcs checks for a strict superset, not “not a subset”. This will miss invalid cases (e.g. {0, task_idx}) and potentially flag none. Use a proper subset validation (e.g. not base_idcs <= allowed_base_idcs) and consider a clearer error if active_dims is None (meaning “all dims”).
| if base_idcs is not None and (base_idcs > allowed_base_idcs): | |
| raise ValueError( | |
| if base_idcs is None: | |
| raise ValueError( | |
| "The base kernel's 'active_dims' must be restricted to the non-task " | |
| f"indices {allowed_base_idcs}; got None, which means all dimensions." | |
| ) | |
| if not base_idcs <= allowed_base_idcs: | |
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There’s trailing whitespace at the end of this line (after “but”). This tends to cause noisy diffs and may violate whitespace linting in some setups; consider removing it.