Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug] KeOps RBF kernel not properly equipped with prediction variance? #2566

Open
matthieudelsart opened this issue Aug 13, 2024 · 0 comments
Labels

Comments

@matthieudelsart
Copy link

matthieudelsart commented Aug 13, 2024

🐛 Bug

Contrarily to when using the standard RBF kernel, using the keops.RBFKernel to get the predicted variance results in a bug, which seems similar to this one.
The same thing occurs when trying to predict the standard deviation, confidence intervals, etc.

To reproduce

Code snippet to reproduce

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_size=train_x.size(-1))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel(ard_num_dims=train_x.size(-1)))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda() # I am using the FixedGaussianNoiseLikelihood but the same issue seems to occur everywhere
model = ExactGPModel(train_x, train_y, likelihood).cuda()

model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1) 
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

training_iter = 10

for i in range(training_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f - Lengthscale_0: %.3f - Outputscale: %.3f' % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale[0, 0].item(),
        model.covar_module.outputscale.item(),
    ))  
    optimizer.step()

model.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    observed_pred = model.likelihood(model(train_x))
    pred_mean = observed_pred.mean
    pred_variance = observed_pred.variance

Stack trace/error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], [line 45](vscode-notebook-cell:?execution_count=16&line=45)
     [43](vscode-notebook-cell:?execution_count=16&line=43) observed_pred = model.likelihood(model(train_x))
     [44](vscode-notebook-cell:?execution_count=16&line=44) pred_mean = observed_pred.mean
---> [45](vscode-notebook-cell:?execution_count=16&line=45) pred_variance = observed_pred.variance

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309, in MultivariateNormal.variance(self)
    [305](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:305) @property
    [306](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:306) def variance(self) -> Tensor:
    [307](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:307)     if self.islazy:
    [308](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:308)         # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
--> [309](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309)         diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
    [310](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:310)         diag = diag.view(diag.shape[:-1] + self._event_shape)
    [311](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:311)         variance = diag.expand(self._batch_shape + self._event_shape)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411, in LinearOperator.diagonal(self, offset, dim1, dim2)
   [1409](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1409) elif not self.is_square:
   [1410](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1410)     raise RuntimeError("LinearOperator#diagonal is only implemented for square operators.")
-> [1411](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411) return self._diagonal()

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     [57](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:57) kwargs_pkl = pickle.dumps(kwargs)
     [58](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:58) if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> [59](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59)     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     [60](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:60) return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
     [22](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:22) @functools.wraps(method)
     [23](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:23) def wrapped(self, *args, **kwargs):
     [24](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:24)     with torch.set_grad_enabled(self._is_grad_enabled):
---> [25](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25)         output = method(self, *args, **kwargs)
     [26](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:26)     return output

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126, in LazyEvaluatedKernelTensor._diagonal(self)
    [124](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:124)     expected_shape = self.shape[:-1]
    [125](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:125)     if res.shape != expected_shape:
--> [126](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126)         raise RuntimeError(
    [127](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:127)             "The kernel {} is not equipped to handle and diag. Expected size {}. "
    [128](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:128)             "Got size {}".format(self.kernel.__class__.__name__, expected_shape, res.shape)
    [129](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:129)         )
    [131](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:131) if isinstance(res, LinearOperator):
    [132](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:132)     res = res.to_dense()

RuntimeError: The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([45159]). Got size torch.Size([45159, 45159])

Expected Behavior

I get for pred_variance a tensor corresponding to the predicted variance for each point, like when using the standard RBF kernel.

System information

  • GPyTorch Version: 1.12
  • PyTorch Version: 2.3.0
  • Computer info:
    • MacBook Pro M2, MacOs Sonoma 14.4.1
    • Service: Connected to Amazon EC2
    • Instance: g5.xlarge, Amazon Linux 2
@matthieudelsart matthieudelsart changed the title [Bug] Keops RBF kernel not equipped with covariance matrix? [Bug] KeOps RBF kernel not properly equipped with prediction variance? Aug 13, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant