Skip to content

[RFC] Refactor Input Transforms #1176

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.model import Model
from botorch.models.transforms.input import InputPerturbation
from botorch.models.transforms.input_augmentation import InputPerturbation
from botorch.posteriors import DeterministicPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
Expand Down
51 changes: 29 additions & 22 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from botorch import settings
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.input_augmentation import InputAugmentationTransform
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
from botorch.models.utils import validate_input_scaling
from botorch.sampling.samplers import MCSampler
from botorch.utils.containers import TrainingData
from gpytorch.constraints.constraints import GreaterThan
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
mean_module: Optional[Mean] = None,
outcome_transform: Optional[OutcomeTransform] = None,
input_transform: Optional[InputTransform] = None,
input_augmentation_transform: Optional[InputAugmentationTransform] = None,
) -> None:
r"""A single-task exact GP model.

Expand All @@ -88,6 +90,8 @@ def __init__(
`.posterior` on the model will be on the original scale).
input_transform: An input transform that is applied in the model's
forward pass.
input_augmentation_transform: An input augmentation transform that is
applied in the `posterior` call.

Example:
>>> train_X = torch.rand(20, 2)
Expand Down Expand Up @@ -148,11 +152,11 @@ def __init__(
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
if input_augmentation_transform is not None:
self.input_augmentation_transform = input_augmentation_transform
self.to(train_X)

def forward(self, x: Tensor) -> MultivariateNormal:
if self.training:
x = self.transform_inputs(x)
def _forward(self, x: Tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
Expand Down Expand Up @@ -191,6 +195,7 @@ def __init__(
mean_module: Optional[Mean] = None,
outcome_transform: Optional[OutcomeTransform] = None,
input_transform: Optional[InputTransform] = None,
input_augmentation_transform: Optional[InputAugmentationTransform] = None,
**kwargs: Any,
) -> None:
r"""A single-task exact GP model using fixed noise levels.
Expand All @@ -210,6 +215,8 @@ def __init__(
`.posterior` on the model will be on the original scale).
input_transform: An input transfrom that is applied in the model's
forward pass.
input_augmentation_transform: An input augmentation transform that is
applied in the `posterior` call.

