Skip to content

Commit 93d796f

Browse files
authored
More conservative caching in the CommutationChecker (Qiskit#13600)
* conservative commutation check * tests and reno * reno in the right location * more tests for custom gates
1 parent 86a5325 commit 93d796f

File tree

4 files changed

+117
-31
lines changed

4 files changed

+117
-31
lines changed

crates/accelerate/src/commutation_checker.rs

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,29 @@ use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationF
2828
use qiskit_circuit::dag_node::DAGOpNode;
2929
use qiskit_circuit::imports::QI_OPERATOR;
3030
use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType};
31-
use qiskit_circuit::operations::{Operation, OperationRef, Param, StandardGate};
31+
use qiskit_circuit::operations::{
32+
get_standard_gate_names, Operation, OperationRef, Param, StandardGate,
33+
};
3234
use qiskit_circuit::{BitType, Clbit, Qubit};
3335

3436
use crate::unitary_compose;
3537
use crate::QiskitError;
3638

39+
const TWOPI: f64 = 2.0 * std::f64::consts::PI;
40+
41+
// These gates do not commute with other gates, we do not check them.
3742
static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"];
38-
static NO_CACHE_NAMES: [&str; 2] = ["annotated", "linear_function"];
43+
44+
// We keep a hash-set of operations eligible for commutation checking. This is because checking
45+
// eligibility is not for free.
3946
static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
4047
HashSet::from([
4148
"rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx",
4249
"cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap",
4350
])
4451
});
4552

