You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a polynomial equation which I want to get in a desired range. This polynomial equation outputs values from -40 to 15, I want to get it into the 7.5 to 9.5 range. I have defined a separate objective function that outputs 1 when the value of the polynomial function is in that specific range and linearly decreases to 0 outside that range. I am using that objective function in the acquisition function to get the suggested points that will output the value of the polynomial function within that range. Still, the values of polynomial function remain the same. Could you please suggest to me where exactly I am going wrong?
To reproduce
** Code snippet to reproduce **
importtorchfrombotorch.acquisition.monte_carloimportMCAcquisitionObjectiveclassRangeObjective(MCAcquisitionObjective):
def__init__(self, *args, **kwargs) ->None:
super().__init__(*args, **kwargs)
defforward(self, samples: torch.Tensor, **kwargs) ->torch.Tensor:
output=torch.zeros_like(samples)
# Between 7.0 and 7.5, linearly scale from 0 to 1mask1= (samples>7.0) & (samples<=7.5)
output[mask1] = (samples[mask1] -7.0) /0.5# Between 7.5 and 9.5, return 1mask2= (samples>7.5) & (samples<9.5)
output[mask2] =1# At exactly 7.5 and 9.5mask_exact= (samples==7.5) | (samples==9.5)
output[mask_exact] =1# Between 9.5 and 10.0, linearly scale from 1 to 0mask3= (samples>9.5) & (samples<10.0)
output[mask3] = (10.0-samples[mask3]) /0.5returnoutput.squeeze(-1)
fromscipy.optimizeimportLinearConstraint# Number of initial random points and iterations for Bayesian OptimizationN_INIT=2N_ITER=100# Generate initial data within specified bounds for each featuretrain_x=torch.rand(N_INIT, 10) *100foriinrange(10):
train_x[:, i] =torch.rand(N_INIT) * (upper_bound[i] -lower_bound[i]) +lower_bound[i]
# Normalize the initial data so that the sum of features equals 100train_x*=100/train_x.sum(dim=1, keepdim=True)
train_y=adh1_func(train_x)
train_y.requires_grad_(False) # Observations typically do not require gradient# Define constraints: sum of elements in each candidate should be 1A=torch.ones((1, 10)) # Coefficient matrix for the linear constraintlower_bounds=torch.tensor([100.0]) # Lower bounds for the constraintupper_bounds=torch.tensor([100.0]) # Upper bounds for the constraint# Convert A, lower_bounds, and upper_bounds to numpy for scipy compatibilityA_np=A.numpy()
lower_bounds_np=lower_bounds.numpy()
upper_bounds_np=upper_bounds.numpy()
# Create a scipy LinearConstraint objectlinear_constraint=LinearConstraint(A=A_np, lb=lower_bounds_np, ub=upper_bounds_np)
frombotorch.acquisition.objectiveimportGenericMCObjectivefrombotorch.acquisition.monte_carloimportqExpectedImprovementfrombotorch.sampling.normalimportSobolQMCNormalSampler# Objective function instanceobjective_func=RangeObjective()
foriterationinrange(N_ITER):
# Fit the GP modelgp=SingleTaskGP(train_x, train_y)
mll=ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_model(mll)
# Define the range-based objectiveobjective_func=RangeObjective()
# Find the current maximum 'y' considering the defined objective#print("train y before passing to obj func", train_y)transformed_y=objective_func(train_y)
#print("train y after passing to obj func", transformed_y)current_max=transformed_y.max()
#print("currrent max", current_max)sampler=SobolQMCNormalSampler(sample_shape=torch.Size([1]))
# Initialize the acquisition functionqEI=qExpectedImprovement(
model=gp,
best_f=current_max,
sampler=None,
objective=objective_func
)
# Optimize the acquisition function to find new candidatescandidate, acq_value=optimize_acqf(
acq_function=qEI,
bounds=bounds,
q=1, # The number of candidates to generatenum_restarts=5,
raw_samples=512, # The number of raw samples to consideroptions={"constraints": linear_constraint}
)
# Evaluate the objective function at the new candidatenew_y=adh1_func(candidate)
new_y_transformed=objective_func(new_y)
# Update training datatrain_x=torch.cat([train_x, candidate])
train_y=torch.cat([train_y, new_y])
print(f"Iteration {iteration+1}, new point = {candidate.numpy()}, objective = {new_y.item()}")
Hereyoucanseethatthevaluesvaryoveraquitealargerange

## Expected BehaviorIwantthevaluestobebetween7.5and9.5.
The text was updated successfully, but these errors were encountered:
Hi @Yash-Pisat. Looks the RangeObjective has a derivative of 0 outside of 7 and 10. This will make the optimization do nothing when the samples are outside of this range. You need to define an objective that has gradients pointing to the desired range whenever the samples are outside of this range.
🐛 Bug
I have a polynomial equation which I want to get in a desired range. This polynomial equation outputs values from -40 to 15, I want to get it into the 7.5 to 9.5 range. I have defined a separate objective function that outputs 1 when the value of the polynomial function is in that specific range and linearly decreases to 0 outside that range. I am using that objective function in the acquisition function to get the suggested points that will output the value of the polynomial function within that range. Still, the values of polynomial function remain the same. Could you please suggest to me where exactly I am going wrong?
To reproduce
** Code snippet to reproduce **
The text was updated successfully, but these errors were encountered: