Skip to content

Commit

Permalink
Use correct precision in botorch wrapper (#417)
Browse files Browse the repository at this point in the history
This PR enables the use of `DTypeFloatTorch` in the BoTorch Wrapper.

Previously, this wrapper converted anything into `float64` precision.
This can cause troubles an imprecision during the corresponding
`forward` call in that function.
  • Loading branch information
AVHopp authored Nov 11, 2024
2 parents e2dd8b9 + 4fe0f5e commit 3b7c1f2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion baybe/utils/botorch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from botorch.test_functions import SyntheticTestFunction

from baybe.utils.torch import DTypeFloatTorch


def botorch_function_wrapper(test_function: SyntheticTestFunction):
"""Turn a BoTorch test function into a format accepted by lookup in simulations.
Expand All @@ -19,7 +21,7 @@ def botorch_function_wrapper(test_function: SyntheticTestFunction):

def wrapper(*x: float) -> float:
# Cast the provided list of floats to a tensor.
x_tensor = torch.tensor(x)
x_tensor = torch.tensor(x, dtype=DTypeFloatTorch)
result = test_function.forward(x_tensor)
# We do not need to return a tuple here.
return float(result)
Expand Down

0 comments on commit 3b7c1f2

Please # to comment.