46-
const TWOPI: f64 = 2.0 * std::f64::consts::PI;
47-
48-
// map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
53+
// Map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
4954
// represent the generator in Rust and store the commutation relation in the commutation dictionary
5055
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
5156
HashMap::from([
@@ -322,15 +327,17 @@ impl CommutationChecker {
322327
(qargs1, qargs2)
323328
};
324329

325-
let skip_cache: bool = NO_CACHE_NAMES.contains(&first_op.name()) ||
326-
NO_CACHE_NAMES.contains(&second_op.name()) ||
327-
// Skip params that do not evaluate to floats for caching and commutation library
328-
first_params.iter().any(|p| !matches!(p, Param::Float(_))) ||
329-
second_params.iter().any(|p| !matches!(p, Param::Float(_)))
330-
&& !SUPPORTED_OP.contains(op1.name())
331-
&& !SUPPORTED_OP.contains(op2.name());
332-
333-
if skip_cache {
330+
// For our cache to work correctly, we require the gate's definition to only depend on the
331+
// ``params`` attribute. This cannot be guaranteed for custom gates, so we only check
332+
// the cache for our standard gates, which we know are defined by the ``params`` AND
333+
// that the ``params`` are float-only at this point.
334+
let whitelist = get_standard_gate_names();
335+
let check_cache = whitelist.contains(&first_op.name())
336+
&& whitelist.contains(&second_op.name())
337+
&& first_params.iter().all(|p| matches!(p, Param::Float(_)))
338+
&& second_params.iter().all(|p| matches!(p, Param::Float(_)));
339+
340+
if !check_cache {
334341
return self.commute_matmul(
335342
py,
336343
first_op,
@@ -630,21 +637,24 @@ fn map_rotation<'a>(
630637
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
631638
let name = op.name();
632639
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
633-
// if the rotation angle is below the tolerance, the gate is assumed to
640+
// If the rotation angle is below the tolerance, the gate is assumed to
634641
// commute with everything, and we simply return the operation with the flag that
635-
// it commutes trivially
642+
// it commutes trivially.
636643
if let Param::Float(angle) = params[0] {
637644
if (angle % TWOPI).abs() < tol {
638645
return (op, params, true);
639646
};
640647
};
641648

642-
// otherwise, we check if a generator is given -- if not, we'll just return the operation
643-
// itself (e.g. RXX does not have a generator and is just stored in the commutations
644-
// dictionary)
649+
// Otherwise we need to cover two cases -- either a generator is given, in which case
650+
// we return it, or we don't have a generator yet, but we know we have the operation
651+
// stored in the commutation library. For example, RXX does not have a generator in Rust
652+
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
653+
// can strip the parameters and just return the gate.
645654
if let Some(gate) = generator {
646655
return (gate, &[], false);
647656
};
657+
return (op, &[], false);
648658
}
649659
(op, params, false)
650660
}

crates/circuit/src/operations.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ static STANDARD_GATE_NAME: [&str; STANDARD_GATE_SIZE] = [
431431
"rcccx", // 51 ("rc3x")
432432
];
433433

434+
/// Get a slice of all standard gate names.
435+
pub fn get_standard_gate_names() -> &'static [&'static str] {
436+
&STANDARD_GATE_NAME
437+
}
438+
434439
impl StandardGate {
435440
pub fn create_py_op(
436441
&self,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
fixes:
3+
- |
4+
Commutation relations of :class:`~.circuit.Instruction`\ s with float-only ``params``
5+
were eagerly cached by the :class:`.CommutationChecker`, using the ``params`` as key to
6+
query the relation. This could lead to faulty results, if the instruction's definition
7+
depended on additional information that just the :attr:`~.circuit.Instruction.params`
8+
attribute, such as e.g. the case for :class:`.PauliEvolutionGate`.
9+
This behavior is now fixed, and the commutation checker only conservatively caches
10+
commutations for Qiskit-native standard gates. This can incur a performance cost if you were
11+
relying on your custom gates being cached, however, we cannot guarantee safe caching for
12+
custom gates, as they might rely on information beyond :attr:`~.circuit.Instruction.params`.
13+
- |
14+
Fixed a bug in the :class:`.CommmutationChecker`, where checking commutation of instruction
15+
with non-numeric values in the :attr:`~.circuit.Instruction.params` attribute (such as the
16+
:class:`.PauliGate`) could raise an error.
17+
Fixed `#13570 <https://github.com/Qiskit/qiskit/issues/13570>`__.
18+

test/python/circuit/test_commutation_checker.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Parameter,
2828
QuantumRegister,
2929
Qubit,
30+
QuantumCircuit,
3031
)
3132
from qiskit.circuit.commutation_library import SessionCommutationChecker as scc
3233
from qiskit.circuit.library import (
@@ -37,9 +38,11 @@
3738
CRYGate,
3839
CRZGate,
3940
CXGate,
41+
CUGate,
4042
LinearFunction,
4143
MCXGate,
4244
Measure,
45+
PauliGate,
4346
PhaseGate,
4447
Reset,
4548
RXGate,
@@ -82,6 +85,22 @@ def to_matrix(self):
8285
return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex)
8386

8487

88+
class MyEvilRXGate(Gate):
89+
"""A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless)."""
90+
91+
def __init__(self, evil_input_not_in_param: float):
92+
"""
93+
Args:
94+
evil_input_not_in_param: The RX rotation angle.
95+
"""
96+
self.value = evil_input_not_in_param
97+
super().__init__("<evil laugh here>", 1, [])
98+
99+
def _define(self):
100+
self.definition = QuantumCircuit(1)
101+
self.definition.rx(self.value, 0)
102+
103+
85104
@ddt
86105
class TestCommutationChecker(QiskitTestCase):
87106
"""Test CommutationChecker class."""
@@ -137,7 +156,7 @@ def test_standard_gates_commutations(self):
137156
def test_caching_positive_results(self):
138157
"""Check that hashing positive results in commutativity checker works as expected."""
139158
scc.clear_cached_commutations()
140-
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
159+
self.assertTrue(scc.commute(ZGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
141160
self.assertGreater(scc.num_cached_entries(), 0)
142161

143162
def test_caching_lookup_with_non_overlapping_qubits(self):
@@ -150,27 +169,29 @@ def test_caching_lookup_with_non_overlapping_qubits(self):
150169
def test_caching_store_and_lookup_with_non_overlapping_qubits(self):
151170
"""Check that commutations storing and lookup with non-overlapping qubits works as expected."""
152171
scc_lenm = scc.num_cached_entries()
153-
self.assertTrue(scc.commute(NewGateCX(), [0, 2], [], CXGate(), [0, 1], []))
154-
self.assertFalse(scc.commute(NewGateCX(), [0, 1], [], CXGate(), [1, 2], []))
155-
self.assertTrue(scc.commute(NewGateCX(), [1, 4], [], CXGate(), [1, 6], []))
156-
self.assertFalse(scc.commute(NewGateCX(), [5, 3], [], CXGate(), [3, 1], []))
172+
cx_like = CUGate(np.pi, 0, np.pi, 0)
173+
self.assertTrue(scc.commute(cx_like, [0, 2], [], CXGate(), [0, 1], []))
174+
self.assertFalse(scc.commute(cx_like, [0, 1], [], CXGate(), [1, 2], []))
175+
self.assertTrue(scc.commute(cx_like, [1, 4], [], CXGate(), [1, 6], []))
176+
self.assertFalse(scc.commute(cx_like, [5, 3], [], CXGate(), [3, 1], []))
157177
self.assertEqual(scc.num_cached_entries(), scc_lenm + 2)
158178

159179
def test_caching_negative_results(self):
160180
"""Check that hashing negative results in commutativity checker works as expected."""
161181
scc.clear_cached_commutations()
162-
self.assertFalse(scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], []))
182+
self.assertFalse(scc.commute(XGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
163183
self.assertGreater(scc.num_cached_entries(), 0)
164184

165185
def test_caching_different_qubit_sets(self):
166186
"""Check that hashing same commutativity results over different qubit sets works as expected."""
167187
scc.clear_cached_commutations()
168188
# All the following should be cached in the same way
169189
# though each relation gets cached twice: (A, B) and (B, A)
170-
scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], [])
171-
scc.commute(XGate(), [10], [], NewGateCX(), [10, 20], [])
172-
scc.commute(XGate(), [10], [], NewGateCX(), [10, 5], [])
173-
scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], [])
190+
cx_like = CUGate(np.pi, 0, np.pi, 0)
191+
scc.commute(XGate(), [0], [], cx_like, [0, 1], [])
192+
scc.commute(XGate(), [10], [], cx_like, [10, 20], [])
193+
scc.commute(XGate(), [10], [], cx_like, [10, 5], [])
194+
scc.commute(XGate(), [5], [], cx_like, [5, 7], [])
174195
self.assertEqual(scc.num_cached_entries(), 1)
175196

176197
def test_zero_rotations(self):
@@ -377,12 +398,14 @@ def test_serialization(self):
377398
"""Test that the commutation checker is correctly serialized"""
378399
import pickle
379400

401+
cx_like = CUGate(np.pi, 0, np.pi, 0)
402+
380403
scc.clear_cached_commutations()
381-
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
404+
self.assertTrue(scc.commute(ZGate(), [0], [], cx_like, [0, 1], []))
382405
cc2 = pickle.loads(pickle.dumps(scc))
383406
self.assertEqual(cc2.num_cached_entries(), 1)
384407
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
385-
dop2 = DAGOpNode(NewGateCX(), qargs=[0, 1], cargs=[])
408+
dop2 = DAGOpNode(cx_like, qargs=[0, 1], cargs=[])
386409
cc2.commute_nodes(dop1, dop2)
387410
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
388411
dop2 = DAGOpNode(CXGate(), qargs=[0, 1], cargs=[])
@@ -430,6 +453,36 @@ def test_rotation_mod_2pi(self, gate_cls):
430453
scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), [])
431454
)
432455

