Skip to content

Commit

Permalink
Merge pull request #22 from qutip/21-prevent-__time__-kwrd-loss
Browse files Browse the repository at this point in the history
Fix __time__ kwrd pop
  • Loading branch information
flowerthrower authored Oct 1, 2024
2 parents c3c3adb + cd61877 commit a5f0016
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/qutip_qoc/_goat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
self._var_t = "guess" in time_options

# num of params for each control function
self._para_counts = [len(v["guess"]) for v in control_parameters.values()]
self._para_counts = [
len(v["guess"]) for k, v in control_parameters.items() if k != "__time__"
]

# inferred attributes
self._tot_n_para = sum(self._para_counts) # excl. time
Expand Down
3 changes: 0 additions & 3 deletions src/qutip_qoc/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ def _global_local_optimization(
_get_init_and_bounds_from_options(x0, control_parameters[key].get("guess"))
_get_init_and_bounds_from_options(bounds, control_parameters[key].get("bounds"))

_get_init_and_bounds_from_options(x0, time_options.get("guess", None))
_get_init_and_bounds_from_options(bounds, time_options.get("bounds", None))

optimizer_kwargs["x0"] = np.concatenate(x0)

multi_objective = _MultiObjective(
Expand Down
7 changes: 4 additions & 3 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def optimize_pulses(
# create time interval
time_interval = _TimeInterval(tslots=tlist)

time_options = control_parameters.pop("__time__", {})
time_options = control_parameters.get("__time__", {})
if time_options: # convert to list of bounds if not already
if not isinstance(time_options["bounds"][0], (list, tuple)):
time_options["bounds"] = [time_options["bounds"]]
Expand All @@ -151,8 +151,9 @@ def optimize_pulses(
# extract guess and bounds for the control pulses
x0, bounds = [], []
for key in control_parameters.keys():
x0.append(control_parameters[key].get("guess"))
bounds.append(control_parameters[key].get("bounds"))
if key != "__time__":
x0.append(control_parameters[key].get("guess"))
bounds.append(control_parameters[key].get("bounds"))
try: # GRAPE, CRAB format
lbound = [b[0][0] for b in bounds]
ubound = [b[0][1] for b in bounds]
Expand Down

0 comments on commit a5f0016

Please # to comment.