diff --git a/tensorflow_quantum/python/util.py b/tensorflow_quantum/python/util.py index 92ebeabee..ebf632694 100644 --- a/tensorflow_quantum/python/util.py +++ b/tensorflow_quantum/python/util.py @@ -14,6 +14,7 @@ # ============================================================================== """A collection of helper functions which are useful in places in TFQ.""" +import concurrent.futures import itertools import numbers import random @@ -262,13 +263,16 @@ def random_pauli_sums(qubits, max_sum_length, n_sums): # There are no native convertible ops inside of this function. @tf.autograph.experimental.do_not_convert def convert_to_tensor(items_to_convert, deterministic_proto_serialize=False): - """Convert lists of tfq supported primitives to tensor representations. + """Convert lists of tfq supported primitives to tensor representations using parallel processing. Recursively convert a nested lists of `cirq.PauliSum` or `cirq.Circuit` - objects to a `tf.Tensor` representation. Note that cirq serialization only - supports `cirq.GridQubit`s so we also require that input circuits and - pauli sums are defined only on `cirq.GridQubit`s. + objects to a `tf.Tensor` representation using multiple processes. Note that + cirq serialization only supports `cirq.GridQubit`s so we also require that + input circuits and pauli sums are defined only on `cirq.GridQubit`s. + The function uses `concurrent.futures.ProcessPoolExecutor` to parallelize the + conversion of the items. Each item in the list is processed independently, + which allows for significant speedup with large lists of items. >>> my_qubits = cirq.GridQubit.rect(1, 2) >>> my_circuits = [cirq.Circuit(cirq.X(my_qubits[0])), @@ -292,7 +296,7 @@ def convert_to_tensor(items_to_convert, deterministic_proto_serialize=False): Args: items_to_convert: Python `list` or nested `list` of `cirq.Circuit` - or `cirq.Paulisum` objects. Must be recangular. + or `cirq.Paulisum` objects. Must be rectangular. deterministic_proto_serialize: Whether to use a deterministic serialization when calling SerializeToString(). Returns: @@ -306,31 +310,30 @@ def convert_to_tensor(items_to_convert, deterministic_proto_serialize=False): # `cirq.Circuit`s and `cirq.PauliSum`s (they are iterable). # This code is safe for nested lists of depth less than the recursion limit, # which is deeper than any practical use the author can think of. - def recur(items_to_convert, curr_type=None): - tensored_items = [] - for item in items_to_convert: - if isinstance(item, (list, np.ndarray, tuple)): - tensored_items.append(recur(item, curr_type)) - elif isinstance(item, (cirq.PauliSum, cirq.PauliString)) and\ - not curr_type == cirq.Circuit: - curr_type = cirq.PauliSum - tensored_items.append( - serializer.serialize_paulisum(item).SerializeToString( - deterministic=deterministic_proto_serialize)) - elif isinstance(item, cirq.Circuit) and\ - not curr_type == cirq.PauliSum: - curr_type = cirq.Circuit - tensored_items.append( - serializer.serialize_circuit(item).SerializeToString( - deterministic=deterministic_proto_serialize)) - else: - raise TypeError("Incompatible item passed into " - "convert_to_tensor. Tensor detected type: {}. " - "got: {}".format(curr_type, type(item))) - return tensored_items + def convert_item(item, curr_type=None): + if isinstance(item, (list, np.ndarray, tuple)): + return [convert_item(i, curr_type) for i in item] + elif isinstance(item, (cirq.PauliSum, cirq.PauliString)) and\ + not curr_type == cirq.Circuit: + curr_type = cirq.PauliSum + return serializer.serialize_paulisum(item).SerializeToString( + deterministic=deterministic_proto_serialize) + elif isinstance(item, cirq.Circuit) and\ + not curr_type == cirq.PauliSum: + curr_type = cirq.Circuit + return serializer.serialize_circuit(item).SerializeToString( + deterministic=deterministic_proto_serialize) + else: + raise TypeError("Incompatible item passed into " + "convert_to_tensor. Tensor detected type: {}. " + "got: {}".format(curr_type, type(item))) + + with concurrent.futures.ProcessPoolExecutor() as executor: + tensored_items = list(executor.map(convert_item, items_to_convert, + [None]*len(items_to_convert))) # This will catch impossible dimensions - return tf.convert_to_tensor(recur(items_to_convert)) + return tf.convert_to_tensor(tensored_items) def _parse_single(item):