[Bug]: Drastically different behavior when toggling sequential
#2755
Replies: 7 comments
-
It’s a bit hard to see the difference in the results. Can you plot them on
the same image and describe what you see as the difference? Are the
differences in PF identified any larger than what you’d expect by
replicated sequential or non sequential multiple times?
Sequential should be approximately batch-size times more expensive (ie 10),
since it optimizes the locations of each of the q points in a sequential
greedy fashion, whereas non sequential does all in parallel. The
convergence behavior might be a little different and so one method’s LBFGSB
optimizer could converge in fewer or more iterations depending on the
circumstance. So a 20x speed up does not seem surprising. Sequential is an
approximation to the full joint problem, so if we do identify the true
maximizer of the joint optimization problem correctly, it often has better
black box optimization performance.
Sequential greedy does break the problem down in a way that can be easier
to solve with qEI, so the efficiency wrt the actual BO task was previously
better with sequential = true, but with qlogEI based AFs, which don’t
suffer from vanishing gradients, we generally find that joint works better.
I know internally we have switched to defaulting some of our applications
to using joint optimization, I am not sure if we have tested it thoroughly
enough to use it as the default for botorch and ax.
…On Wed, Feb 19, 2025 at 3:10 AM AdrianSosic ***@***.***> wrote:
What happened?
I've recently been testing a few things for multi-output modeling and
stumbled over some very weird (unexpected?) behavior regarding the
sequential flag of optimize_acqf:
- Even for my very simple toy problem below, I get significantly
different results when toggling the flag.
- The runtime difference is *tremendous!* For my example,
sequential=True takes roughly 6s whereas sequential=False runs for
about 130s.
Here the corresponding plots:
sequential=True seq_true.png (view on web)
<https://github.com/user-attachments/assets/32920b38-3de3-42b6-9c72-848ff1f8e914>
sequential=False seq_false.png (view on web)
<https://github.com/user-attachments/assets/f9944b40-5363-4c76-a38b-b3e67a1fc975> What
is interesting to note here
The achieved acquisition values of the batches (shown in the legend) are
roughly identical for both settings, so the two optimization strategies
seem to have ended up in two different but equivalent (in terms of function
value) local minima. From a pure acqf perspective, this suggests that both
solutions are equally good, even though sequential=True clearly gives the
better qualitative result. Perhaps you can comment on this?
Also, I have no good explanation for the runtime difference. Is this
expected? If so, is there a reason why sequential=False is the default?
Please provide a minimal, reproducible example of the unexpected behavior.
from time import perf_counter
import gpytorchimport numpy as npimport torchfrom botorch.acquisition import qLogExpectedImprovementfrom botorch.acquisition.multi_objective import qLogNoisyExpectedHypervolumeImprovementfrom botorch.fit import fit_gpytorch_mllfrom botorch.models import SingleTaskGPfrom botorch.models.transforms import Normalize, Standardizefrom botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
)from botorch.optim import optimize_acqffrom gpytorch.mlls import ExactMarginalLogLikelihoodfrom matplotlib import pyplot as plt
torch.manual_seed(1337)torch.set_default_dtype(torch.float64)
########################################################################################SEQUENTIAL = False # <-- switch this to True########################################################################################
BATCH_SIZE = 10N_TRAINING_DATA = 100N_GRID_POINTS = 100CENTER_Y0 = torch.tensor([-0.5, -0.5])CENTER_Y1 = torch.tensor([0.5, 0.5])
def fun(x: torch.Tensor) -> torch.Tensor:
y0 = -(x - CENTER_Y0).pow(2).sum(dim=1)
y1 = -(x - CENTER_Y1).pow(2).sum(dim=1)
return torch.stack([y0, y1], dim=1)
def recommend(train_X, train_Y):
mean_module = gpytorch.means.ConstantMean()
covar_module = get_matern_kernel_with_gamma_prior(2)
likelihood = get_gaussian_likelihood_with_gamma_prior()
model = SingleTaskGP(
train_X,
train_Y,
input_transform=Normalize(d=2),
outcome_transform=Standardize(m=train_Y.shape[1]),
mean_module=mean_module,
covar_module=covar_module,
likelihood=likelihood,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
if train_Y.shape[1] == 1:
acqf = qLogExpectedImprovement(model, train_Y.max())
else:
acqf = qLogNoisyExpectedHypervolumeImprovement(
model, ref_point=train_Y.min(dim=0)[0], X_baseline=train_X
)
bounds = torch.tensor([[-1.0, 1.0], [-1.0, 1.0]]).T
rec, _ = optimize_acqf(
acqf, bounds, BATCH_SIZE, num_restarts=5, raw_samples=20, sequential=SEQUENTIAL
)
return rec, acqf(rec).item()
train_X = torch.rand([N_TRAINING_DATA, 2]) * 2 - 1train_Y = fun(train_X)
t = perf_counter()rec_y0, val_y0 = recommend(train_X, train_Y[:, :1])rec_y1, val_y1 = recommend(train_X, train_Y[:, 1:])rec_p, val_p = recommend(train_X, train_Y)print(perf_counter() - t)
out_y0 = fun(rec_y0)out_y1 = fun(rec_y1)out_p = fun(rec_p)
x0_mesh, x1_mesh = torch.meshgrid(
torch.linspace(-1.0, 1.0, N_GRID_POINTS),
torch.linspace(-1.0, 1.0, N_GRID_POINTS),
)y = fun(torch.stack([x0_mesh.ravel(), x1_mesh.ravel()], dim=1))y0_mesh = torch.reshape(y[:, 0], x0_mesh.shape)y1_mesh = torch.reshape(y[:, 1], x1_mesh.shape)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plt.sca(axs[0])plt.contour(x0_mesh, x1_mesh, y0_mesh, colors="tab:red", alpha=0.2)plt.contour(x0_mesh, x1_mesh, y1_mesh, colors="tab:blue", alpha=0.2)plt.plot(*np.c_[CENTER_Y0, CENTER_Y1], "k", label="frontier")plt.plot(train_X[:, 0], train_X[:, 1], "o", color="0.7", markersize=2, label="training")plt.plot(
rec_y0[:, 0], rec_y0[:, 1], "o", color="tab:red", label=f"single_y0: {val_y0:.3f}"
)plt.plot(
rec_y1[:, 0], rec_y1[:, 1], "o", color="tab:blue", label=f"single_y1: {val_y1:.3f}"
)plt.plot(
rec_p[:, 0], rec_p[:, 1], "o", color="tab:purple", label=f"pareto: {val_p:.3f}"
)plt.legend(loc="upper left")plt.axis("square")plt.axis([-1, 1, -1, 1])
plt.sca(axs[1])frontier = fun(torch.from_numpy(np.linspace(CENTER_Y0, CENTER_Y1)))plt.plot(*frontier.T, "k", label="frontier")plt.plot(train_Y[:, 0], train_Y[:, 1], "o", color="0.7", markersize=2, label="training")plt.plot(out_y0[:, 0], out_y0[:, 1], "o", color="tab:red", label="single_y0")plt.plot(out_y1[:, 0], out_y1[:, 1], "o", color="tab:blue", label="single_y1")plt.plot(out_p[:, 0], out_p[:, 1], "o", color="tab:purple", label="pareto")plt.legend(loc="lower left")plt.axis("square")
plt.tight_layout()plt.show()
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.13.0
Python Version
3.10
Operating System
macOS
Code of Conduct
- I agree to follow BoTorch's Code of Conduct
—
Reply to this email directly, view it on GitHub
<#2750>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAW34I3PATGU4VHSNLLO2L2QQ37LAVCNFSM6AAAAABXNQXMLSVHI2DSMVQWIX3LMV43ASLTON2WKOZSHA3DENJUGE2TCMY>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
[image: AdrianSosic]*AdrianSosic* created an issue (pytorch/botorch#2750)
<#2750>
What happened?
I've recently been testing a few things for multi-output modeling and
stumbled over some very weird (unexpected?) behavior regarding the
sequential flag of optimize_acqf:
- Even for my very simple toy problem below, I get significantly
different results when toggling the flag.
- The runtime difference is *tremendous!* For my example,
sequential=True takes roughly 6s whereas sequential=False runs for
about 130s.
Here the corresponding plots:
sequential=True seq_true.png (view on web)
<https://github.com/user-attachments/assets/32920b38-3de3-42b6-9c72-848ff1f8e914>
sequential=False seq_false.png (view on web)
<https://github.com/user-attachments/assets/f9944b40-5363-4c76-a38b-b3e67a1fc975> What
is interesting to note here
The achieved acquisition values of the batches (shown in the legend) are
roughly identical for both settings, so the two optimization strategies
seem to have ended up in two different but equivalent (in terms of function
value) local minima. From a pure acqf perspective, this suggests that both
solutions are equally good, even though sequential=True clearly gives the
better qualitative result. Perhaps you can comment on this?
Also, I have no good explanation for the runtime difference. Is this
expected? If so, is there a reason why sequential=False is the default?
Please provide a minimal, reproducible example of the unexpected behavior.
from time import perf_counter
import gpytorchimport numpy as npimport torchfrom botorch.acquisition import qLogExpectedImprovementfrom botorch.acquisition.multi_objective import qLogNoisyExpectedHypervolumeImprovementfrom botorch.fit import fit_gpytorch_mllfrom botorch.models import SingleTaskGPfrom botorch.models.transforms import Normalize, Standardizefrom botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
)from botorch.optim import optimize_acqffrom gpytorch.mlls import ExactMarginalLogLikelihoodfrom matplotlib import pyplot as plt
torch.manual_seed(1337)torch.set_default_dtype(torch.float64)
########################################################################################SEQUENTIAL = False # <-- switch this to True########################################################################################
BATCH_SIZE = 10N_TRAINING_DATA = 100N_GRID_POINTS = 100CENTER_Y0 = torch.tensor([-0.5, -0.5])CENTER_Y1 = torch.tensor([0.5, 0.5])
def fun(x: torch.Tensor) -> torch.Tensor:
y0 = -(x - CENTER_Y0).pow(2).sum(dim=1)
y1 = -(x - CENTER_Y1).pow(2).sum(dim=1)
return torch.stack([y0, y1], dim=1)
def recommend(train_X, train_Y):
mean_module = gpytorch.means.ConstantMean()
covar_module = get_matern_kernel_with_gamma_prior(2)
likelihood = get_gaussian_likelihood_with_gamma_prior()
model = SingleTaskGP(
train_X,
train_Y,
input_transform=Normalize(d=2),
outcome_transform=Standardize(m=train_Y.shape[1]),
mean_module=mean_module,
covar_module=covar_module,
likelihood=likelihood,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
if train_Y.shape[1] == 1:
acqf = qLogExpectedImprovement(model, train_Y.max())
else:
acqf = qLogNoisyExpectedHypervolumeImprovement(
model, ref_point=train_Y.min(dim=0)[0], X_baseline=train_X
)
bounds = torch.tensor([[-1.0, 1.0], [-1.0, 1.0]]).T
rec, _ = optimize_acqf(
acqf, bounds, BATCH_SIZE, num_restarts=5, raw_samples=20, sequential=SEQUENTIAL
)
return rec, acqf(rec).item()
train_X = torch.rand([N_TRAINING_DATA, 2]) * 2 - 1train_Y = fun(train_X)
t = perf_counter()rec_y0, val_y0 = recommend(train_X, train_Y[:, :1])rec_y1, val_y1 = recommend(train_X, train_Y[:, 1:])rec_p, val_p = recommend(train_X, train_Y)print(perf_counter() - t)
out_y0 = fun(rec_y0)out_y1 = fun(rec_y1)out_p = fun(rec_p)
x0_mesh, x1_mesh = torch.meshgrid(
torch.linspace(-1.0, 1.0, N_GRID_POINTS),
torch.linspace(-1.0, 1.0, N_GRID_POINTS),
)y = fun(torch.stack([x0_mesh.ravel(), x1_mesh.ravel()], dim=1))y0_mesh = torch.reshape(y[:, 0], x0_mesh.shape)y1_mesh = torch.reshape(y[:, 1], x1_mesh.shape)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plt.sca(axs[0])plt.contour(x0_mesh, x1_mesh, y0_mesh, colors="tab:red", alpha=0.2)plt.contour(x0_mesh, x1_mesh, y1_mesh, colors="tab:blue", alpha=0.2)plt.plot(*np.c_[CENTER_Y0, CENTER_Y1], "k", label="frontier")plt.plot(train_X[:, 0], train_X[:, 1], "o", color="0.7", markersize=2, label="training")plt.plot(
rec_y0[:, 0], rec_y0[:, 1], "o", color="tab:red", label=f"single_y0: {val_y0:.3f}"
)plt.plot(
rec_y1[:, 0], rec_y1[:, 1], "o", color="tab:blue", label=f"single_y1: {val_y1:.3f}"
)plt.plot(
rec_p[:, 0], rec_p[:, 1], "o", color="tab:purple", label=f"pareto: {val_p:.3f}"
)plt.legend(loc="upper left")plt.axis("square")plt.axis([-1, 1, -1, 1])
plt.sca(axs[1])frontier = fun(torch.from_numpy(np.linspace(CENTER_Y0, CENTER_Y1)))plt.plot(*frontier.T, "k", label="frontier")plt.plot(train_Y[:, 0], train_Y[:, 1], "o", color="0.7", markersize=2, label="training")plt.plot(out_y0[:, 0], out_y0[:, 1], "o", color="tab:red", label="single_y0")plt.plot(out_y1[:, 0], out_y1[:, 1], "o", color="tab:blue", label="single_y1")plt.plot(out_p[:, 0], out_p[:, 1], "o", color="tab:purple", label="pareto")plt.legend(loc="lower left")plt.axis("square")
plt.tight_layout()plt.show()
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.13.0
Python Version
3.10
Operating System
macOS
Code of Conduct
- I agree to follow BoTorch's Code of Conduct
—
Reply to this email directly, view it on GitHub
<#2750>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAW34I3PATGU4VHSNLLO2L2QQ37LAVCNFSM6AAAAABXNQXMLSVHI2DSMVQWIX3LMV43ASLTON2WKOZSHA3DENJUGE2TCMY>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
It’s expected behavior that sequential=False is very slow when using hypervolume-based acquisition functions and that sequential-greedy and joint optimization provide similar acquisition values. For background, joint optimization (sequential=False) attempts to solve the problem exactly, while sequential/greedy optimization is an approximation. With sequential=True, optimize_acqf chooses each candidate in a batch one at a time; each point is chosen to maximize the acquisition value of just that point, treating any previously chosen points in the batch as pending. So with Regarding the similarity of the acquisition values, Wilson et. al. (2018) bound the regret associated with a sequential-greedy approach to maximizing qEI and show that it's not so bad. While using log acquisition functions changes things, some (unpublished) empirical evaluations I ran showed that using sequential=True still generally looked better. With multiple outcomes, joint optimization is prohibitively expensive because hypervolume improvement computations have exponential time complexity in the batch size, whereas the complexity is only polynomial in batch size with the greedy variant. Daulton et. al. (2021) have a good discussion on this. |
Beta Was this translation helpful? Give feedback.
-
As Eytan points out, the two approaches solve two versions of one problem.
Let's assume that the two optimization problems yield comparable solutions in quality, which is questionable on its own, and focus on the total cost.
So, this really depends on the problem setup. @esantorella linked a couple papers that look into the difference between joint and sequential optimization. I've benchmarked the two relatively recently on a few vanilla problems, and found |
Beta Was this translation helpful? Give feedback.
-
For single objective optimization, the overall AF values of the batch will be similar (between batch and sequential) even if only one point from the batch has high expected improvement (since the contributions from the rest have little effect). Therefore, here joint optimization finds one point that is on the optimum, and then the rest of the points have little effect on the AF value, and it is hard to find settings for those points that yield improvements in the the overall AF value. Sequential is still able to make incremental improvements to the overall AF value by selecting new starting points that are promising for the new candidate being optimized, and we see that the overall batch AF values from sequential are indeed better. This likely isn’t too big of a problem in general, because typically we don’t know where the optimum is. But here, the model does know fairly well where the optimum is, since there are 100 training points (with noiseless observations) and the functions are relatively simple. So once one candidate is chosen on the optimum, it is hard to find improvement when joint optimizing the other candidates. +1 to Elizabeth's answer about slowness due to exponential complexity with respect to the batch size when using joint optimization with qLogEHVI. The joint optimization problem is also very hard, so sequential works better in terms of performance too. See purple (sequential) vs green (joint) in figure 2b in https://arxiv.org/abs/2006.05078. |
Beta Was this translation helpful? Give feedback.
-
First and foremost: thank you all very much for taking the time to quickly write such in-depth answers, I really appreciate this a lot! 👏🏼 👏🏼 I probably should have mentioned that the general difference between Also, while the remaining arguments brought up make sense to me, there are still a couple of unanswered questions in my head. I'm not expecting that one of you has a clear answer, but if you do (for any of the points), I'll love to hear it. Let me just quickly dump my brain, in random order:
|
Beta Was this translation helpful? Give feedback.
-
Thank you to everyone for the explanations about the algo and scaling. That was very helpful indeed. I would like to re-emphasize our question on the result, since @eytan also asked about the difference we see between the plots: Both plots are done in a setting where plenty of data points were already measured. Now if we compare the right plot for sequential and non-seqenetial, the results differ drastically.
In your view, is this expected? Have you observed similar? Would you be worried about this? Would this be reason to set Basically we are trying to undertand if there is anything that we did wrong, or whether this is just a reality we have to live with (and probably pragmatically setting |
Beta Was this translation helpful? Give feedback.
-
My hunch is that the choice of reference point is causing this. The reference point is very far from the PF, which puts more weight on the extrema. If the model is somewhat uncertain about the function value for the purple point in the lower right (for |
Beta Was this translation helpful? Give feedback.
-
What happened?
I've recently been testing a few things for multi-output modeling and stumbled over some very weird (unexpected?) behavior regarding the
sequential
flag ofoptimize_acqf
:sequential=True
takes roughly 6s whereassequential=False
runs for about 130s.Here the corresponding plots:
sequential=True
sequential=False
What is interesting to note here
The achieved acquisition values of the batches (shown in the legend) are roughly identical for both settings, so the two optimization strategies seem to have ended up in two different but equivalent (in terms of function value) local minima. From a pure acqf perspective, this suggests that both solutions are equally good, even though
sequential=True
clearly gives the better qualitative result. Perhaps you can comment on this?Also, I have no good explanation for the runtime difference. Is this expected? If so, is there a reason why
sequential=False
is the default?Please provide a minimal, reproducible example of the unexpected behavior.
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.13.0
Python Version
3.10
Operating System
macOS
Code of Conduct
Beta Was this translation helpful? Give feedback.
All reactions