From d04fb4a5bd4b637bf473fcca5461c22d6e204328 Mon Sep 17 00:00:00 2001
From: Sait Cakmak <saitcakmak@meta.com>
Date: Mon, 24 Mar 2025 19:16:19 -0700
Subject: [PATCH] Fix posterior with observation noise in batched MTGP models
 (#2782)

Summary:
Pull Request resolved: https://github.com/pytorch/botorch/pull/2782

Posterior call with `observation_noise=True` would fail with fully Bayesian MTGP model before this change. This diff updates the logic that applies the noise to take batch shape into account.

Differential Revision: D71643890
---
 botorch/models/gpytorch.py                   | 22 +++++++++-----
 test/models/test_fully_bayesian_multitask.py | 32 ++++++++++++++++++--
 2 files changed, 44 insertions(+), 10 deletions(-)

diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py
index b6b490125f..2e157ef5dc 100644
--- a/botorch/models/gpytorch.py
+++ b/botorch/models/gpytorch.py
@@ -47,7 +47,7 @@
 from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
 from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
 from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
-from torch import Tensor
+from torch import broadcast_shapes, Tensor
 
 if TYPE_CHECKING:
     from botorch.posteriors.posterior_list import PosteriorList  # pragma: no cover
@@ -858,15 +858,23 @@ def _apply_noise(
             # get task features for training points
             train_task_features = self.train_inputs[0][..., self._task_feature]
             train_task_features = self._map_tasks(train_task_features).long()
-            noise_by_task = torch.zeros(self.num_tasks, dtype=X.dtype, device=X.device)
+            noise_by_task = torch.zeros(
+                *self.batch_shape, self.num_tasks, dtype=X.dtype, device=X.device
+            )
             for task_feature in unique_test_task_features:
                 mask = train_task_features == task_feature
-                noise_by_task[task_feature] = self.likelihood.noise[mask].mean(
-                    dim=-1, keepdim=True
-                )
+                noise_by_task[..., task_feature] = self.likelihood.noise[
+                    ..., mask
+                ].mean(dim=-1)
             # noise_shape is `broadcast(test_batch_shape, model.batch_shape) x q`
-            noise_shape = X.shape[:-1]
-            observation_noise = noise_by_task[test_task_features].expand(noise_shape)
+            noise_shape = (
+                broadcast_shapes(X.shape[:-2], self.batch_shape) + X.shape[-2:-1]
+            )
+            # Expand and gather ensures we pick correct noise dimensions for
+            # batch evaluations of batched models.
+            observation_noise = noise_by_task.expand(*noise_shape[:-1], -1).gather(
+                dim=-1, index=test_task_features.expand(noise_shape)
+            )
             return self.likelihood(
                 mvn,
                 X,
diff --git a/test/models/test_fully_bayesian_multitask.py b/test/models/test_fully_bayesian_multitask.py
index 9c232e2a19..23d99493f7 100644
--- a/test/models/test_fully_bayesian_multitask.py
+++ b/test/models/test_fully_bayesian_multitask.py
@@ -282,9 +282,35 @@ def test_fit_model(
             self.assertIsInstance(posterior, GaussianMixturePosterior)
             self.assertIsInstance(posterior, GaussianMixturePosterior)
 
-            test_X = torch.rand(*batch_shape, d, **tkwargs)
-            posterior = model.posterior(test_X)
-            self.assertIsInstance(posterior, GaussianMixturePosterior)
+            # Test with observation noise.
+            # Add task index to have variability in added noise.
+            task_idcs = torch.tensor(
+                [[i % self.num_tasks] for i in range(batch_shape[-1])],
+                device=self.device,
+            )
+            test_X_w_task = torch.cat(
+                [test_X, task_idcs.expand(*batch_shape, 1)], dim=-1
+            )
+            noise_free_posterior = model.posterior(X=test_X_w_task)
+            noisy_posterior = model.posterior(X=test_X_w_task, observation_noise=True)
+            self.assertAllClose(noisy_posterior.mean, noise_free_posterior.mean)
+            added_noise = noisy_posterior.variance - noise_free_posterior.variance
+            self.assertTrue(torch.all(added_noise > 0.0))
+            if infer_noise is False:
+                # Check that correct noise was added.
+                train_tasks = train_X[..., 4]
+                mean_noise_by_task = torch.tensor(
+                    [
+                        train_Yvar[train_tasks == i].mean(dim=0)
+                        for i in train_tasks.unique(sorted=True)
+                    ],
+                    device=self.device,
+                )
+                expected_noise = mean_noise_by_task[task_idcs]
+                self.assertAllClose(
+                    added_noise, expected_noise.expand_as(added_noise), atol=1e-4
+                )
+
             # Mean/variance
             expected_shape = (
                 *batch_shape[: MCMC_DIM + 2],