Skip to content

Commit 1fc760a

Browse files
authored
Include &Packedinstruction in DAGCircuit::op_nodes iteration (Qiskit#13681)
Uses of the `DAGCircuit::op_nodes` iterator are almost invariably followed by indexing into the DAG to retrieve the node weight, which is dynamically asserted to be `NodeType::Operation`. Since the iterator already verifies this, we can just include the unwrapped `&PackedInstruction` in the iterator return to simply calling code and reduce the number of `unreachable!()`, `panic!()`, etcs. A new `op_node_indices` is provided for ease for cases where the iterator needs to be consumed to avoid being lifetime-tied to the DAG, such as when the DAG is going to be mutated based on the nodes. `topological_op_nodes` is left for a follow-up; it might be slightly more challenging because of the interaction with rustworkx, but it can also be done separately anyway.
1 parent d374427 commit 1fc760a

File tree

11 files changed

+110
-165
lines changed

11 files changed

+110
-165
lines changed

crates/accelerate/src/barrier_before_final_measurement.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,21 @@ pub fn barrier_before_final_measurements(
3030
dag: &mut DAGCircuit,
3131
label: Option<String>,
3232
) -> PyResult<()> {
33+
let is_exactly_final = |inst: &PackedInstruction| FINAL_OP_NAMES.contains(&inst.op.name());
3334
let final_ops: HashSet<NodeIndex> = dag
3435
.op_nodes(true)
35-
.filter(|node| {
36-
let NodeType::Operation(ref inst) = dag.dag()[*node] else {
37-
unreachable!();
38-
};
39-
if !FINAL_OP_NAMES.contains(&inst.op.name()) {
40-
return false;
36+
.filter_map(|(node, inst)| {
37+
if !is_exactly_final(inst) {
38+
return None;
4139
}
42-
let is_final_op = dag.bfs_successors(*node).all(|(_, child_successors)| {
43-
!child_successors.iter().any(|suc| match dag.dag()[*suc] {
44-
NodeType::Operation(ref suc_inst) => {
45-
!FINAL_OP_NAMES.contains(&suc_inst.op.name())
46-
}
47-
_ => false,
40+
dag.bfs_successors(node)
41+
.all(|(_, child_successors)| {
42+
child_successors.iter().all(|suc| match dag.dag()[*suc] {
43+
NodeType::Operation(ref suc_inst) => is_exactly_final(suc_inst),
44+
_ => true,
45+
})
4846
})
49-
});
50-
is_final_op
47+
.then_some(node)
5148
})
5249
.collect();
5350
if final_ops.is_empty() {

crates/accelerate/src/basis/basis_translator/compose_transforms.rs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use qiskit_circuit::parameter_table::ParameterUuid;
1818
use qiskit_circuit::Qubit;
1919
use qiskit_circuit::{
2020
circuit_data::CircuitData,
21-
dag_circuit::{DAGCircuit, NodeType},
21+
dag_circuit::DAGCircuit,
2222
operations::{Operation, Param},
2323
};
2424
use smallvec::SmallVec;
@@ -83,24 +83,18 @@ pub(super) fn compose_transforms<'a>(
8383
for (_, dag) in &mut mapped_instructions.values_mut() {
8484
let nodes_to_replace = dag
8585
.op_nodes(true)
86-
.filter_map(|node| {
87-
if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) {
88-
if (gate_name.as_str(), *gate_num_qubits)
89-
== (op.op.name(), op.op.num_qubits())
90-
{
91-
Some((
92-
node,
93-
op.params_view()
94-
.iter()
95-
.map(|x| x.clone_ref(py))
96-
.collect::<SmallVec<[Param; 3]>>(),
97-
))
98-
} else {
99-
None
100-
}
101-
} else {
102-
None
103-
}
86+
.filter(|(_, op)| {
87+
(op.op.num_qubits() == *gate_num_qubits)
88+
&& (op.op.name() == gate_name.as_str())
89+
})
90+
.map(|(node, op)| {
91+
(
92+
node,
93+
op.params_view()
94+
.iter()
95+
.map(|x| x.clone_ref(py))
96+
.collect::<SmallVec<[Param; 3]>>(),
97+
)
10498
})
10599
.collect::<Vec<_>>();
106100
for (node, params) in nodes_to_replace {
@@ -141,17 +135,15 @@ fn get_gates_num_params(
141135
dag: &DAGCircuit,
142136
example_gates: &mut HashMap<GateIdentifier, usize>,
143137
) -> PyResult<()> {
144-
for node in dag.op_nodes(true) {
145-
if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) {
146-
example_gates.insert(
147-
(op.op.name().to_string(), op.op.num_qubits()),
148-
op.params_view().len(),
149-
);
150-
if op.op.control_flow() {
151-
let blocks = op.op.blocks();
152-
for block in blocks {
153-
get_gates_num_params_circuit(&block, example_gates)?;
154-
}
138+
for (_, inst) in dag.op_nodes(true) {
139+
example_gates.insert(
140+
(inst.op.name().to_string(), inst.op.num_qubits()),
141+
inst.params_view().len(),
142+
);
143+
if inst.op.control_flow() {
144+
let blocks = inst.op.blocks();
145+
for block in blocks {
146+
get_gates_num_params_circuit(&block, example_gates)?;
155147
}
156148
}
157149
}

crates/accelerate/src/basis/basis_translator/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,7 @@ fn extract_basis(
212212
basis: &mut HashSet<GateIdentifier>,
213213
min_qubits: usize,
214214
) -> PyResult<()> {
215-
for node in circuit.op_nodes(true) {
216-
let operation: &PackedInstruction = circuit.dag()[node].unwrap_operation();
215+
for (node, operation) in circuit.op_nodes(true) {
217216
if !circuit.has_calibration_for_index(py, node)?
218217
&& circuit.get_qargs(operation.qubits).len() >= min_qubits
219218
{
@@ -279,8 +278,7 @@ fn extract_basis_target(
279278
min_qubits: usize,
280279
qargs_with_non_global_operation: &HashMap<Option<Qargs>, HashSet<String>>,
281280
) -> PyResult<()> {
282-
for node in dag.op_nodes(true) {
283-
let node_obj: &PackedInstruction = dag.dag()[node].unwrap_operation();
281+
for (node, node_obj) in dag.op_nodes(true) {
284282
let qargs: &[Qubit] = dag.get_qargs(node_obj.qubits);
285283
if dag.has_calibration_for_index(py, node)? || qargs.len() < min_qubits {
286284
continue;

crates/accelerate/src/check_map.rs

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use pyo3::prelude::*;
1616
use pyo3::wrap_pyfunction;
1717

1818
use qiskit_circuit::circuit_data::CircuitData;
19-
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType};
19+
use qiskit_circuit::dag_circuit::DAGCircuit;
2020
use qiskit_circuit::imports::CIRCUIT_TO_DAG;
2121
use qiskit_circuit::operations::{Operation, OperationRef};
2222
use qiskit_circuit::Qubit;
@@ -36,45 +36,43 @@ fn recurse<'py>(
3636
None => edge_set.contains(&[qubits[0].into(), qubits[1].into()]),
3737
}
3838
};
39-
for node in dag.op_nodes(false) {
40-
if let NodeType::Operation(inst) = &dag.dag()[node] {
41-
let qubits = dag.get_qargs(inst.qubits);
42-
if inst.op.control_flow() {
43-
if let OperationRef::Instruction(py_inst) = inst.op.view() {
44-
let raw_blocks = py_inst.instruction.getattr(py, "blocks")?;
45-
let circuit_to_dag = CIRCUIT_TO_DAG.get_bound(py);
46-
for raw_block in raw_blocks.bind(py).iter().unwrap() {
47-
let block_obj = raw_block?;
48-
let block = block_obj
49-
.getattr(intern!(py, "_data"))?
50-
.downcast::<CircuitData>()?
51-
.borrow();
52-
let new_dag: DAGCircuit =
53-
circuit_to_dag.call1((block_obj.clone(),))?.extract()?;
54-
let wire_map = (0..block.num_qubits())
55-
.map(|inner| {
56-
let outer = qubits[inner];
57-
match wire_map {
58-
Some(wire_map) => wire_map[outer.index()],
59-
None => outer,
60-
}
61-
})
62-
.collect::<Vec<_>>();
63-
let res = recurse(py, &new_dag, edge_set, Some(&wire_map))?;
64-
if res.is_some() {
65-
return Ok(res);
66-
}
39+
for (node, inst) in dag.op_nodes(false) {
40+
let qubits = dag.get_qargs(inst.qubits);
41+
if inst.op.control_flow() {
42+
if let OperationRef::Instruction(py_inst) = inst.op.view() {
43+
let raw_blocks = py_inst.instruction.getattr(py, "blocks")?;
44+
let circuit_to_dag = CIRCUIT_TO_DAG.get_bound(py);
45+
for raw_block in raw_blocks.bind(py).iter().unwrap() {
46+
let block_obj = raw_block?;
47+
let block = block_obj
48+
.getattr(intern!(py, "_data"))?
49+
.downcast::<CircuitData>()?
50+
.borrow();
51+
let new_dag: DAGCircuit =
52+
circuit_to_dag.call1((block_obj.clone(),))?.extract()?;
53+
let wire_map = (0..block.num_qubits())
54+
.map(|inner| {
55+
let outer = qubits[inner];
56+
match wire_map {
57+
Some(wire_map) => wire_map[outer.index()],
58+
None => outer,
59+
}
60+
})
61+
.collect::<Vec<_>>();
62+
let res = recurse(py, &new_dag, edge_set, Some(&wire_map))?;
63+
if res.is_some() {
64+
return Ok(res);
6765
}
6866
}
69-
} else if qubits.len() == 2
70-
&& (dag.calibrations_empty() || !dag.has_calibration_for_index(py, node)?)
71-
&& !check_qubits(qubits)
72-
{
73-
return Ok(Some((
74-
inst.op.name().to_string(),
75-
[qubits[0].0, qubits[1].0],
76-
)));
7767
}
68+
} else if qubits.len() == 2
69+
&& (dag.calibrations_empty() || !dag.has_calibration_for_index(py, node)?)
70+
&& !check_qubits(qubits)
71+
{
72+
return Ok(Some((
73+
inst.op.name().to_string(),
74+
[qubits[0].0, qubits[1].0],
75+
)));
7876
}
7977
}
8078
Ok(None)

crates/accelerate/src/filter_op_nodes.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub fn py_filter_op_nodes(
2929
predicate.call1((dag_op_node,))?.extract()
3030
};
3131
let mut remove_nodes: Vec<NodeIndex> = Vec::new();
32-
for node in dag.op_nodes(true) {
32+
for node in dag.op_node_indices(true) {
3333
if !callable(node)? {
3434
remove_nodes.push(node);
3535
}

crates/accelerate/src/gate_direction.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use qiskit_circuit::{
2323
circuit_instruction::CircuitInstruction,
2424
circuit_instruction::ExtraInstructionAttributes,
2525
converters::{circuit_to_dag, QuantumCircuitData},
26-
dag_circuit::{DAGCircuit, NodeType},
26+
dag_circuit::DAGCircuit,
2727
dag_node::{DAGNode, DAGOpNode},
2828
imports,
2929
imports::get_std_gate_class,
@@ -105,11 +105,7 @@ fn check_gate_direction<T>(
105105
where
106106
T: Fn(&PackedInstruction, &[Qubit]) -> bool,
107107
{
108-
for node in dag.op_nodes(false) {
109-
let NodeType::Operation(packed_inst) = &dag.dag()[node] else {
110-
panic!("PackedInstruction is expected");
111-
};
112-
108+
for (_, packed_inst) in dag.op_nodes(false) {
113109
let inst_qargs = dag.get_qargs(packed_inst.qubits);
114110

115111
if let OperationRef::Instruction(py_inst) = packed_inst.op.view() {
@@ -254,9 +250,7 @@ where
254250
let mut nodes_to_replace: Vec<(NodeIndex, DAGCircuit)> = Vec::new();
255251
let mut ops_to_replace: Vec<(NodeIndex, Vec<Bound<PyAny>>)> = Vec::new();
256252

257-
for node in dag.op_nodes(false) {
258-
let packed_inst = dag.dag()[node].unwrap_operation();
259-
253+
for (node, packed_inst) in dag.op_nodes(false) {
260254
let op_args = dag.get_qargs(packed_inst.qubits);
261255

262256
if let OperationRef::Instruction(py_inst) = packed_inst.op.view() {

crates/accelerate/src/gates_in_basis.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ fn any_gate_missing_from_target(dag: &DAGCircuit, target: &Target) -> PyResult<b
7979
);
8080

8181
// Process the DAG.
82-
for gate in dag.op_nodes(true) {
83-
let gate = dag.dag()[gate].unwrap_operation();
82+
for (_, gate) in dag.op_nodes(true) {
8483
if is_universal(gate) {
8584
continue;
8685
}

crates/accelerate/src/remove_diagonal_gates_before_measure.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,7 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
4747
static DIAGONAL_3Q_GATES: [StandardGate; 1] = [StandardGate::CCZGate];
4848

4949
let mut nodes_to_remove = Vec::new();
50-
for index in dag.op_nodes(true) {
51-
let node = &dag.dag()[index];
52-
let NodeType::Operation(inst) = node else {
53-
panic!()
54-
};
55-
50+
for (index, inst) in dag.op_nodes(true) {
5651
if inst.op.name() == "measure" {
5752
let predecessor = (dag.quantum_predecessors(index))
5853
.next()

crates/accelerate/src/remove_identity_equiv.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ fn remove_identity_equiv(
7474
}
7575
};
7676

77-
for op_node in dag.op_nodes(false) {
78-
let inst = dag.dag()[op_node].unwrap_operation();
77+
for (op_node, inst) in dag.op_nodes(false) {
7978
match inst.op.view() {
8079
OperationRef::Standard(gate) => {
8180
let (dim, trace) = match gate {

crates/accelerate/src/split_2q_unitaries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub fn split_2q_unitaries(
3131
if !dag.get_op_counts().contains_key("unitary") {
3232
return Ok(());
3333
}
34-
let nodes: Vec<NodeIndex> = dag.op_nodes(false).collect();
34+
let nodes: Vec<NodeIndex> = dag.op_node_indices(false).collect();
3535

3636
for node in nodes {
3737
if let NodeType::Operation(inst) = &dag.dag()[node] {

0 commit comments

Comments
 (0)