456+
def test_custom_gate(self):
457+
"""Test a custom gate."""
458+
my_cx = NewGateCX()
459+
460+
self.assertTrue(scc.commute(my_cx, [0, 1], [], XGate(), [1], []))
461+
self.assertFalse(scc.commute(my_cx, [0, 1], [], XGate(), [0], []))
462+
self.assertTrue(scc.commute(my_cx, [0, 1], [], ZGate(), [0], []))
463+
464+
self.assertFalse(scc.commute(my_cx, [0, 1], [], my_cx, [1, 0], []))
465+
self.assertTrue(scc.commute(my_cx, [0, 1], [], my_cx, [0, 1], []))
466+
467+
def test_custom_gate_caching(self):
468+
"""Test a custom gate is correctly handled on consecutive runs."""
469+
470+
all_commuter = MyEvilRXGate(0) # this will commute with anything
471+
some_rx = MyEvilRXGate(1.6192) # this should not commute with H
472+
473+
# the order here is important: we're testing whether the gate that commutes with
474+
# everything is used after the first commutation check, regardless of the internal
475+
# gate parameters
476+
self.assertTrue(scc.commute(all_commuter, [0], [], HGate(), [0], []))
477+
self.assertFalse(scc.commute(some_rx, [0], [], HGate(), [0], []))
478+
479+
def test_nonfloat_param(self):
480+
"""Test commutation-checking on a gate that has non-float ``params``."""
481+
pauli_gate = PauliGate("XX")
482+
rx_gate_theta = RXGate(Parameter("Theta"))
483+
self.assertTrue(scc.commute(pauli_gate, [0, 1], [], rx_gate_theta, [0], []))
484+
self.assertTrue(scc.commute(rx_gate_theta, [0], [], pauli_gate, [0, 1], []))
485+
433486

434487
if __name__ == "__main__":
435488
unittest.main()

0 commit comments

Comments
 (0)