You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When sampling from a prior that's been moved to GPU, the correct device is only used for some priors, even though the state_dict has been updated correctly (as of #2550, which this issue seems related to, although no regression was introduced as far as I can tell):
NormalPrior() cuda:0 {'loc': tensor(1., device='cuda:0'), 'scale': tensor(1., device='cuda:0')}
GammaPrior() cuda:0 {'concentration': tensor(1., device='cuda:0'), 'rate': tensor(1., device='cuda:0')}
HalfCauchyPrior() cpu {'_transformed_scale': tensor(1., device='cuda:0')}
HalfNormalPrior() cpu {'_transformed_scale': tensor(1., device='cuda:0')}
LogNormalPrior() cpu {'_transformed_loc': tensor(1., device='cuda:0'), '_transformed_scale': tensor(1., device='cuda:0')}
UniformPrior(low: 1.0, high: 2.0) cpu {}
This manifests itself in BoTorch when a LogNormal prior is in use. If the fit fails the first time, new initial hyperparameter values are sampled from the prior, which results in a device mismatch. In the reproducible example below, I'm triggering this manually with optimizer_kwargs set such that a warning is raised, and warning_handler set to trigger a retry for any warning.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 34
16 model = SingleTaskGP(
17 train_x,
18 train_y,
(...)
30 )
31 )
33 mll = ExactMarginalLogLikelihood(model.likelihood, model)
---> 34 fit_gpytorch_mll(
35 mll,
36 optimizer_kwargs={"timeout_sec": 1e-3},
37 warning_handler=lambda _: False,
38 )
File .../python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
101 if optimizer is not None: # defer to per-method defaults
102 kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
105 mll,
106 type(mll.likelihood),
107 type(mll.model),
108 closure=closure,
109 closure_kwargs=closure_kwargs,
110 optimizer_kwargs=optimizer_kwargs,
111 **kwargs,
112 )
File .../python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File .../python3.10/site-packages/botorch/fit.py:198, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
195 ckpt_nograd = {name: ckpt[name] for name in params_nograd}
197 with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
--> 198 sample_all_priors(mll.model)
200 try:
201 # Fit the model
202 with catch_warnings(record=True) as warning_list, debug(True):
File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:191, in sample_all_priors(model, max_retries)
186 raise RuntimeError(
187 "Failed to sample a feasible parameter value "
188 f"from the prior after {max_retries} attempts."
189 )
190 else:
--> 191 raise e
File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:171, in sample_all_priors(model, max_retries)
166 prior_shape = prior._extended_shape()
167 if prior_shape.numel() == 1:
168 # For a univariate prior we can sample the size of the closure.
169 # Otherwise we will sample exactly the same value for all
170 # lengthscales where we commonly specify a univariate prior.
--> 171 setting_closure(module, prior.sample(closure(module).shape))
172 else:
173 closure_shape = closure(module).shape
File .../python3.10/site-packages/gpytorch/kernels/kernel.py:221, in Kernel._lengthscale_closure(self, m, v)
219 def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
220 # Used by the lengthscale_prior
--> 221 return m._set_lengthscale(v)
File .../python3.10/site-packages/gpytorch/kernels/kernel.py:231, in Kernel._set_lengthscale(self, value)
228 if not torch.is_tensor(value):
229 value = torch.as_tensor(value).to(self.raw_lengthscale)
--> 231 self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
File .../python3.10/site-packages/gpytorch/module.py:103, in Module.initialize(self, **kwargs)
101 elif torch.is_tensor(val):
102 constraint = self.constraint_for_parameter_name(name)
--> 103 if constraint is not None and constraint.enforced and not constraint.check_raw(val):
104 raise RuntimeError(
105 "Attempting to manually set a parameter value that is out of bounds of "
106 f"its current constraints, {constraint}. "
107 "Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
108 "(noise_constraint=gpytorch.constraints.GreaterThan(better_lower_bound))"
109 )
110 try:
File .../python3.10/site-packages/gpytorch/constraints/constraints.py:90, in Interval.check_raw(self, tensor)
88 def check_raw(self, tensor) -> bool:
89 return bool(
---> 90 torch.all((self.transform(tensor) <= self.upper_bound))
91 and torch.all(self.transform(tensor) >= self.lower_bound)
92 )
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Thanks for raising this. Yes this is the same reason as for #2550 - the .to() doesn't move the attributes over the GPU. @hvarfner in case you have any immediate ideas on this - basically looks like we just need to overwrite the .to() method in the same way.
🐛 Bug
When sampling from a prior that's been moved to GPU, the correct device is only used for some priors, even though the
state_dict
has been updated correctly (as of #2550, which this issue seems related to, although no regression was introduced as far as I can tell):This manifests itself in BoTorch when a
LogNormal
prior is in use. If the fit fails the first time, new initial hyperparameter values are sampled from the prior, which results in a device mismatch. In the reproducible example below, I'm triggering this manually withoptimizer_kwargs
set such that a warning is raised, andwarning_handler
set to trigger a retry for any warning.To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
Expected Behavior
System information
Please complete the following information:
The text was updated successfully, but these errors were encountered: