From d04fb4a5bd4b637bf473fcca5461c22d6e204328 Mon Sep 17 00:00:00 2001 From: Sait Cakmak <saitcakmak@meta.com> Date: Mon, 24 Mar 2025 19:16:19 -0700 Subject: [PATCH] Fix posterior with observation noise in batched MTGP models (#2782) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2782 Posterior call with `observation_noise=True` would fail with fully Bayesian MTGP model before this change. This diff updates the logic that applies the noise to take batch shape into account. Differential Revision: D71643890 --- botorch/models/gpytorch.py | 22 +++++++++----- test/models/test_fully_bayesian_multitask.py | 32 ++++++++++++++++++-- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index b6b490125f..2e157ef5dc 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -47,7 +47,7 @@ from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator -from torch import Tensor +from torch import broadcast_shapes, Tensor if TYPE_CHECKING: from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover @@ -858,15 +858,23 @@ def _apply_noise( # get task features for training points train_task_features = self.train_inputs[0][..., self._task_feature] train_task_features = self._map_tasks(train_task_features).long() - noise_by_task = torch.zeros(self.num_tasks, dtype=X.dtype, device=X.device) + noise_by_task = torch.zeros( + *self.batch_shape, self.num_tasks, dtype=X.dtype, device=X.device + ) for task_feature in unique_test_task_features: mask = train_task_features == task_feature - noise_by_task[task_feature] = self.likelihood.noise[mask].mean( - dim=-1, keepdim=True - ) + noise_by_task[..., task_feature] = self.likelihood.noise[ + ..., mask + ].mean(dim=-1) # noise_shape is `broadcast(test_batch_shape, model.batch_shape) x q` - noise_shape = X.shape[:-1] - observation_noise = noise_by_task[test_task_features].expand(noise_shape) + noise_shape = ( + broadcast_shapes(X.shape[:-2], self.batch_shape) + X.shape[-2:-1] + ) + # Expand and gather ensures we pick correct noise dimensions for + # batch evaluations of batched models. + observation_noise = noise_by_task.expand(*noise_shape[:-1], -1).gather( + dim=-1, index=test_task_features.expand(noise_shape) + ) return self.likelihood( mvn, X, diff --git a/test/models/test_fully_bayesian_multitask.py b/test/models/test_fully_bayesian_multitask.py index 9c232e2a19..23d99493f7 100644 --- a/test/models/test_fully_bayesian_multitask.py +++ b/test/models/test_fully_bayesian_multitask.py @@ -282,9 +282,35 @@ def test_fit_model( self.assertIsInstance(posterior, GaussianMixturePosterior) self.assertIsInstance(posterior, GaussianMixturePosterior) - test_X = torch.rand(*batch_shape, d, **tkwargs) - posterior = model.posterior(test_X) - self.assertIsInstance(posterior, GaussianMixturePosterior) + # Test with observation noise. + # Add task index to have variability in added noise. + task_idcs = torch.tensor( + [[i % self.num_tasks] for i in range(batch_shape[-1])], + device=self.device, + ) + test_X_w_task = torch.cat( + [test_X, task_idcs.expand(*batch_shape, 1)], dim=-1 + ) + noise_free_posterior = model.posterior(X=test_X_w_task) + noisy_posterior = model.posterior(X=test_X_w_task, observation_noise=True) + self.assertAllClose(noisy_posterior.mean, noise_free_posterior.mean) + added_noise = noisy_posterior.variance - noise_free_posterior.variance + self.assertTrue(torch.all(added_noise > 0.0)) + if infer_noise is False: + # Check that correct noise was added. + train_tasks = train_X[..., 4] + mean_noise_by_task = torch.tensor( + [ + train_Yvar[train_tasks == i].mean(dim=0) + for i in train_tasks.unique(sorted=True) + ], + device=self.device, + ) + expected_noise = mean_noise_by_task[task_idcs] + self.assertAllClose( + added_noise, expected_noise.expand_as(added_noise), atol=1e-4 + ) + # Mean/variance expected_shape = ( *batch_shape[: MCMC_DIM + 2],