Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix performance regression in looped QuantumCircuit.assign_parameters #13337

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1024,13 +1024,19 @@ impl CircuitData {
}

/// Assign all uses of the circuit parameters as keys `mapping` to their corresponding values.
///
/// Any items in the mapping that are not present in the circuit are skipped; it's up to Python
/// space to turn extra bindings into an error, if they choose to do it.
fn assign_parameters_mapping(&mut self, mapping: Bound<PyAny>) -> PyResult<()> {
let py = mapping.py();
let mut items = Vec::new();
for item in mapping.call_method0("items")?.iter()? {
let (param_ob, value) = item?.extract::<(Py<PyAny>, AssignParam)>()?;
let uuid = ParameterUuid::from_parameter(param_ob.bind(py))?;
items.push((param_ob, value.0, self.param_table.pop(uuid)?));
// It's fine if the mapping contains parameters that we don't have - just skip those.
if let Ok(uses) = self.param_table.pop(uuid) {
items.push((param_ob, value.0, uses));
}
}
self.assign_parameters_inner(py, items)
}
Expand Down
78 changes: 32 additions & 46 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,10 @@ def has_parameter(self, name_or_param: str | Parameter, /) -> bool:
"""
if isinstance(name_or_param, str):
return self.get_parameter(name_or_param, None) is not None
return self.get_parameter(name_or_param.name) == name_or_param
return (
isinstance(name_or_param, Parameter)
and self.get_parameter(name_or_param.name, None) == name_or_param
)

@typing.overload
def get_var(self, name: str, default: T) -> Union[expr.Var, T]: ...
Expand Down Expand Up @@ -4348,38 +4351,57 @@ def assign_parameters( # pylint: disable=missing-raises-doc

if isinstance(parameters, collections.abc.Mapping):
raw_mapping = parameters if flat_input else self._unroll_param_dict(parameters)
our_parameters = self._data.unsorted_parameters()
if strict and (extras := raw_mapping.keys() - our_parameters):
if strict and (
extras := [
parameter for parameter in raw_mapping if not self.has_parameter(parameter)
]
):
raise CircuitError(
f"Cannot bind parameters ({', '.join(str(x) for x in extras)}) not present in"
" the circuit."
)
parameter_binds = _ParameterBindsDict(raw_mapping, our_parameters)
target._data.assign_parameters_mapping(parameter_binds)

def create_mapping_view():
return raw_mapping

target._data.assign_parameters_mapping(raw_mapping)
else:
parameter_binds = _ParameterBindsSequence(target._data.parameters, parameters)
# This should be cause a cache retrieval, since we warmed the cache. We need to keep
kevinhartman marked this conversation as resolved.
Show resolved Hide resolved
# hold of this so that if/when we lazily construct this mapping within the calibration
# assignments, we don't query the newly-bound version of the inner parameters.
initial_parameters = target._data.parameters

def create_mapping_view():
return dict(zip(initial_parameters, parameters))

target._data.assign_parameters_iterable(parameters)

# Finally, assign the parameters inside any of the calibrations. We don't track these in
# the `ParameterTable`, so we manually reconstruct things.
# the `ParameterTable`, so we manually reconstruct things. We only construct the mapping
# view on the first actual call, and cache the result.
mapping_view = None

def map_calibration(qubits, parameters, schedule):
nonlocal mapping_view
if mapping_view is None:
mapping_view = create_mapping_view()
modified = False
new_parameters = list(parameters)
for i, parameter in enumerate(new_parameters):
if not isinstance(parameter, ParameterExpression):
continue
if not (contained := parameter.parameters & parameter_binds.mapping.keys()):
if not (contained := parameter.parameters & mapping_view.keys()):
continue
for to_bind in contained:
parameter = parameter.assign(to_bind, parameter_binds.mapping[to_bind])
parameter = parameter.assign(to_bind, mapping_view[to_bind])
if not parameter.parameters:
parameter = parameter.numeric()
if isinstance(parameter, complex):
raise TypeError(f"Calibration cannot use complex number: '{parameter}'")
new_parameters[i] = parameter
modified = True
if modified:
schedule.assign_parameters(parameter_binds.mapping)
schedule.assign_parameters(mapping_view)
return (qubits, tuple(new_parameters)), schedule

target._calibrations = defaultdict(
Expand Down Expand Up @@ -6702,42 +6724,6 @@ def _validate_expr(circuit_scope: CircuitScopeInterface, node: expr.Expr) -> exp
return node


class _ParameterBindsDict:
__slots__ = ("mapping", "allowed_keys")

def __init__(self, mapping, allowed_keys):
self.mapping = mapping
self.allowed_keys = allowed_keys

def items(self):
"""Iterator through all the keys in the mapping that we care about. Wrapping the main
mapping allows us to avoid reconstructing a new 'dict', but just use the given 'mapping'
without any copy / reconstruction."""
for parameter, value in self.mapping.items():
if parameter in self.allowed_keys:
yield parameter, value


class _ParameterBindsSequence:
__slots__ = ("parameters", "values", "mapping_cache")

def __init__(self, parameters, values):
self.parameters = parameters
self.values = values
self.mapping_cache = None

def items(self):
"""Iterator through all the keys in the mapping that we care about."""
return zip(self.parameters, self.values)

@property
def mapping(self):
"""Cached version of a mapping. This is only generated on demand."""
if self.mapping_cache is None:
self.mapping_cache = dict(zip(self.parameters, self.values))
return self.mapping_cache


def _bit_argument_conversion(specifier, bit_sequence, bit_set, type_) -> list[Bit]:
"""Get the list of bits referred to by the specifier ``specifier``.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
Fixed a performance regression in :meth:`.QuantumCircuit.assign_parameters` introduced in Qiskit
1.2.0 when calling the method in a tight loop, binding only a small number of parameters out of
a heavily parametric circuit on each iteration. If possible, it is still more performant to
call :meth:`~.QuantumCircuit.assign_parameters` only once, with all assignments at the same
time, as this reduces the proportion of time spent on input normalization and error-checking
overhead.