Skip to content

Commit 1ec824c

Browse files
committed
Added entropy search acquisition function initializer, ensured unittest coverage
1 parent 9d39469 commit 1ec824c

File tree

4 files changed

+300
-53
lines changed

4 files changed

+300
-53
lines changed

botorch/optim/initializers.py

+178-53
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,128 @@ def gen_batch_initial_conditions(
468468
return batch_initial_conditions
469469

470470

471+
def gen_optimal_input_initial_conditions(
472+
acq_function: AcquisitionFunction,
473+
bounds: Tensor,
474+
q: int,
475+
num_restarts: int,
476+
raw_samples: int,
477+
fixed_features: dict[int, float] | None = None,
478+
options: dict[str, bool | float | int] | None = None,
479+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
480+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
481+
):
482+
r"""Generate a batch of initial conditions for random-restart optimziation of
483+
information-theoretic acquisition functions (PES & JES), where sampled optimizers
484+
of the posterior constitute good initial guesses for further optimization. A
485+
fraction of initial samples (by default: 100%) are drawn as perturbations around
486+
`acq.optimal_inputs`. On average, this drastically decreases the runtime of
487+
acquisition function optimization and yields higher-valued candidates by acquisition
488+
function value. See https://github.com/pytorch/botorch/pull/2751 for more info.
489+
490+
Args:
491+
acq_function: The acquisition function to be optimized.
492+
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
493+
q: The number of candidates to consider.
494+
num_restarts: The number of starting points for multistart acquisition
495+
function optimization.
496+
raw_samples: The number of raw samples to consider in the initialization
497+
heuristic. Note: if `sample_around_best` is True (the default is False),
498+
then `2 * raw_samples` samples are used.
499+
fixed_features: A map `{feature_index: value}` for features that
500+
should be fixed to a particular value during generation.
501+
options: Options for initial condition generation. These contain all
502+
settings for the standard heuristic initialization from
503+
`gen_batch_initial_conditions`. In addition, they contain
504+
`frac_random` (the fraction of points drawn fully at random as opposed
505+
to around the drawn optimizers from the posterior).
506+
`sample_around_best_sigma` dictates both the standard deviation of the
507+
samples drawn from posterior maximizers, and the samples from previous
508+
best (if enabled).
509+
inequality constraints: A list of tuples (indices, coefficients, rhs),
510+
with each tuple encoding an inequality constraint of the form
511+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
512+
equality constraints: A list of tuples (indices, coefficients, rhs),
513+
with each tuple encoding an inequality constraint of the form
514+
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
515+
516+
Returns:
517+
A `num_restarts x q x d` tensor of initial conditions.
518+
"""
519+
options = options or {}
520+
device = bounds.device
521+
if not hasattr(acq_function, "optimal_inputs"):
522+
raise AttributeError(
523+
"gen_optimal_input_initial_conditions can only be used with "
524+
"an AcquisitionFunction that has an optimal_inputs attribute."
525+
)
526+
frac_random: float = options.get("frac_random", 0.0)
527+
if not 0 <= frac_random <= 1:
528+
raise ValueError(
529+
f"frac_random must take on values in (0,1). Value: {frac_random}"
530+
)
531+
532+
batch_limit = options.get("batch_limit")
533+
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
534+
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
535+
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
536+
num_random = round(raw_samples * frac_random)
537+
if num_random > 0:
538+
X_rnd = sample_q_batches_from_polytope(
539+
n=num_random,
540+
q=q,
541+
bounds=bounds,
542+
n_burnin=options.get("n_burnin", 10000),
543+
n_thinning=options.get("n_thinning", 32),
544+
equality_constraints=equality_constraints,
545+
inequality_constraints=inequality_constraints,
546+
)
547+
X = torch.cat((X, X_rnd))
548+
549+
if num_random < raw_samples:
550+
X_perturbed = sample_points_around_best(
551+
acq_function=acq_function,
552+
n_discrete_points=q * (raw_samples - num_random),
553+
sigma=options.get("sample_around_best_sigma", 1e-2),
554+
bounds=bounds,
555+
best_X=suggestions,
556+
)
557+
X_perturbed = X_perturbed.view(
558+
raw_samples - num_random, q, bounds.shape[-1]
559+
).cpu()
560+
X = torch.cat((X, X_perturbed))
561+
562+
if options.get("sample_around_best", False):
563+
X_best = sample_points_around_best(
564+
acq_function=acq_function,
565+
n_discrete_points=q * raw_samples,
566+
sigma=options.get("sample_around_best_sigma", 1e-2),
567+
bounds=bounds,
568+
)
569+
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
570+
X = torch.cat((X, X_best))
571+
572+
X_rnd = fix_features(X, fixed_features=fixed_features).cpu()
573+
with torch.no_grad():
574+
if batch_limit is None:
575+
batch_limit = X.shape[0]
576+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
577+
# sized chunks.
578+
acq_vals = torch.cat(
579+
[
580+
acq_function(x_.to(device=device)).cpu()
581+
for x_ in X.split(split_size=batch_limit, dim=0)
582+
],
583+
dim=0,
584+
)
585+
idx = boltzmann_sample(
586+
function_values=acq_vals,
587+
num_samples=num_restarts,
588+
eta=options.get("eta", 2.0),
589+
)
590+
return X[idx]
591+
592+
471593
def gen_one_shot_kg_initial_conditions(
472594
acq_function: qKnowledgeGradient,
473595
bounds: Tensor,
@@ -1136,6 +1258,7 @@ def sample_points_around_best(
11361258
best_pct: float = 5.0,
11371259
subset_sigma: float = 1e-1,
11381260
prob_perturb: float | None = None,
1261+
best_X: Tensor | None = None,
11391262
) -> Tensor | None:
11401263
r"""Find best points and sample nearby points.
11411264
@@ -1154,60 +1277,62 @@ def sample_points_around_best(
11541277
An optional `n_discrete_points x d`-dim tensor containing the
11551278
sampled points. This is None if no baseline points are found.
11561279
"""
1157-
X = get_X_baseline(acq_function=acq_function)
1158-
if X is None:
1159-
return
1160-
with torch.no_grad():
1161-
try:
1162-
posterior = acq_function.model.posterior(X)
1163-
except AttributeError:
1164-
warnings.warn(
1165-
"Failed to sample around previous best points.",
1166-
BotorchWarning,
1167-
stacklevel=3,
1168-
)
1280+
if best_X is None:
1281+
X = get_X_baseline(acq_function=acq_function)
1282+
if X is None:
11691283
return
1170-
mean = posterior.mean
1171-
while mean.ndim > 2:
1172-
# take average over batch dims
1173-
mean = mean.mean(dim=0)
1174-
try:
1175-
f_pred = acq_function.objective(mean)
1176-
# Some acquisition functions do not have an objective
1177-
# and for some acquisition functions the objective is None
1178-
except (AttributeError, TypeError):
1179-
f_pred = mean
1180-
if hasattr(acq_function, "maximize"):
1181-
# make sure that the optimiztaion direction is set properly
1182-
if not acq_function.maximize:
1183-
f_pred = -f_pred
1184-
try:
1185-
# handle constraints for EHVI-based acquisition functions
1186-
constraints = acq_function.constraints
1187-
if constraints is not None:
1188-
neg_violation = -torch.stack(
1189-
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1190-
).sum(dim=-1)
1191-
feas = neg_violation == 0
1192-
if feas.any():
1193-
f_pred[~feas] = float("-inf")
1194-
else:
1195-
# set objective equal to negative violation
1196-
f_pred = neg_violation
1197-
except AttributeError:
1198-
pass
1199-
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1200-
# multi-objective
1201-
# find pareto set
1202-
is_pareto = is_non_dominated(f_pred)
1203-
best_X = X[is_pareto]
1204-
else:
1205-
if f_pred.shape[-1] == 1:
1206-
f_pred = f_pred.squeeze(-1)
1207-
n_best = max(1, round(X.shape[0] * best_pct / 100))
1208-
# the view() is to ensure that best_idcs is not a scalar tensor
1209-
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1210-
best_X = X[best_idcs]
1284+
with torch.no_grad():
1285+
try:
1286+
posterior = acq_function.model.posterior(X)
1287+
except AttributeError:
1288+
warnings.warn(
1289+
"Failed to sample around previous best points.",
1290+
BotorchWarning,
1291+
stacklevel=3,
1292+
)
1293+
return
1294+
mean = posterior.mean
1295+
while mean.ndim > 2:
1296+
# take average over batch dims
1297+
mean = mean.mean(dim=0)
1298+
try:
1299+
f_pred = acq_function.objective(mean)
1300+
# Some acquisition functions do not have an objective
1301+
# and for some acquisition functions the objective is None
1302+
except (AttributeError, TypeError):
1303+
f_pred = mean
1304+
if hasattr(acq_function, "maximize"):
1305+
# make sure that the optimiztaion direction is set properly
1306+
if not acq_function.maximize:
1307+
f_pred = -f_pred
1308+
try:
1309+
# handle constraints for EHVI-based acquisition functions
1310+
constraints = acq_function.constraints
1311+
if constraints is not None:
1312+
neg_violation = -torch.stack(
1313+
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1314+
).sum(dim=-1)
1315+
feas = neg_violation == 0
1316+
if feas.any():
1317+
f_pred[~feas] = float("-inf")
1318+
else:
1319+
# set objective equal to negative violation
1320+
f_pred = neg_violation
1321+
except AttributeError:
1322+
pass
1323+
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1324+
# multi-objective
1325+
# find pareto set
1326+
is_pareto = is_non_dominated(f_pred)
1327+
best_X = X[is_pareto]
1328+
else:
1329+
if f_pred.shape[-1] == 1:
1330+
f_pred = f_pred.squeeze(-1)
1331+
n_best = max(1, round(X.shape[0] * best_pct / 100))
1332+
# the view() is to ensure that best_idcs is not a scalar tensor
1333+
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1334+
best_X = X[best_idcs]
1335+
12111336
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
12121337
n_trunc_normal_points = (
12131338
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points

botorch/optim/optimize.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AcquisitionFunction,
2121
OneShotAcquisitionFunction,
2222
)
23+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2324
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2425
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
2526
qHypervolumeKnowledgeGradient,
@@ -33,6 +34,7 @@
3334
gen_batch_initial_conditions,
3435
gen_one_shot_hvkg_initial_conditions,
3536
gen_one_shot_kg_initial_conditions,
37+
gen_optimal_input_initial_conditions,
3638
TGenInitialConditions,
3739
)
3840
from botorch.optim.stopping import ExpMAStoppingCriterion
@@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions:
174176
return gen_one_shot_kg_initial_conditions
175177
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
176178
return gen_one_shot_hvkg_initial_conditions
179+
elif isinstance(self.acq_function, qJointEntropySearch):
180+
return gen_optimal_input_initial_conditions
177181
return gen_batch_initial_conditions
178182

179183

0 commit comments

Comments
 (0)