|
47 | 47 | from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
|
48 | 48 | from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
|
49 | 49 | from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
|
50 |
| -from torch import Tensor |
| 50 | +from torch import broadcast_shapes, Tensor |
51 | 51 |
|
52 | 52 | if TYPE_CHECKING:
|
53 | 53 | from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover
|
@@ -858,15 +858,23 @@ def _apply_noise(
|
858 | 858 | # get task features for training points
|
859 | 859 | train_task_features = self.train_inputs[0][..., self._task_feature]
|
860 | 860 | train_task_features = self._map_tasks(train_task_features).long()
|
861 |
| - noise_by_task = torch.zeros(self.num_tasks, dtype=X.dtype, device=X.device) |
| 861 | + noise_by_task = torch.zeros( |
| 862 | + *self.batch_shape, self.num_tasks, dtype=X.dtype, device=X.device |
| 863 | + ) |
862 | 864 | for task_feature in unique_test_task_features:
|
863 | 865 | mask = train_task_features == task_feature
|
864 |
| - noise_by_task[task_feature] = self.likelihood.noise[mask].mean( |
865 |
| - dim=-1, keepdim=True |
866 |
| - ) |
| 866 | + noise_by_task[..., task_feature] = self.likelihood.noise[ |
| 867 | + ..., mask |
| 868 | + ].mean(dim=-1) |
867 | 869 | # noise_shape is `broadcast(test_batch_shape, model.batch_shape) x q`
|
868 |
| - noise_shape = X.shape[:-1] |
869 |
| - observation_noise = noise_by_task[test_task_features].expand(noise_shape) |
| 870 | + noise_shape = ( |
| 871 | + broadcast_shapes(X.shape[:-2], self.batch_shape) + X.shape[-2:-1] |
| 872 | + ) |
| 873 | + # Expand and gather ensures we pick correct noise dimensions for |
| 874 | + # batch evaluations of batched models. |
| 875 | + observation_noise = noise_by_task.expand(*noise_shape[:-1], -1).gather( |
| 876 | + dim=-1, index=test_task_features.expand(noise_shape) |
| 877 | + ) |
870 | 878 | return self.likelihood(
|
871 | 879 | mvn,
|
872 | 880 | X,
|
|
0 commit comments