@@ -468,6 +468,128 @@ def gen_batch_initial_conditions(
468
468
return batch_initial_conditions
469
469
470
470
471
+ def gen_optimal_input_initial_conditions (
472
+ acq_function : AcquisitionFunction ,
473
+ bounds : Tensor ,
474
+ q : int ,
475
+ num_restarts : int ,
476
+ raw_samples : int ,
477
+ fixed_features : dict [int , float ] | None = None ,
478
+ options : dict [str , bool | float | int ] | None = None ,
479
+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
480
+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
481
+ ):
482
+ r"""Generate a batch of initial conditions for random-restart optimziation of
483
+ information-theoretic acquisition functions (PES & JES), where sampled optimizers
484
+ of the posterior constitute good initial guesses for further optimization. A
485
+ fraction of initial samples (by default: 100%) are drawn as perturbations around
486
+ `acq.optimal_inputs`. On average, this drastically decreases the runtime of
487
+ acquisition function optimization and yields higher-valued candidates by acquisition
488
+ function value. See https://github.com/pytorch/botorch/pull/2751 for more info.
489
+
490
+ Args:
491
+ acq_function: The acquisition function to be optimized.
492
+ bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
493
+ q: The number of candidates to consider.
494
+ num_restarts: The number of starting points for multistart acquisition
495
+ function optimization.
496
+ raw_samples: The number of raw samples to consider in the initialization
497
+ heuristic. Note: if `sample_around_best` is True (the default is False),
498
+ then `2 * raw_samples` samples are used.
499
+ fixed_features: A map `{feature_index: value}` for features that
500
+ should be fixed to a particular value during generation.
501
+ options: Options for initial condition generation. These contain all
502
+ settings for the standard heuristic initialization from
503
+ `gen_batch_initial_conditions`. In addition, they contain
504
+ `frac_random` (the fraction of points drawn fully at random as opposed
505
+ to around the drawn optimizers from the posterior).
506
+ `sample_around_best_sigma` dictates both the standard deviation of the
507
+ samples drawn from posterior maximizers, and the samples from previous
508
+ best (if enabled).
509
+ inequality constraints: A list of tuples (indices, coefficients, rhs),
510
+ with each tuple encoding an inequality constraint of the form
511
+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
512
+ equality constraints: A list of tuples (indices, coefficients, rhs),
513
+ with each tuple encoding an inequality constraint of the form
514
+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
515
+
516
+ Returns:
517
+ A `num_restarts x q x d` tensor of initial conditions.
518
+ """
519
+ options = options or {}
520
+ device = bounds .device
521
+ if not hasattr (acq_function , "optimal_inputs" ):
522
+ raise AttributeError (
523
+ "gen_optimal_input_initial_conditions can only be used with "
524
+ "an AcquisitionFunction that has an optimal_inputs attribute."
525
+ )
526
+ frac_random : float = options .get ("frac_random" , 0.0 )
527
+ if not 0 <= frac_random <= 1 :
528
+ raise ValueError (
529
+ f"frac_random must take on values in (0,1). Value: { frac_random } "
530
+ )
531
+
532
+ batch_limit = options .get ("batch_limit" )
533
+ num_optima = acq_function .optimal_inputs .shape [:- 1 ].numel ()
534
+ suggestions = acq_function .optimal_inputs .reshape (num_optima , - 1 )
535
+ X = torch .empty (0 , q , bounds .shape [1 ], dtype = bounds .dtype )
536
+ num_random = round (raw_samples * frac_random )
537
+ if num_random > 0 :
538
+ X_rnd = sample_q_batches_from_polytope (
539
+ n = num_random ,
540
+ q = q ,
541
+ bounds = bounds ,
542
+ n_burnin = options .get ("n_burnin" , 10000 ),
543
+ n_thinning = options .get ("n_thinning" , 32 ),
544
+ equality_constraints = equality_constraints ,
545
+ inequality_constraints = inequality_constraints ,
546
+ )
547
+ X = torch .cat ((X , X_rnd ))
548
+
549
+ if num_random < raw_samples :
550
+ X_perturbed = sample_points_around_best (
551
+ acq_function = acq_function ,
552
+ n_discrete_points = q * (raw_samples - num_random ),
553
+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
554
+ bounds = bounds ,
555
+ best_X = suggestions ,
556
+ )
557
+ X_perturbed = X_perturbed .view (
558
+ raw_samples - num_random , q , bounds .shape [- 1 ]
559
+ ).cpu ()
560
+ X = torch .cat ((X , X_perturbed ))
561
+
562
+ if options .get ("sample_around_best" , False ):
563
+ X_best = sample_points_around_best (
564
+ acq_function = acq_function ,
565
+ n_discrete_points = q * raw_samples ,
566
+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
567
+ bounds = bounds ,
568
+ )
569
+ X_best = X_best .view (raw_samples , q , bounds .shape [- 1 ]).cpu ()
570
+ X = torch .cat ((X , X_best ))
571
+
572
+ X_rnd = fix_features (X , fixed_features = fixed_features ).cpu ()
573
+ with torch .no_grad ():
574
+ if batch_limit is None :
575
+ batch_limit = X .shape [0 ]
576
+ # Evaluate the acquisition function on `X_rnd` using `batch_limit`
577
+ # sized chunks.
578
+ acq_vals = torch .cat (
579
+ [
580
+ acq_function (x_ .to (device = device )).cpu ()
581
+ for x_ in X .split (split_size = batch_limit , dim = 0 )
582
+ ],
583
+ dim = 0 ,
584
+ )
585
+ idx = boltzmann_sample (
586
+ function_values = acq_vals ,
587
+ num_samples = num_restarts ,
588
+ eta = options .get ("eta" , 2.0 ),
589
+ )
590
+ return X [idx ]
591
+
592
+
471
593
def gen_one_shot_kg_initial_conditions (
472
594
acq_function : qKnowledgeGradient ,
473
595
bounds : Tensor ,
@@ -1136,6 +1258,7 @@ def sample_points_around_best(
1136
1258
best_pct : float = 5.0 ,
1137
1259
subset_sigma : float = 1e-1 ,
1138
1260
prob_perturb : float | None = None ,
1261
+ best_X : Tensor | None = None ,
1139
1262
) -> Tensor | None :
1140
1263
r"""Find best points and sample nearby points.
1141
1264
@@ -1154,60 +1277,62 @@ def sample_points_around_best(
1154
1277
An optional `n_discrete_points x d`-dim tensor containing the
1155
1278
sampled points. This is None if no baseline points are found.
1156
1279
"""
1157
- X = get_X_baseline (acq_function = acq_function )
1158
- if X is None :
1159
- return
1160
- with torch .no_grad ():
1161
- try :
1162
- posterior = acq_function .model .posterior (X )
1163
- except AttributeError :
1164
- warnings .warn (
1165
- "Failed to sample around previous best points." ,
1166
- BotorchWarning ,
1167
- stacklevel = 3 ,
1168
- )
1280
+ if best_X is None :
1281
+ X = get_X_baseline (acq_function = acq_function )
1282
+ if X is None :
1169
1283
return
1170
- mean = posterior .mean
1171
- while mean .ndim > 2 :
1172
- # take average over batch dims
1173
- mean = mean .mean (dim = 0 )
1174
- try :
1175
- f_pred = acq_function .objective (mean )
1176
- # Some acquisition functions do not have an objective
1177
- # and for some acquisition functions the objective is None
1178
- except (AttributeError , TypeError ):
1179
- f_pred = mean
1180
- if hasattr (acq_function , "maximize" ):
1181
- # make sure that the optimiztaion direction is set properly
1182
- if not acq_function .maximize :
1183
- f_pred = - f_pred
1184
- try :
1185
- # handle constraints for EHVI-based acquisition functions
1186
- constraints = acq_function .constraints
1187
- if constraints is not None :
1188
- neg_violation = - torch .stack (
1189
- [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1190
- ).sum (dim = - 1 )
1191
- feas = neg_violation == 0
1192
- if feas .any ():
1193
- f_pred [~ feas ] = float ("-inf" )
1194
- else :
1195
- # set objective equal to negative violation
1196
- f_pred = neg_violation
1197
- except AttributeError :
1198
- pass
1199
- if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1200
- # multi-objective
1201
- # find pareto set
1202
- is_pareto = is_non_dominated (f_pred )
1203
- best_X = X [is_pareto ]
1204
- else :
1205
- if f_pred .shape [- 1 ] == 1 :
1206
- f_pred = f_pred .squeeze (- 1 )
1207
- n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1208
- # the view() is to ensure that best_idcs is not a scalar tensor
1209
- best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1210
- best_X = X [best_idcs ]
1284
+ with torch .no_grad ():
1285
+ try :
1286
+ posterior = acq_function .model .posterior (X )
1287
+ except AttributeError :
1288
+ warnings .warn (
1289
+ "Failed to sample around previous best points." ,
1290
+ BotorchWarning ,
1291
+ stacklevel = 3 ,
1292
+ )
1293
+ return
1294
+ mean = posterior .mean
1295
+ while mean .ndim > 2 :
1296
+ # take average over batch dims
1297
+ mean = mean .mean (dim = 0 )
1298
+ try :
1299
+ f_pred = acq_function .objective (mean )
1300
+ # Some acquisition functions do not have an objective
1301
+ # and for some acquisition functions the objective is None
1302
+ except (AttributeError , TypeError ):
1303
+ f_pred = mean
1304
+ if hasattr (acq_function , "maximize" ):
1305
+ # make sure that the optimiztaion direction is set properly
1306
+ if not acq_function .maximize :
1307
+ f_pred = - f_pred
1308
+ try :
1309
+ # handle constraints for EHVI-based acquisition functions
1310
+ constraints = acq_function .constraints
1311
+ if constraints is not None :
1312
+ neg_violation = - torch .stack (
1313
+ [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1314
+ ).sum (dim = - 1 )
1315
+ feas = neg_violation == 0
1316
+ if feas .any ():
1317
+ f_pred [~ feas ] = float ("-inf" )
1318
+ else :
1319
+ # set objective equal to negative violation
1320
+ f_pred = neg_violation
1321
+ except AttributeError :
1322
+ pass
1323
+ if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1324
+ # multi-objective
1325
+ # find pareto set
1326
+ is_pareto = is_non_dominated (f_pred )
1327
+ best_X = X [is_pareto ]
1328
+ else :
1329
+ if f_pred .shape [- 1 ] == 1 :
1330
+ f_pred = f_pred .squeeze (- 1 )
1331
+ n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1332
+ # the view() is to ensure that best_idcs is not a scalar tensor
1333
+ best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1334
+ best_X = X [best_idcs ]
1335
+
1211
1336
use_perturbed_sampling = best_X .shape [- 1 ] >= 20 or prob_perturb is not None
1212
1337
n_trunc_normal_points = (
1213
1338
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
0 commit comments