Skip to content

Commit 7e93bfd

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix posterior with observation noise in batched MTGP models (#2782)
Summary: Posterior call with `observation_noise=True` would fail with fully Bayesian MTGP model before this change. This diff updates the logic that applies the noise to take batch shape into account. Differential Revision: D71643890
1 parent 9c1c759 commit 7e93bfd

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

botorch/models/gpytorch.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
4848
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
4949
from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
50-
from torch import Tensor
50+
from torch import broadcast_shapes, Tensor
5151

5252
if TYPE_CHECKING:
5353
from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover
@@ -858,15 +858,23 @@ def _apply_noise(
858858
# get task features for training points
859859
train_task_features = self.train_inputs[0][..., self._task_feature]
860860
train_task_features = self._map_tasks(train_task_features).long()
861-
noise_by_task = torch.zeros(self.num_tasks, dtype=X.dtype, device=X.device)
861+
noise_by_task = torch.zeros(
862+
*self.batch_shape, self.num_tasks, dtype=X.dtype, device=X.device
863+
)
862864
for task_feature in unique_test_task_features:
863865
mask = train_task_features == task_feature
864-
noise_by_task[task_feature] = self.likelihood.noise[mask].mean(
865-
dim=-1, keepdim=True
866-
)
866+
noise_by_task[..., task_feature] = self.likelihood.noise[
867+
..., mask
868+
].mean(dim=-1)
867869
# noise_shape is `broadcast(test_batch_shape, model.batch_shape) x q`
868-
noise_shape = X.shape[:-1]
869-
observation_noise = noise_by_task[test_task_features].expand(noise_shape)
870+
noise_shape = (
871+
broadcast_shapes(X.shape[:-2], self.batch_shape) + X.shape[-2:-1]
872+
)
873+
# Expand and gather ensures we pick correct noise dimensions for
874+
# batch evaluations of batched models.
875+
observation_noise = noise_by_task.expand(*noise_shape[:-1], -1).gather(
876+
dim=-1, index=test_task_features.expand(noise_shape)
877+
)
870878
return self.likelihood(
871879
mvn,
872880
X,

test/models/test_fully_bayesian_multitask.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,33 @@ def test_fit_model(
282282
self.assertIsInstance(posterior, GaussianMixturePosterior)
283283
self.assertIsInstance(posterior, GaussianMixturePosterior)
284284

285-
test_X = torch.rand(*batch_shape, d, **tkwargs)
286-
posterior = model.posterior(test_X)
287-
self.assertIsInstance(posterior, GaussianMixturePosterior)
285+
# Test with observation noise.
286+
# Add task index to have variability in added noise.
287+
task_idcs = torch.tensor(
288+
[[i % self.num_tasks] for i in range(batch_shape[-1])]
289+
)
290+
test_X_w_task = torch.cat(
291+
[test_X, task_idcs.expand(*batch_shape, 1)], dim=-1
292+
)
293+
noise_free_posterior = model.posterior(X=test_X_w_task)
294+
noisy_posterior = model.posterior(X=test_X_w_task, observation_noise=True)
295+
self.assertAllClose(noisy_posterior.mean, noise_free_posterior.mean)
296+
added_noise = noisy_posterior.variance - noise_free_posterior.variance
297+
self.assertTrue(torch.all(added_noise > 0.0))
298+
if infer_noise is False:
299+
# Check that correct noise was added.
300+
train_tasks = train_X[..., 4]
301+
mean_noise_by_task = torch.tensor(
302+
[
303+
train_Yvar[train_tasks == i].mean(dim=0)
304+
for i in train_tasks.unique(sorted=True)
305+
]
306+
)
307+
expected_noise = mean_noise_by_task[task_idcs]
308+
self.assertAllClose(
309+
added_noise, expected_noise.expand_as(added_noise), atol=1e-4
310+
)
311+
288312
# Mean/variance
289313
expected_shape = (
290314
*batch_shape[: MCMC_DIM + 2],

0 commit comments

Comments
 (0)