Skip to content

Commit c157b57

Browse files
committed
Added entropy search acquisition function initializer, ensured unittest
coverage
1 parent 6b3002b commit c157b57

File tree

4 files changed

+263
-53
lines changed

4 files changed

+263
-53
lines changed

botorch/optim/initializers.py

+141-53
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,91 @@ def gen_batch_initial_conditions(
468468
return batch_initial_conditions
469469

470470

471+
def gen_optimal_input_initial_conditions(
472+
acq_function: AcquisitionFunction,
473+
bounds: Tensor,
474+
q: int,
475+
num_restarts: int,
476+
raw_samples: int,
477+
fixed_features: dict[int, float] | None = None,
478+
options: dict[str, bool | float | int] | None = None,
479+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
480+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
481+
):
482+
options = options or {}
483+
device = bounds.device
484+
if not hasattr(acq_function, "optimal_inputs"):
485+
raise AttributeError(
486+
"gen_optimal_input_initial_conditions can only be used with "
487+
"an AcquisitionFunction that has an optimal_inputs attribute."
488+
)
489+
frac_random: float = options.get("frac_random", 0.0)
490+
if not 0 <= frac_random <= 1:
491+
raise ValueError(
492+
f"frac_random must take on values in (0,1). Value: {frac_random}"
493+
)
494+
495+
batch_limit = options.get("batch_limit")
496+
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
497+
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
498+
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
499+
num_random = round(raw_samples * frac_random)
500+
if num_random > 0:
501+
X_rnd = sample_q_batches_from_polytope(
502+
n=num_random,
503+
q=q,
504+
bounds=bounds,
505+
n_burnin=options.get("n_burnin", 10000),
506+
n_thinning=options.get("n_thinning", 32),
507+
equality_constraints=equality_constraints,
508+
inequality_constraints=inequality_constraints,
509+
)
510+
X = torch.cat((X, X_rnd))
511+
512+
if num_random < raw_samples:
513+
X_perturbed = sample_points_around_best(
514+
acq_function=acq_function,
515+
n_discrete_points=q * (raw_samples - num_random),
516+
sigma=options.get("sample_around_best_sigma", 1e-2),
517+
bounds=bounds,
518+
best_X=suggestions,
519+
)
520+
X_perturbed = X_perturbed.view(
521+
raw_samples - num_random, q, bounds.shape[-1]
522+
).cpu()
523+
X = torch.cat((X, X_perturbed))
524+
525+
if options.get("sample_around_best", False):
526+
X_best = sample_points_around_best(
527+
acq_function=acq_function,
528+
n_discrete_points=q * raw_samples,
529+
sigma=options.get("sample_around_best_sigma", 1e-2),
530+
bounds=bounds,
531+
)
532+
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
533+
X = torch.cat((X, X_best))
534+
535+
with torch.no_grad():
536+
if batch_limit is None:
537+
batch_limit = X.shape[0]
538+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
539+
# sized chunks.
540+
acq_vals = torch.cat(
541+
[
542+
acq_function(x_.to(device=device)).cpu()
543+
for x_ in X.split(split_size=batch_limit, dim=0)
544+
],
545+
dim=0,
546+
)
547+
idx = boltzmann_sample(
548+
function_values=acq_vals,
549+
num_samples=num_restarts,
550+
eta=options.get("eta", 2.0),
551+
)
552+
# set the respective initial conditions to the sampled optimizers
553+
return X[idx]
554+
555+
471556
def gen_one_shot_kg_initial_conditions(
472557
acq_function: qKnowledgeGradient,
473558
bounds: Tensor,
@@ -1136,6 +1221,7 @@ def sample_points_around_best(
11361221
best_pct: float = 5.0,
11371222
subset_sigma: float = 1e-1,
11381223
prob_perturb: float | None = None,
1224+
best_X: Tensor | None = None,
11391225
) -> Tensor | None:
11401226
r"""Find best points and sample nearby points.
11411227
@@ -1154,60 +1240,62 @@ def sample_points_around_best(
11541240
An optional `n_discrete_points x d`-dim tensor containing the
11551241
sampled points. This is None if no baseline points are found.
11561242
"""
1157-
X = get_X_baseline(acq_function=acq_function)
1158-
if X is None:
1159-
return
1160-
with torch.no_grad():
1161-
try:
1162-
posterior = acq_function.model.posterior(X)
1163-
except AttributeError:
1164-
warnings.warn(
1165-
"Failed to sample around previous best points.",
1166-
BotorchWarning,
1167-
stacklevel=3,
1168-
)
1243+
if best_X is None:
1244+
X = get_X_baseline(acq_function=acq_function)
1245+
if X is None:
11691246
return
1170-
mean = posterior.mean
1171-
while mean.ndim > 2:
1172-
# take average over batch dims
1173-
mean = mean.mean(dim=0)
1174-
try:
1175-
f_pred = acq_function.objective(mean)
1176-
# Some acquisition functions do not have an objective
1177-
# and for some acquisition functions the objective is None
1178-
except (AttributeError, TypeError):
1179-
f_pred = mean
1180-
if hasattr(acq_function, "maximize"):
1181-
# make sure that the optimiztaion direction is set properly
1182-
if not acq_function.maximize:
1183-
f_pred = -f_pred
1184-
try:
1185-
# handle constraints for EHVI-based acquisition functions
1186-
constraints = acq_function.constraints
1187-
if constraints is not None:
1188-
neg_violation = -torch.stack(
1189-
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1190-
).sum(dim=-1)
1191-
feas = neg_violation == 0
1192-
if feas.any():
1193-
f_pred[~feas] = float("-inf")
1194-
else:
1195-
# set objective equal to negative violation
1196-
f_pred = neg_violation
1197-
except AttributeError:
1198-
pass
1199-
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1200-
# multi-objective
1201-
# find pareto set
1202-
is_pareto = is_non_dominated(f_pred)
1203-
best_X = X[is_pareto]
1204-
else:
1205-
if f_pred.shape[-1] == 1:
1206-
f_pred = f_pred.squeeze(-1)
1207-
n_best = max(1, round(X.shape[0] * best_pct / 100))
1208-
# the view() is to ensure that best_idcs is not a scalar tensor
1209-
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1210-
best_X = X[best_idcs]
1247+
with torch.no_grad():
1248+
try:
1249+
posterior = acq_function.model.posterior(X)
1250+
except AttributeError:
1251+
warnings.warn(
1252+
"Failed to sample around previous best points.",
1253+
BotorchWarning,
1254+
stacklevel=3,
1255+
)
1256+
return
1257+
mean = posterior.mean
1258+
while mean.ndim > 2:
1259+
# take average over batch dims
1260+
mean = mean.mean(dim=0)
1261+
try:
1262+
f_pred = acq_function.objective(mean)
1263+
# Some acquisition functions do not have an objective
1264+
# and for some acquisition functions the objective is None
1265+
except (AttributeError, TypeError):
1266+
f_pred = mean
1267+
if hasattr(acq_function, "maximize"):
1268+
# make sure that the optimiztaion direction is set properly
1269+
if not acq_function.maximize:
1270+
f_pred = -f_pred
1271+
try:
1272+
# handle constraints for EHVI-based acquisition functions
1273+
constraints = acq_function.constraints
1274+
if constraints is not None:
1275+
neg_violation = -torch.stack(
1276+
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1277+
).sum(dim=-1)
1278+
feas = neg_violation == 0
1279+
if feas.any():
1280+
f_pred[~feas] = float("-inf")
1281+
else:
1282+
# set objective equal to negative violation
1283+
f_pred = neg_violation
1284+
except AttributeError:
1285+
pass
1286+
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1287+
# multi-objective
1288+
# find pareto set
1289+
is_pareto = is_non_dominated(f_pred)
1290+
best_X = X[is_pareto]
1291+
else:
1292+
if f_pred.shape[-1] == 1:
1293+
f_pred = f_pred.squeeze(-1)
1294+
n_best = max(1, round(X.shape[0] * best_pct / 100))
1295+
# the view() is to ensure that best_idcs is not a scalar tensor
1296+
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1297+
best_X = X[best_idcs]
1298+
12111299
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
12121300
n_trunc_normal_points = (
12131301
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points

botorch/optim/optimize.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AcquisitionFunction,
2121
OneShotAcquisitionFunction,
2222
)
23+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2324
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2425
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
2526
qHypervolumeKnowledgeGradient,
@@ -33,6 +34,7 @@
3334
gen_batch_initial_conditions,
3435
gen_one_shot_hvkg_initial_conditions,
3536
gen_one_shot_kg_initial_conditions,
37+
gen_optimal_input_initial_conditions,
3638
TGenInitialConditions,
3739
)
3840
from botorch.optim.stopping import ExpMAStoppingCriterion
@@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions:
174176
return gen_one_shot_kg_initial_conditions
175177
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
176178
return gen_one_shot_hvkg_initial_conditions
179+
elif isinstance(self.acq_function, qJointEntropySearch):
180+
return gen_optimal_input_initial_conditions
177181
return gen_batch_initial_conditions
178182

179183

test/optim/test_initializers.py

+107
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
from botorch.acquisition.analytic import PosteriorMean
1515
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
16+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
1617
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
1718
from botorch.acquisition.monte_carlo import (
1819
qExpectedImprovement,
@@ -34,6 +35,7 @@
3435
gen_batch_initial_conditions,
3536
gen_one_shot_hvkg_initial_conditions,
3637
gen_one_shot_kg_initial_conditions,
38+
gen_optimal_input_initial_conditions,
3739
gen_value_function_initial_conditions,
3840
initialize_q_batch,
3941
initialize_q_batch_nonneg,
@@ -47,6 +49,7 @@
4749
)
4850
from botorch.sampling.normal import IIDNormalSampler
4951
from botorch.utils.sampling import manual_seed, unnormalize
52+
from botorch.utils.test_helpers import get_model
5053
from botorch.utils.testing import (
5154
_get_max_violation_of_bounds,
5255
_get_max_violation_of_constraints,
@@ -1074,6 +1077,110 @@ def test_gen_one_shot_kg_initial_conditions(self):
10741077
)
10751078
self.assertTrue(torch.all(ics[..., -n_value:, :] == 1))
10761079

1080+
def test_gen_optimal_input_initial_conditions(self):
1081+
num_restarts = 10
1082+
raw_samples = 16
1083+
q = 3
1084+
for dtype in (torch.float, torch.double):
1085+
model = get_model(
1086+
torch.rand(4, 2, dtype=dtype), torch.rand(4, 1, dtype=dtype)
1087+
)
1088+
optimal_inputs = torch.rand(5, 2, dtype=dtype)
1089+
optimal_outputs = torch.rand(5, 1, dtype=dtype)
1090+
jes = qJointEntropySearch(
1091+
model=model,
1092+
optimal_inputs=optimal_inputs,
1093+
optimal_outputs=optimal_outputs,
1094+
)
1095+
bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype)
1096+
# base case
1097+
ics = gen_optimal_input_initial_conditions(
1098+
acq_function=jes,
1099+
bounds=bounds,
1100+
q=q,
1101+
num_restarts=num_restarts,
1102+
raw_samples=raw_samples,
1103+
)
1104+
self.assertEqual(ics.shape, torch.Size([num_restarts, q, 2]))
1105+
1106+
# since we do sample_around best, this should generate enough points
1107+
# despite num_restarts being larger than raw_samples
1108+
ics = gen_optimal_input_initial_conditions(
1109+
acq_function=jes,
1110+
bounds=bounds,
1111+
q=q,
1112+
num_restarts=15,
1113+
raw_samples=8,
1114+
options={"frac_random": 0.2, "sample_around_best": True},
1115+
)
1116+
self.assertEqual(ics.shape, torch.Size([15, q, 2]))
1117+
1118+
# test option error
1119+
with self.assertRaises(ValueError):
1120+
gen_optimal_input_initial_conditions(
1121+
acq_function=jes,
1122+
bounds=bounds,
1123+
q=1,
1124+
num_restarts=num_restarts,
1125+
raw_samples=raw_samples,
1126+
options={"frac_random": 2.0},
1127+
)
1128+
1129+
ei = qExpectedImprovement(model, 99.9)
1130+
with self.assertRaisesRegex(
1131+
AttributeError,
1132+
"gen_optimal_input_initial_conditions can only be used with "
1133+
"an AcquisitionFunction that has an optimal_inputs attribute.",
1134+
):
1135+
gen_optimal_input_initial_conditions(
1136+
acq_function=ei,
1137+
bounds=bounds,
1138+
q=1,
1139+
num_restarts=num_restarts,
1140+
raw_samples=raw_samples,
1141+
options={"frac_random": 2.0},
1142+
)
1143+
# test generation logic
1144+
random_ics = torch.rand(raw_samples // 2, q, 2)
1145+
suggested_ics = torch.rand(raw_samples // 2 * q, 2)
1146+
with ExitStack() as es:
1147+
mock_random_ics = es.enter_context(
1148+
mock.patch(
1149+
"botorch.optim.initializers.sample_q_batches_from_polytope",
1150+
return_value=random_ics,
1151+
)
1152+
)
1153+
mock_suggested_ics = es.enter_context(
1154+
mock.patch(
1155+
"botorch.optim.initializers.sample_points_around_best",
1156+
return_value=suggested_ics,
1157+
)
1158+
)
1159+
mock_choose = es.enter_context(
1160+
mock.patch(
1161+
"torch.multinomial",
1162+
return_value=torch.arange(0, 10),
1163+
)
1164+
)
1165+
1166+
ics = gen_optimal_input_initial_conditions(
1167+
acq_function=jes,
1168+
bounds=bounds,
1169+
q=q,
1170+
num_restarts=num_restarts,
1171+
raw_samples=raw_samples,
1172+
options={"frac_random": 0.5},
1173+
)
1174+
1175+
mock_suggested_ics.assert_called_once()
1176+
mock_random_ics.assert_called_once()
1177+
mock_choose.assert_called_once()
1178+
1179+
expected_result = torch.cat(
1180+
(random_ics, suggested_ics.view(raw_samples // 2, q, 2)[0:2])
1181+
)
1182+
self.assertTrue(torch.equal(ics, expected_result))
1183+
10771184

10781185
class TestGenOneShotHVKGInitialConditions(BotorchTestCase):
10791186
def test_gen_one_shot_hvkg_initial_conditions(self):

0 commit comments

Comments
 (0)