Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug] Some priors do not respect selected device #2581

Open
slishak-PX opened this issue Sep 9, 2024 · 2 comments
Open

[Bug] Some priors do not respect selected device #2581

slishak-PX opened this issue Sep 9, 2024 · 2 comments
Labels

Comments

@slishak-PX
Copy link

slishak-PX commented Sep 9, 2024

🐛 Bug

When sampling from a prior that's been moved to GPU, the correct device is only used for some priors, even though the state_dict has been updated correctly (as of #2550, which this issue seems related to, although no regression was introduced as far as I can tell):

from gpytorch import priors

for prior in (
    priors.NormalPrior(1.0, 1.0),
    priors.GammaPrior(1.0, 1.0),
    priors.HalfCauchyPrior(1.0, 1.0),
    priors.HalfNormalPrior(1.0, 1.0),
    priors.LogNormalPrior(1.0, 1.0),
    priors.UniformPrior(1.0, 2.0),
):
    prior.to("cuda:0")
    samples = prior.rsample()
    print(f"{str(prior):<35} {str(samples.device):<8} {dict(prior.state_dict())}")
NormalPrior()                       cuda:0   {'loc': tensor(1., device='cuda:0'), 'scale': tensor(1., device='cuda:0')}
GammaPrior()                        cuda:0   {'concentration': tensor(1., device='cuda:0'), 'rate': tensor(1., device='cuda:0')}
HalfCauchyPrior()                   cpu      {'_transformed_scale': tensor(1., device='cuda:0')}
HalfNormalPrior()                   cpu      {'_transformed_scale': tensor(1., device='cuda:0')}
LogNormalPrior()                    cpu      {'_transformed_loc': tensor(1., device='cuda:0'), '_transformed_scale': tensor(1., device='cuda:0')}
UniformPrior(low: 1.0, high: 2.0)   cpu      {}

This manifests itself in BoTorch when a LogNormal prior is in use. If the fit fails the first time, new initial hyperparameter values are sampled from the prior, which results in a device mismatch. In the reproducible example below, I'm triggering this manually with optimizer_kwargs set such that a warning is raised, and warning_handler set to trigger a retry for any warning.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from gpytorch import kernels, priors
from gpytorch.mlls import ExactMarginalLogLikelihood

n_inputs = 4
n_outputs = 2
n_train = 256
device = torch.device("cuda:0")

train_x = torch.rand(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_outputs, dtype=torch.float64, device=device)

model = SingleTaskGP(
    train_x, 
    train_y, 
    input_transform=Normalize(n_inputs),
    outcome_transform=Standardize(m=n_outputs),
    covar_module=kernels.ScaleKernel(
        base_kernel=kernels.MaternKernel(
            nu=2.5,
            ard_num_dims=n_inputs,
            batch_shape=torch.Size([n_outputs]),
            lengthscale_prior=priors.LogNormalPrior(0.5, 0.5),
        ),
        outputscale_prior=priors.GammaPrior(2.0, 0.15),
        batch_shape=torch.Size([n_outputs]),
    )
)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(
    mll, 
    optimizer_kwargs={"timeout_sec": 1e-3}, 
    warning_handler=lambda _: False,
)

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 34
     16 model = SingleTaskGP(
     17     train_x, 
     18     train_y, 
   (...)
     30     )
     31 )
     33 mll = ExactMarginalLogLikelihood(model.likelihood, model)
---> 34 fit_gpytorch_mll(
     35     mll, 
     36     optimizer_kwargs={"timeout_sec": 1e-3}, 
     37     warning_handler=lambda _: False,
     38 )

File .../python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    101 if optimizer is not None:  # defer to per-method defaults
    102     kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
    105     mll,
    106     type(mll.likelihood),
    107     type(mll.model),
    108     closure=closure,
    109     closure_kwargs=closure_kwargs,
    110     optimizer_kwargs=optimizer_kwargs,
    111     **kwargs,
    112 )

File .../python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File .../python3.10/site-packages/botorch/fit.py:198, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
    195         ckpt_nograd = {name: ckpt[name] for name in params_nograd}
    197     with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
--> 198         sample_all_priors(mll.model)
    200 try:
    201     # Fit the model
    202     with catch_warnings(record=True) as warning_list, debug(True):

File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:191, in sample_all_priors(model, max_retries)
    186         raise RuntimeError(
    187             "Failed to sample a feasible parameter value "
    188             f"from the prior after {max_retries} attempts."
    189         )
    190 else:
--> 191     raise e

File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:171, in sample_all_priors(model, max_retries)
    166 prior_shape = prior._extended_shape()
    167 if prior_shape.numel() == 1:
    168     # For a univariate prior we can sample the size of the closure.
    169     # Otherwise we will sample exactly the same value for all
    170     # lengthscales where we commonly specify a univariate prior.
--> 171     setting_closure(module, prior.sample(closure(module).shape))
    172 else:
    173     closure_shape = closure(module).shape

File .../python3.10/site-packages/gpytorch/kernels/kernel.py:221, in Kernel._lengthscale_closure(self, m, v)
    219 def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
    220     # Used by the lengthscale_prior
--> 221     return m._set_lengthscale(v)

File .../python3.10/site-packages/gpytorch/kernels/kernel.py:231, in Kernel._set_lengthscale(self, value)
    228 if not torch.is_tensor(value):
    229     value = torch.as_tensor(value).to(self.raw_lengthscale)
--> 231 self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))

File .../python3.10/site-packages/gpytorch/module.py:103, in Module.initialize(self, **kwargs)
    101 elif torch.is_tensor(val):
    102     constraint = self.constraint_for_parameter_name(name)
--> 103     if constraint is not None and constraint.enforced and not constraint.check_raw(val):
    104         raise RuntimeError(
    105             "Attempting to manually set a parameter value that is out of bounds of "
    106             f"its current constraints, {constraint}. "
    107             "Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
    108             "(noise_constraint=gpytorch.constraints.GreaterThan(better_lower_bound))"
    109         )
    110     try:

File .../python3.10/site-packages/gpytorch/constraints/constraints.py:90, in Interval.check_raw(self, tensor)
     88 def check_raw(self, tensor) -> bool:
     89     return bool(
---> 90         torch.all((self.transform(tensor) <= self.upper_bound))
     91         and torch.all(self.transform(tensor) >= self.lower_bound)
     92     )

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected Behavior

System information

Please complete the following information:

  • GPyTorch Version: 1.14.dev2+g83332c2c (latest main)
  • PyTorch Version: '2.0.1+cu117'
  • Computer OS: Linux
@slishak-PX slishak-PX added the bug label Sep 9, 2024
@Balandat
Copy link
Collaborator

Thanks for raising this. Yes this is the same reason as for #2550 - the .to() doesn't move the attributes over the GPU. @hvarfner in case you have any immediate ideas on this - basically looks like we just need to overwrite the .to() method in the same way.

@hvarfner
Copy link
Contributor

@Balandat Interesting. I'm not sure if we need to overwrite .to() or just modify save/load_state_dict(), but I'll have a look.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants