Skip to content

Commit d16e6c9

Browse files
Fix global phase handling in CommutativeCancellation (backport Qiskit#14956) (Qiskit#14961)
* Fix global phase handling in `CommutativeCancellation` (Qiskit#14956) * fix cc phase * Fix case if `P/U1` is the gate we reduce to * Fix 4-eps case and add tests Co-authored-by: Alexander Ivrii <[email protected]> * fix no var_z_gate case * review comments + reno * fix accumulation of T/S to P/U1 * Update releasenotes/notes/fix-commcanc-phase-f68fbb428363f081.yaml Co-authored-by: Alexander Ivrii <[email protected]> --------- Co-authored-by: Alexander Ivrii <[email protected]> (cherry picked from commit 1064921) # Conflicts: # crates/transpiler/src/passes/commutation_cancellation.rs * fix merge conflicts --------- Co-authored-by: Julien Gacon <[email protected]>
1 parent 0af043f commit d16e6c9

File tree

3 files changed

+126
-52
lines changed

3 files changed

+126
-52
lines changed

crates/transpiler/src/passes/commutation_cancellation.rs

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// copyright notice, and modified files need to carry a notice indicating
1111
// that they have been altered from the originals.
1212

13-
use std::f64::consts::PI;
13+
use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
1414

1515
use hashbrown::{HashMap, HashSet};
1616
use pyo3::exceptions::PyRuntimeError;
@@ -19,19 +19,16 @@ use pyo3::{pyfunction, wrap_pyfunction, Bound, PyResult, Python};
1919
use rustworkx_core::petgraph::stable_graph::NodeIndex;
2020
use smallvec::{smallvec, SmallVec};
2121

22-
use crate::commutation_checker::CommutationChecker;
2322
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
2423
use qiskit_circuit::operations::{Operation, Param, StandardGate};
2524
use qiskit_circuit::Qubit;
25+
use qiskit_synthesis::QiskitError;
2626

2727
use super::analyze_commutations;
28-
use qiskit_synthesis::{euler_one_qubit_decomposer, QiskitError};
28+
use crate::commutation_checker::CommutationChecker;
2929

3030
const _CUTOFF_PRECISION: f64 = 1e-5;
3131
static ROTATION_GATES: [&str; 4] = ["p", "u1", "rz", "rx"];
32-
static HALF_TURNS: [&str; 2] = ["z", "x"];
33-
static QUARTER_TURNS: [&str; 1] = ["s"];
34-
static EIGHTH_TURNS: [&str; 1] = ["t"];
3532

3633
static VAR_Z_MAP: [(&str, StandardGate); 3] = [
3734
("rz", StandardGate::RZ),
@@ -93,14 +90,39 @@ pub fn cancel_commutations(
9390
.map(|(_, gate)| gate)
9491
})
9592
.or_else(|| {
93+
// Fallback to the first matching key from basis if there is no match in dag.op_names
9694
basis.iter().find_map(|g| {
9795
VAR_Z_MAP
9896
.iter()
9997
.find(|(key, _)| *key == g.as_str())
10098
.map(|(_, gate)| gate)
10199
})
102100
});
103-
// Fallback to the first matching key from basis if there is no match in dag.op_names
101+
102+
// RZ and P/U1 have a phase difference of angle/2, which we need to account for
103+
let z_phase_shift = match z_var_gate {
104+
Some(z_var_gate) => {
105+
if z_var_gate == &StandardGate::RZ {
106+
|gate_name: &str, angle: f64| -> f64 {
107+
if ["u1", "p"].contains(&gate_name) {
108+
angle / 2.
109+
} else {
110+
0.
111+
}
112+
}
113+
} else {
114+
|gate_name: &str, angle: f64| -> f64 {
115+
if gate_name == "rz" {
116+
-angle / 2.
117+
} else {
118+
0.
119+
}
120+
}
121+
}
122+
}
123+
// if there's no z_var_gate detected, Z rotations are not merged, so we have no phase shifts
124+
None => |_name: &str, _angle: f64| -> f64 { 0. },
125+
};
104126

105127
// Gate sets to be cancelled
106128
/* Traverse each qubit to generate the cancel dictionaries
@@ -219,38 +241,29 @@ pub fn cancel_commutations(
219241
};
220242
let node_op_name = node_op.op.name();
221243

222-
let node_angle = if ROTATION_GATES.contains(&node_op_name) {
223-
match node_op.params_view().first() {
224-
Some(Param::Float(f)) => Ok(*f),
244+
let (node_angle, phase_shift) = if ROTATION_GATES.contains(&node_op_name) {
245+
let node_angle = match node_op.params_view().first() {
246+
Some(Param::Float(f)) => *f,
225247
_ => return Err(QiskitError::new_err(format!(
226248
"Rotational gate with parameter expression encountered in cancellation {:?}",
227249
node_op.op
228250
)))
229-
}
230-
} else if HALF_TURNS.contains(&node_op_name) {
231-
Ok(PI)
232-
} else if QUARTER_TURNS.contains(&node_op_name) {
233-
Ok(PI / 2.0)
234-
} else if EIGHTH_TURNS.contains(&node_op_name) {
235-
Ok(PI / 4.0)
251+
};
252+
let phase_shift = z_phase_shift(node_op_name, node_angle);
253+
Ok((node_angle, phase_shift))
236254
} else {
237-
Err(PyRuntimeError::new_err(format!(
238-
"Angle for operation {} is not defined",
239-
node_op_name
240-
)))
241-
};
242-
total_angle += node_angle?;
243-
244-
let Param::Float(new_phase) = node_op
245-
.op
246-
.definition(node_op.params_view())
247-
.unwrap()
248-
.global_phase()
249-
.clone()
250-
else {
251-
unreachable!()
252-
};
253-
total_phase += new_phase
255+
match node_op_name {
256+
"t" => Ok((FRAC_PI_4, z_phase_shift("p", FRAC_PI_4))),
257+
"s" => Ok((FRAC_PI_2, z_phase_shift("p", FRAC_PI_2))),
258+
"z" => Ok((PI, z_phase_shift("p", PI))),
259+
"x" => Ok((PI, FRAC_PI_2)),
260+
_ => Err(PyRuntimeError::new_err(format!(
261+
"Angle for operation {node_op_name} is not defined"
262+
))),
263+
}
264+
}?;
265+
total_angle += node_angle;
266+
total_phase += phase_shift;
254267
}
255268

256269
let new_op = match cancel_key.gate {
@@ -259,24 +272,21 @@ pub fn cancel_commutations(
259272
_ => unreachable!(),
260273
};
261274

262-
let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.);
275+
let pi_multiple = total_angle / PI;
263276

264-
let new_op_phase: f64 = if gate_angle.abs() > _CUTOFF_PRECISION {
265-
dag.insert_1q_on_incoming_qubit((*new_op, &[total_angle]), cancel_set[0]);
266-
let Param::Float(new_phase) = new_op
267-
.definition(&[Param::Float(total_angle)])
268-
.unwrap()
269-
.global_phase()
270-
.clone()
271-
else {
272-
unreachable!();
273-
};
274-
new_phase
277+
let mod4 = pi_multiple.rem_euclid(4.);
278+
if mod4 < _CUTOFF_PRECISION || (4. - mod4) < _CUTOFF_PRECISION {
279+
// if the angle is close to a 4-pi multiple (from above or below), then the
280+
// operator is equal to the identity
281+
} else if (mod4 - 2.).abs() < _CUTOFF_PRECISION {
282+
// a 2-pi multiple has a phase of pi: RX(2pi) = RZ(2pi) = -I = I exp(i pi)
283+
total_phase -= PI;
275284
} else {
276-
0.0
277-
};
285+
// any other is not the identity and we add the gate
286+
dag.insert_1q_on_incoming_qubit((*new_op, &[total_angle]), cancel_set[0]);
287+
}
278288

279-
dag.add_global_phase(&Param::Float(total_phase - new_op_phase))?;
289+
dag.add_global_phase(&Param::Float(total_phase))?;
280290

281291
for node in cancel_set {
282292
dag.remove_op_node(*node);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed several issues in the :class:`.CommutativeCancellation` transpiler pass (and thereby in
5+
:func:`.transpile`), where the global phase of the circuit was not updated correctly.
6+
In particular, merging an X-gate and an RX-gate introduced a phase mismatch,
7+
while removing a Pauli rotation gate with angle of the form
8+
:math:`(2 + 4k)\pi`, :math:`k \in \mathbb Z` incorrectly produced a phase shift of :math:`-1`.

test/python/transpiler/test_commutative_cancellation.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,61 @@ def test_all_gates(self):
8080
expected.global_phase = 0.5
8181
self.assertEqual(expected, new_circuit)
8282

83+
def test_2pi_multiples(self):
84+
"""Test 2pi multiples are handled with the correct phase they introduce."""
85+
for eps in [0, 1e-10, -1e-10]:
86+
for sign in [-1, 1]:
87+
qc = QuantumCircuit(1)
88+
qc.rz(sign * np.pi + eps, 0)
89+
qc.rz(sign * np.pi, 0)
90+
91+
with self.subTest(msg="single 2pi", sign=sign, eps=eps):
92+
tqc = CommutativeCancellation()(qc)
93+
self.assertEqual(0, len(tqc.count_ops()))
94+
self.assertAlmostEqual(np.pi, tqc.global_phase)
95+
96+
for sign_x in [-1, 1]:
97+
for sign_z in [-1, 1]:
98+
qc = QuantumCircuit(2)
99+
qc.rx(sign_x * np.pi + eps, 0)
100+
qc.rx(sign_x * np.pi, 0)
101+
qc.rz(sign_z * np.pi, 1)
102+
qc.rz(sign_z * np.pi, 1)
103+
104+
with self.subTest(msg="two 2pi", sign_x=sign_x, sign_z=sign_z, eps=eps):
105+
tqc = CommutativeCancellation()(qc)
106+
self.assertEqual(0, len(tqc.count_ops()))
107+
self.assertAlmostEqual(0, tqc.global_phase)
108+
109+
def test_4pi_multiples(self):
110+
"""Test 4pi multiples are removed w/o changing the global phase."""
111+
for eps in [0, 1e-10, -1e-10]:
112+
for sign in [-1, 1]:
113+
qc = QuantumCircuit(1)
114+
qc.rz(sign * np.pi + eps, 0)
115+
qc.rz(sign * 6 * np.pi, 0)
116+
qc.rz(sign * np.pi, 0)
117+
118+
with self.subTest(sign=sign, eps=eps):
119+
tqc = CommutativeCancellation()(qc)
120+
self.assertEqual(0, len(tqc.count_ops()))
121+
self.assertAlmostEqual(0, tqc.global_phase)
122+
123+
def test_fixed_rotation_accumulation(self):
124+
"""Test accumulating gates with fixed angles (T, S) works correctly."""
125+
cc = CommutativeCancellation()
126+
127+
# test for U1, P and RZ as target gate
128+
for gate_cls in [RZGate, PhaseGate, U1Gate]:
129+
qc = QuantumCircuit(1)
130+
gate = gate_cls(0.2)
131+
qc.append(gate, [0])
132+
qc.t(0)
133+
qc.s(0)
134+
135+
tqc = cc(qc)
136+
self.assertTrue(np.allclose(Operator(qc).data, Operator(tqc).data))
137+
83138
def test_commutative_circuit1(self):
84139
"""A simple circuit where three CNOTs commute, the first and the last cancel.
85140
@@ -149,9 +204,10 @@ def test_consecutive_cnots2(self):
149204
)
150205
)
151206
new_circuit = passmanager.run(circuit)
152-
expected = QuantumCircuit(qr)
207+
expected = QuantumCircuit(qr, global_phase=np.pi) # RX(2pi) = -I = exp(i pi) I
153208

154209
self.assertEqual(expected, new_circuit)
210+
self.assertTrue(np.allclose(Operator(circuit).data, Operator(expected).data))
155211

156212
def test_2_alternating_cnots(self):
157213
"""A simple circuit where nothing should be cancelled.
@@ -669,9 +725,9 @@ def test_simple_if_else(self):
669725
(test.clbits[0], True), base_test1.copy(), base_test2.copy(), test.qubits, test.clbits
670726
)
671727

672-
expected = QuantumCircuit(3, 3)
728+
expected = QuantumCircuit(3, 3, global_phase=np.pi / 2)
673729
expected.h(0)
674-
expected.rx(np.pi + 0.2, 0)
730+
expected.rx(np.pi + 0.2, 0) # transforming X into RX(pi) introduces a pi/2 global phase
675731
expected.measure(0, 0)
676732
expected.x(0)
677733

0 commit comments

Comments
 (0)