|
| 1 | +// This code is part of Qiskit. |
| 2 | +// |
| 3 | +// (C) Copyright IBM 2024 |
| 4 | +// |
| 5 | +// This code is licensed under the Apache License, Version 2.0. You may |
| 6 | +// obtain a copy of this license in the LICENSE.txt file in the root directory |
| 7 | +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. |
| 8 | +// |
| 9 | +// Any modifications or derivative works of this code must retain this |
| 10 | +// copyright notice, and modified files need to carry a notice indicating |
| 11 | +// that they have been altered from the originals. |
| 12 | + |
| 13 | +/// Type alias for a node representation. |
| 14 | +/// Each node is represented as a tuple containing: |
| 15 | +/// - Node id (usize) |
| 16 | +/// - List of involved qubit indices (Vec<VirtualQubit>) |
| 17 | +/// - Set of involved classical bit indices (HashSet<usize>) |
| 18 | +/// - Directive flag (bool) |
| 19 | +type Nodes = (usize, Vec<VirtualQubit>, HashSet<usize>, bool); |
| 20 | + |
| 21 | +/// Type alias for a block representation. |
| 22 | +/// Each block is represented by a tuple containing: |
| 23 | +/// - A boolean indicating the presence of a center (bool) |
| 24 | +/// - A list of nodes (Vec<Nodes>) |
| 25 | +type Block = (bool, Vec<Nodes>); |
| 26 | + |
| 27 | +use crate::nlayout::PhysicalQubit; |
| 28 | +use crate::nlayout::VirtualQubit; |
| 29 | +use crate::sabre::sabre_dag::SabreDAG; |
| 30 | +use crate::sabre::swap_map::SwapMap; |
| 31 | +use crate::sabre::BlockResult; |
| 32 | +use crate::sabre::NodeBlockResults; |
| 33 | +use crate::sabre::SabreResult; |
| 34 | +use hashbrown::HashMap; |
| 35 | +use hashbrown::HashSet; |
| 36 | +use numpy::IntoPyArray; |
| 37 | +use pyo3::prelude::*; |
| 38 | + |
| 39 | +/// Python function to perform star prerouting on a SabreDAG. |
| 40 | +/// This function processes star blocks and updates the DAG and qubit mapping. |
| 41 | +#[pyfunction] |
| 42 | +#[pyo3(text_signature = "(dag, blocks, processing_order, /)")] |
| 43 | +fn star_preroute( |
| 44 | + py: Python, |
| 45 | + dag: &mut SabreDAG, |
| 46 | + blocks: Vec<Block>, |
| 47 | + processing_order: Vec<Nodes>, |
| 48 | +) -> (SwapMap, PyObject, NodeBlockResults, PyObject) { |
| 49 | + let mut qubit_mapping: Vec<usize> = (0..dag.num_qubits).collect(); |
| 50 | + let mut processed_block_ids: HashSet<usize> = HashSet::with_capacity(blocks.len()); |
| 51 | + let last_2q_gate = processing_order.iter().rev().find(|node| node.1.len() == 2); |
| 52 | + let mut is_first_star = true; |
| 53 | + |
| 54 | + // Structures for SabreResult |
| 55 | + let mut out_map: HashMap<usize, Vec<[PhysicalQubit; 2]>> = |
| 56 | + HashMap::with_capacity(dag.dag.node_count()); |
| 57 | + let mut gate_order: Vec<usize> = Vec::with_capacity(dag.dag.node_count()); |
| 58 | + let node_block_results: HashMap<usize, Vec<BlockResult>> = HashMap::new(); |
| 59 | + |
| 60 | + // Create a HashMap to store the node-to-block mapping |
| 61 | + let mut node_to_block: HashMap<usize, usize> = HashMap::with_capacity(processing_order.len()); |
| 62 | + for (block_id, block) in blocks.iter().enumerate() { |
| 63 | + for node in &block.1 { |
| 64 | + node_to_block.insert(node.0, block_id); |
| 65 | + } |
| 66 | + } |
| 67 | + // Store nodes where swaps will be placed. |
| 68 | + let mut swap_locations: Vec<&Nodes> = Vec::with_capacity(processing_order.len()); |
| 69 | + |
| 70 | + // Process blocks, gathering swap locations and updating the gate order |
| 71 | + for node in &processing_order { |
| 72 | + if let Some(&block_id) = node_to_block.get(&node.0) { |
| 73 | + // Skip if the block has already been processed |
| 74 | + if !processed_block_ids.insert(block_id) { |
| 75 | + continue; |
| 76 | + } |
| 77 | + process_block( |
| 78 | + &blocks[block_id], |
| 79 | + last_2q_gate, |
| 80 | + &mut is_first_star, |
| 81 | + &mut gate_order, |
| 82 | + &mut swap_locations, |
| 83 | + ); |
| 84 | + } else { |
| 85 | + // Apply operation for nodes not part of any block |
| 86 | + gate_order.push(node.0); |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + // Apply the swaps based on the gathered swap locations and gate order |
| 91 | + for (index, node_id) in gate_order.iter().enumerate() { |
| 92 | + for swap_location in &swap_locations { |
| 93 | + if *node_id == swap_location.0 { |
| 94 | + if let Some(next_node_id) = gate_order.get(index + 1) { |
| 95 | + apply_swap( |
| 96 | + &mut qubit_mapping, |
| 97 | + &swap_location.1, |
| 98 | + *next_node_id, |
| 99 | + &mut out_map, |
| 100 | + ); |
| 101 | + } |
| 102 | + } |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + let res = SabreResult { |
| 107 | + map: SwapMap { map: out_map }, |
| 108 | + node_order: gate_order, |
| 109 | + node_block_results: NodeBlockResults { |
| 110 | + results: node_block_results, |
| 111 | + }, |
| 112 | + }; |
| 113 | + |
| 114 | + let final_res = ( |
| 115 | + res.map, |
| 116 | + res.node_order.into_pyarray_bound(py).into(), |
| 117 | + res.node_block_results, |
| 118 | + qubit_mapping.into_pyarray_bound(py).into(), |
| 119 | + ); |
| 120 | + |
| 121 | + final_res |
| 122 | +} |
| 123 | + |
| 124 | +/// Processes a star block, applying operations and handling swaps. |
| 125 | +/// |
| 126 | +/// Args: |
| 127 | +/// |
| 128 | +/// * `block` - A tuple containing a boolean indicating the presence of a center and a vector of nodes representing the star block. |
| 129 | +/// * `last_2q_gate` - The last two-qubit gate in the processing order. |
| 130 | +/// * `is_first_star` - A mutable reference to a boolean indicating if this is the first star block being processed. |
| 131 | +/// * `gate_order` - A mutable reference to the gate order vector. |
| 132 | +/// * `swap_locations` - A mutable reference to the nodes where swaps will be placed after |
| 133 | +fn process_block<'a>( |
| 134 | + block: &'a Block, |
| 135 | + last_2q_gate: Option<&'a Nodes>, |
| 136 | + is_first_star: &mut bool, |
| 137 | + gate_order: &mut Vec<usize>, |
| 138 | + swap_locations: &mut Vec<&'a Nodes>, |
| 139 | +) { |
| 140 | + let (has_center, sequence) = block; |
| 141 | + |
| 142 | + // If the block contains exactly 2 nodes, apply them directly |
| 143 | + if sequence.len() == 2 { |
| 144 | + for inner_node in sequence { |
| 145 | + gate_order.push(inner_node.0); |
| 146 | + } |
| 147 | + return; |
| 148 | + } |
| 149 | + |
| 150 | + let mut prev_qargs = None; |
| 151 | + let mut swap_source = false; |
| 152 | + |
| 153 | + // Process each node in the block |
| 154 | + for inner_node in sequence.iter() { |
| 155 | + // Apply operation directly if it's a single-qubit operation or the same as previous qargs |
| 156 | + if inner_node.1.len() == 1 || prev_qargs == Some(&inner_node.1) { |
| 157 | + gate_order.push(inner_node.0); |
| 158 | + continue; |
| 159 | + } |
| 160 | + |
| 161 | + // If this is the first star and no swap source has been identified, set swap_source |
| 162 | + if *is_first_star && !swap_source { |
| 163 | + swap_source = *has_center; |
| 164 | + gate_order.push(inner_node.0); |
| 165 | + prev_qargs = Some(&inner_node.1); |
| 166 | + continue; |
| 167 | + } |
| 168 | + |
| 169 | + // Place 2q-gate and subsequent swap gate |
| 170 | + gate_order.push(inner_node.0); |
| 171 | + |
| 172 | + if inner_node != last_2q_gate.unwrap() && inner_node.1.len() == 2 { |
| 173 | + swap_locations.push(inner_node); |
| 174 | + } |
| 175 | + prev_qargs = Some(&inner_node.1); |
| 176 | + } |
| 177 | + *is_first_star = false; |
| 178 | +} |
| 179 | + |
| 180 | +/// Applies a swap operation to the DAG and updates the qubit mapping. |
| 181 | +/// |
| 182 | +/// # Args: |
| 183 | +/// |
| 184 | +/// * `qubit_mapping` - A mutable reference to the qubit mapping vector. |
| 185 | +/// * `qargs` - Qubit indices for the swap operation (node before the swap) |
| 186 | +/// * `next_node_id` - ID of the next node in the gate order (node after the swap) |
| 187 | +/// * `out_map` - A mutable reference to the output map. |
| 188 | +fn apply_swap( |
| 189 | + qubit_mapping: &mut [usize], |
| 190 | + qargs: &[VirtualQubit], |
| 191 | + next_node_id: usize, |
| 192 | + out_map: &mut HashMap<usize, Vec<[PhysicalQubit; 2]>>, |
| 193 | +) { |
| 194 | + if qargs.len() == 2 { |
| 195 | + let idx0 = qargs[0].index(); |
| 196 | + let idx1 = qargs[1].index(); |
| 197 | + |
| 198 | + // Update the `qubit_mapping` and `out_map` to reflect the swap operation |
| 199 | + qubit_mapping.swap(idx0, idx1); |
| 200 | + out_map.insert( |
| 201 | + next_node_id, |
| 202 | + vec![[ |
| 203 | + PhysicalQubit::new(qubit_mapping[idx0].try_into().unwrap()), |
| 204 | + PhysicalQubit::new(qubit_mapping[idx1].try_into().unwrap()), |
| 205 | + ]], |
| 206 | + ); |
| 207 | + } |
| 208 | +} |
| 209 | + |
| 210 | +#[pymodule] |
| 211 | +pub fn star_prerouting(m: &Bound<PyModule>) -> PyResult<()> { |
| 212 | + m.add_wrapped(wrap_pyfunction!(star_preroute))?; |
| 213 | + Ok(()) |
| 214 | +} |
0 commit comments