diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index c0148aca55..7e41049665 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -24,9 +24,11 @@ DefaultKernelFactory, _default_noise_factory, ) +from baybe.surrogates.gaussian_process.prior_modules import PriorMean from baybe.utils.conversion import to_string if TYPE_CHECKING: + from botorch.models import SingleTaskGP from botorch.models.gpytorch import GPyTorchModel from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform @@ -113,11 +115,60 @@ class GaussianProcessSurrogate(Surrogate): _model = field(init=False, default=None, eq=False) """The actual model.""" + # Transfer learning fields + _prior_gp: SingleTaskGP | None = field(init=False, default=None, eq=False) + """Prior GP to extract mean/covariance for transfer learning.""" + @staticmethod def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: """Create a Gaussian process surrogate from one of the defined presets.""" return make_gp_from_preset(preset) + @classmethod + def from_prior( + cls, + prior_gp: GaussianProcessSurrogate, + kernel_factory: KernelFactory | None = None, + **kwargs, + ) -> GaussianProcessSurrogate: + """Create a GP surrogate using a prior GP's predictions as the mean function. + + Transfers knowledge by using the prior GP's posterior mean predictions + as the mean function for a new GP, while learning covariance from scratch. + + Args: + prior_gp: Fitted GaussianProcessSurrogate to use as prior + kernel_factory: Kernel factory for covariance components + **kwargs: Additional arguments for GaussianProcessSurrogate constructor + + Returns: + New GaussianProcessSurrogate instance with transfer learning + + Raises: + ValueError: If prior_gp is not a GaussianProcessSurrogate or is not fitted + """ + from copy import deepcopy + + # Validate prior GP is fitted + if not isinstance(prior_gp, cls): + raise ValueError( + "prior_gp must be a fitted GaussianProcessSurrogate instance" + ) + if prior_gp._model is None: + raise ValueError("Prior GP must be fitted before use") + + # Configure kernel factory (always needed since we only do mean transfer now) + if kernel_factory is None: + kernel_factory = DefaultKernelFactory() + + # Create new surrogate instance + instance = cls(kernel_or_factory=kernel_factory, **kwargs) + + # Configure for transfer learning - store the BoTorch model + instance._prior_gp = deepcopy(prior_gp.to_botorch()) + + return instance + @override def to_botorch(self) -> GPyTorchModel: return self._model @@ -152,22 +203,30 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: assert self._searchspace is not None context = _ModelContext(self._searchspace) - numerical_idxs = context.get_numerical_indices(train_x.shape[-1]) - # For GPs, we let botorch handle the scaling. See [Scaling Workaround] above. - input_transform = Normalize( - train_x.shape[-1], - bounds=context.parameter_bounds, - indices=list(numerical_idxs), - ) - outcome_transform = Standardize(train_y.shape[-1]) - # extract the batch shape of the training data batch_shape = train_x.shape[:-2] + # Configure input/output transforms + if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"): + # Use prior's transforms for consistency in transfer learning + input_transform = self._prior_gp.input_transform + outcome_transform = self._prior_gp.outcome_transform + else: + # For GPs, we let botorch handle scaling. See [Scaling Workaround] above. + input_transform = Normalize( + train_x.shape[-1], + bounds=context.parameter_bounds, + indices=numerical_idxs, + ) + outcome_transform = Standardize(train_y.shape[-1]) + # create GP mean - mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) + if self._prior_gp is not None: + mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape) + else: + mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) # define the covariance module for the numeric dimensions base_covar_module = self.kernel_factory( diff --git a/baybe/surrogates/gaussian_process/prior_modules.py b/baybe/surrogates/gaussian_process/prior_modules.py new file mode 100644 index 0000000000..d718b3c7a6 --- /dev/null +++ b/baybe/surrogates/gaussian_process/prior_modules.py @@ -0,0 +1,55 @@ +"""Prior modules for Gaussian process transfer learning.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import gpytorch +import torch +from botorch.models import SingleTaskGP +from torch import Tensor + + +class PriorMean(gpytorch.means.Mean): + """GPyTorch mean module using a trained GP as prior mean. + + This mean module wraps a trained Gaussian Process and uses its predictions + as the mean function for another GP. + + Args: + gp: Trained Gaussian Process to use as mean function. + batch_shape: Batch shape for the mean module. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any + ) -> None: + super().__init__() + + # Deep copy and freeze the GP + self.gp: SingleTaskGP = deepcopy(gp) + self.batch_shape: torch.Size = batch_shape + + # Freeze parameters and set eval mode once + for param in self.gp.parameters(): + param.requires_grad = False + + def forward(self, x: Tensor) -> Tensor: + """Compute the mean function using the wrapped GP. + + Args: + x: Input tensor for which to compute the mean. + + Returns: + Mean predictions from the wrapped GP. + """ + self.gp.eval() + self.gp.likelihood.eval() + with torch.no_grad(), gpytorch.settings.fast_pred_var(): + mean = self.gp(x).mean.detach() + + # Handle batch dimensions + target_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-1]) + return mean.reshape(target_shape)