-
Notifications
You must be signed in to change notification settings - Fork 571
Enable fantasy models for multitask GPs Reborn #2317
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Changes from all commits
969a9ec
b9dc064
6c2fd48
f50a9f8
2f7a3cf
28ee4ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,8 @@ | |
from torch import Tensor | ||
|
||
from .. import settings | ||
|
||
from ..distributions import MultitaskMultivariateNormal | ||
from ..lazy import LazyEvaluatedKernelTensor | ||
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache | ||
|
||
|
@@ -134,16 +136,28 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ | |
A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have | ||
been added and all test-time caches have been updated. | ||
""" | ||
if not isinstance(full_output, MultitaskMultivariateNormal): | ||
target_batch_shape = targets.shape[:-1] | ||
else: | ||
target_batch_shape = targets.shape[:-2] | ||
|
||
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix | ||
|
||
batch_shape = full_inputs[0].shape[:-2] | ||
|
||
full_mean = full_mean.view(*batch_shape, -1) | ||
num_train = self.num_train | ||
|
||
if isinstance(full_output, MultitaskMultivariateNormal): | ||
num_tasks = full_output.event_shape[-1] | ||
full_mean = full_mean.view(*batch_shape, -1, num_tasks) | ||
fant_mean = full_mean[..., (num_train // num_tasks) :, :] | ||
full_targets = full_targets.view(*target_batch_shape, -1) | ||
else: | ||
full_mean = full_mean.view(*batch_shape, -1) | ||
fant_mean = full_mean[..., num_train:] | ||
|
||
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated. | ||
fant_fant_covar = full_covar[..., num_train:, num_train:] | ||
fant_mean = full_mean[..., num_train:] | ||
mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar) | ||
fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs) | ||
mvn_obs = fant_likelihood(mvn, inputs, **kwargs) | ||
|
@@ -209,6 +223,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ | |
new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape) | ||
# no need to repeat the covar cache, broadcasting will do the right thing | ||
|
||
if isinstance(full_output, MultitaskMultivariateNormal): | ||
full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous() | ||
|
||
# Create new DefaultPredictionStrategy object | ||
fant_strat = self.__class__( | ||
train_inputs=full_inputs, | ||
|
@@ -285,7 +302,11 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera | |
# NOTE TO FUTURE SELF: | ||
# You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact | ||
# GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! | ||
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) | ||
|
||
if len(self.mean_cache.shape) == 4: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main issue I have with this PR is this line. While working with a simple BO loop using BoTorch to test my code changes and observe the shapes of everything going through the code, I found that sometimes These observed shapes are also why my unit test has There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious if anyone has any thoughts about this! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do you have a sense for why this is? Is this some insufficient invalidation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly I'm not 100% sure, but This is called from # The below are torch.tensors, I just show their dimensions
inputs =[torch.Size([1,1,3])
targets = torch.Size([5,1,1,4]) And # The below are torch.tensors, I just show their dimensions
inputs = [torch.Size([1,1,3])]
targets = torch.Size([5,1,1,4])
full_inputs = [torch.Size([1, 6, 3])]
full_targets = torch.Size([5, 1, 6, 4])
full_output = MultitaskMultivariateNormal(loc: torch.Size([1, 24])) The scal_transf = ScalarizedPosteriorTransform(weights=torch.tensor([1.0] + [0.0]*dim, dtype=torch.double))
# Define qKG acquisition function
qKG = qKnowledgeGradient(model,\
posterior_transform=scal_transf,\
num_fantasies=5) Hopefully this helps! I'm not sure what the expected behavior should be, but please let me know how I can help. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the extra 1 appearing is just a soft incompatibility we never noticed where BoTorch requires an explicit task dim for the labels, but we don't in gpytorch. Indeed, my default is usually to have a single dim label vector, so when writing the code something like this could have slipped by me. |
||
res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1) | ||
else: | ||
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) | ||
res = res + test_mean | ||
|
||
return res | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import unittest | ||
from math import pi | ||
|
||
import torch | ||
|
||
import gpytorch | ||
from gpytorch.distributions import MultitaskMultivariateNormal | ||
from gpytorch.kernels import ScaleKernel, RBFKernelGrad | ||
from gpytorch.likelihoods import MultitaskGaussianLikelihood | ||
from gpytorch.means import ConstantMeanGrad | ||
from gpytorch.test.base_test_case import BaseTestCase | ||
|
||
# Simple training data | ||
num_train_samples = 15 | ||
num_fantasies = 10 | ||
dim = 1 | ||
train_X = torch.linspace(0, 1, num_train_samples).reshape(-1, 1) | ||
train_Y = torch.hstack([ | ||
torch.sin(train_X * (2 * pi)).reshape(-1, 1), | ||
(2 * pi) * torch.cos(train_X * (2 * pi)).reshape(-1, 1), | ||
]) | ||
|
||
|
||
class GPWithDerivatives(gpytorch.models.ExactGP): | ||
def __init__(self, train_X, train_Y): | ||
likelihood = MultitaskGaussianLikelihood(num_tasks=1 + dim) | ||
super().__init__(train_X, train_Y, likelihood) | ||
self.mean_module = ConstantMeanGrad() | ||
self.base_kernel = RBFKernelGrad() | ||
self.covar_module = ScaleKernel(self.base_kernel) | ||
self._num_outputs = 1 + dim | ||
|
||
def forward(self, x): | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
return MultitaskMultivariateNormal(mean_x, covar_x) | ||
|
||
|
||
class TestDerivativeGPFutures(BaseTestCase, unittest.TestCase): | ||
|
||
# Inspired by test_lanczos_fantasy_model | ||
def test_derivative_gp_futures(self): | ||
model = GPWithDerivatives(train_X, train_Y) | ||
mll = gpytorch.mlls.sum_marginal_log_likelihood.ExactMarginalLogLikelihood(model.likelihood, model) | ||
|
||
mll.train() | ||
mll.eval() | ||
|
||
# get a posterior to fill in caches | ||
model(torch.randn(num_train_samples).reshape(-1, 1)) | ||
|
||
new_x = torch.randn((1, 1, dim)) | ||
new_y = torch.randn((num_fantasies, 1, 1, 1 + dim)) | ||
|
||
# just check that this can run without error | ||
model.get_fantasy_model(new_x, new_y) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of
.view().contiguous()
, can also just usereshape()
here