Example:
>>> train_X = torch.rand(20, 2)
Expand Down Expand Up @@ -262,7 +269,8 @@ def __init__(
self.input_transform = input_transform
if outcome_transform is not None:
self.outcome_transform = outcome_transform

if input_augmentation_transform is not None:
self.input_augmentation_transform = input_augmentation_transform
self.to(train_X)

def fantasize(
Expand Down Expand Up @@ -298,24 +306,19 @@ def fantasize(
The constructed fantasy model.
"""
propagate_grads = kwargs.pop("propagate_grads", False)
with fantasize_flag():
with settings.propagate_grads(propagate_grads):
post_X = self.posterior(
X, observation_noise=observation_noise, **kwargs
)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
# Use the mean of the previous noise values (TODO: be smarter here).
# noise should be batch_shape x q x m when X is batch_shape x q x d, and
# Y_fantasized is num_fantasies x batch_shape x q x m.
noise_shape = Y_fantasized.shape[1:]
noise = self.likelihood.noise.mean().expand(noise_shape)
return self.condition_on_observations(
X=self.transform_inputs(X), Y=Y_fantasized, noise=noise
)
with settings.propagate_grads(propagate_grads):
post_X = self._posterior(X, observation_noise=observation_noise, **kwargs)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
# Use the mean of the previous noise values (TODO: be smarter here).
# noise should be batch_shape x q x m when X is batch_shape x q x d, and
# Y_fantasized is num_fantasies x batch_shape x q x m.
noise_shape = Y_fantasized.shape[1:]
noise = self.likelihood.noise.mean().expand(noise_shape)
return self.condition_on_observations(
X=self.transform_inputs(X), Y=Y_fantasized, noise=noise
)

def forward(self, x: Tensor) -> MultivariateNormal:
if self.training:
x = self.transform_inputs(x)
def _forward(self, x: Tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
Expand Down Expand Up @@ -370,6 +373,7 @@ def __init__(
train_Yvar: Tensor,
outcome_transform: Optional[OutcomeTransform] = None,
input_transform: Optional[InputTransform] = None,
input_augmentation_transform: Optional[InputAugmentationTransform] = None,
) -> None:
r"""A single-task exact GP model using a heteroskedastic noise model.

Expand All @@ -386,6 +390,8 @@ def __init__(
variances, which will happen after this transform is applied.
input_transform: An input transfrom that is applied in the model's
forward pass.
input_augmentation_transform: An input augmentation transform that is
applied in the `posterior` call.

Example:
>>> train_X = torch.rand(20, 2)
Expand Down Expand Up @@ -419,6 +425,7 @@ def __init__(
train_Y=train_Y,
likelihood=likelihood,
input_transform=input_transform,
input_augmentation_transform=input_augmentation_transform,
)
self.register_added_loss_term("noise_added_loss")
self.update_added_loss_term(
Expand Down
29 changes: 5 additions & 24 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self._num_outputs

def posterior(
def _posterior(
self,
X: Tensor,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> GPyTorchPosterior:
r"""Computes the posterior over model outputs at the provided points.
Expand All @@ -133,17 +132,13 @@ def posterior(
observation_noise: If True, add the observation noise from the
likelihood to the posterior. If a Tensor, use it directly as the
observation noise (must be of shape `(batch_shape) x q`).
posterior_transform: An optional PosteriorTransform.

Returns:
A `GPyTorchPosterior` object, representing a batch of `b` joint
distributions over `q` points. Includes observation noise if
specified.
"""
self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
# `model.forward()` at the training time
X = self.transform_inputs(X)
with gpt_posterior_settings():
mvn = self(X)
if observation_noise is not False:
Expand All @@ -158,8 +153,6 @@ def posterior(
posterior = GPyTorchPosterior(mvn=mvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior

def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
Expand Down Expand Up @@ -301,12 +294,11 @@ def _transform_tensor_args(
)
return X, Y.squeeze(-1), None if Yvar is None else Yvar.squeeze(-1)

def posterior(
def _posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> GPyTorchPosterior:
r"""Computes the posterior over model outputs at the provided points.
Expand All @@ -323,17 +315,13 @@ def posterior(
observation_noise: If True, add the observation noise from the
likelihood to the posterior. If a Tensor, use it directly as the
observation noise (must be of shape `(batch_shape) x q x m`).
posterior_transform: An optional PosteriorTransform.

Returns:
A `GPyTorchPosterior` object, representing `batch_shape` joint
distributions over `q` points and the outputs selected by
`output_indices` each. Includes observation noise if specified.
"""
self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
# `model.forward()` at the training time
X = self.transform_inputs(X)
with gpt_posterior_settings():
# insert a dimension for the output dimension
if self._num_outputs > 1:
Expand Down Expand Up @@ -369,8 +357,6 @@ def posterior(
posterior = GPyTorchPosterior(mvn=mvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior

def condition_on_observations(
Expand Down Expand Up @@ -549,6 +535,8 @@ def posterior(
by `output_indices` each. Includes measurement noise if
`observation_noise` is specified.
"""
# TODO: Not sure if this needs special handling or is good with a `_`.
# Leaving untouched for now.
self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
# `model.forward()` at the training time
Expand Down Expand Up @@ -622,12 +610,11 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
"long-format" multi-task GP in the style of `MultiTaskGP`.
"""

def posterior(
def _posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> GPyTorchPosterior:
r"""Computes the posterior over model outputs at the provided points.
Expand All @@ -644,7 +631,6 @@ def posterior(
observation_noise: If True, add observation noise from the respective
likelihoods. If a Tensor, specifies the observation noise levels
to add.
posterior_transform: An optional PosteriorTransform.

Returns:
A `GPyTorchPosterior` object, representing `batch_shape` joint
Expand All @@ -663,9 +649,6 @@ def posterior(
X_full = _make_X_full(X=X, output_indices=output_indices, tf=self._task_feature)

self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
# `model.forward()` at the training time
X_full = self.transform_inputs(X_full)
with gpt_posterior_settings():
mvn = self(X_full)
if observation_noise is not False:
Expand All @@ -685,6 +668,4 @@ def posterior(
posterior = GPyTorchPosterior(mvn=mtmvn)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
return posterior
Loading