2727 Parameter ,
2828 QuantumRegister ,
2929 Qubit ,
30+ QuantumCircuit ,
3031)
3132from qiskit .circuit .commutation_library import SessionCommutationChecker as scc
3233from qiskit .circuit .library import (
3738 CRYGate ,
3839 CRZGate ,
3940 CXGate ,
41+ CUGate ,
4042 LinearFunction ,
4143 MCXGate ,
4244 Measure ,
45+ PauliGate ,
4346 PhaseGate ,
4447 Reset ,
4548 RXGate ,
@@ -82,6 +85,22 @@ def to_matrix(self):
8285 return np .array ([[1 , 0 , 0 , 0 ], [0 , 0 , 0 , 1 ], [0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 ]], dtype = complex )
8386
8487
88+ class MyEvilRXGate (Gate ):
89+ """A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless)."""
90+
91+ def __init__ (self , evil_input_not_in_param : float ):
92+ """
93+ Args:
94+ evil_input_not_in_param: The RX rotation angle.
95+ """
96+ self .value = evil_input_not_in_param
97+ super ().__init__ ("<evil laugh here>" , 1 , [])
98+
99+ def _define (self ):
100+ self .definition = QuantumCircuit (1 )
101+ self .definition .rx (self .value , 0 )
102+
103+
85104@ddt
86105class TestCommutationChecker (QiskitTestCase ):
87106 """Test CommutationChecker class."""
@@ -137,7 +156,7 @@ def test_standard_gates_commutations(self):
137156 def test_caching_positive_results (self ):
138157 """Check that hashing positive results in commutativity checker works as expected."""
139158 scc .clear_cached_commutations ()
140- self .assertTrue (scc .commute (ZGate (), [0 ], [], NewGateCX ( ), [0 , 1 ], []))
159+ self .assertTrue (scc .commute (ZGate (), [0 ], [], CUGate ( 1 , 2 , 3 , 0 ), [0 , 1 ], []))
141160 self .assertGreater (scc .num_cached_entries (), 0 )
142161
143162 def test_caching_lookup_with_non_overlapping_qubits (self ):
@@ -150,27 +169,29 @@ def test_caching_lookup_with_non_overlapping_qubits(self):
150169 def test_caching_store_and_lookup_with_non_overlapping_qubits (self ):
151170 """Check that commutations storing and lookup with non-overlapping qubits works as expected."""
152171 scc_lenm = scc .num_cached_entries ()
153- self .assertTrue (scc .commute (NewGateCX (), [0 , 2 ], [], CXGate (), [0 , 1 ], []))
154- self .assertFalse (scc .commute (NewGateCX (), [0 , 1 ], [], CXGate (), [1 , 2 ], []))
155- self .assertTrue (scc .commute (NewGateCX (), [1 , 4 ], [], CXGate (), [1 , 6 ], []))
156- self .assertFalse (scc .commute (NewGateCX (), [5 , 3 ], [], CXGate (), [3 , 1 ], []))
172+ cx_like = CUGate (np .pi , 0 , np .pi , 0 )
173+ self .assertTrue (scc .commute (cx_like , [0 , 2 ], [], CXGate (), [0 , 1 ], []))
174+ self .assertFalse (scc .commute (cx_like , [0 , 1 ], [], CXGate (), [1 , 2 ], []))
175+ self .assertTrue (scc .commute (cx_like , [1 , 4 ], [], CXGate (), [1 , 6 ], []))
176+ self .assertFalse (scc .commute (cx_like , [5 , 3 ], [], CXGate (), [3 , 1 ], []))
157177 self .assertEqual (scc .num_cached_entries (), scc_lenm + 2 )
158178
159179 def test_caching_negative_results (self ):
160180 """Check that hashing negative results in commutativity checker works as expected."""
161181 scc .clear_cached_commutations ()
162- self .assertFalse (scc .commute (XGate (), [0 ], [], NewGateCX ( ), [0 , 1 ], []))
182+ self .assertFalse (scc .commute (XGate (), [0 ], [], CUGate ( 1 , 2 , 3 , 0 ), [0 , 1 ], []))
163183 self .assertGreater (scc .num_cached_entries (), 0 )
164184
165185 def test_caching_different_qubit_sets (self ):
166186 """Check that hashing same commutativity results over different qubit sets works as expected."""
167187 scc .clear_cached_commutations ()
168188 # All the following should be cached in the same way
169189 # though each relation gets cached twice: (A, B) and (B, A)
170- scc .commute (XGate (), [0 ], [], NewGateCX (), [0 , 1 ], [])
171- scc .commute (XGate (), [10 ], [], NewGateCX (), [10 , 20 ], [])
172- scc .commute (XGate (), [10 ], [], NewGateCX (), [10 , 5 ], [])
173- scc .commute (XGate (), [5 ], [], NewGateCX (), [5 , 7 ], [])
190+ cx_like = CUGate (np .pi , 0 , np .pi , 0 )
191+ scc .commute (XGate (), [0 ], [], cx_like , [0 , 1 ], [])
192+ scc .commute (XGate (), [10 ], [], cx_like , [10 , 20 ], [])
193+ scc .commute (XGate (), [10 ], [], cx_like , [10 , 5 ], [])
194+ scc .commute (XGate (), [5 ], [], cx_like , [5 , 7 ], [])
174195 self .assertEqual (scc .num_cached_entries (), 1 )
175196
176197 def test_zero_rotations (self ):
@@ -377,12 +398,14 @@ def test_serialization(self):
377398 """Test that the commutation checker is correctly serialized"""
378399 import pickle
379400
401+ cx_like = CUGate (np .pi , 0 , np .pi , 0 )
402+
380403 scc .clear_cached_commutations ()
381- self .assertTrue (scc .commute (ZGate (), [0 ], [], NewGateCX () , [0 , 1 ], []))
404+ self .assertTrue (scc .commute (ZGate (), [0 ], [], cx_like , [0 , 1 ], []))
382405 cc2 = pickle .loads (pickle .dumps (scc ))
383406 self .assertEqual (cc2 .num_cached_entries (), 1 )
384407 dop1 = DAGOpNode (ZGate (), qargs = [0 ], cargs = [])
385- dop2 = DAGOpNode (NewGateCX () , qargs = [0 , 1 ], cargs = [])
408+ dop2 = DAGOpNode (cx_like , qargs = [0 , 1 ], cargs = [])
386409 cc2 .commute_nodes (dop1 , dop2 )
387410 dop1 = DAGOpNode (ZGate (), qargs = [0 ], cargs = [])
388411 dop2 = DAGOpNode (CXGate (), qargs = [0 , 1 ], cargs = [])
@@ -430,6 +453,36 @@ def test_rotation_mod_2pi(self, gate_cls):
430453 scc .commute (generic_gate , [0 ], [], gate , list (range (gate .num_qubits )), [])
431454 )
432455
456+ def test_custom_gate (self ):
457+ """Test a custom gate."""
458+ my_cx = NewGateCX ()
459+
460+ self .assertTrue (scc .commute (my_cx , [0 , 1 ], [], XGate (), [1 ], []))
461+ self .assertFalse (scc .commute (my_cx , [0 , 1 ], [], XGate (), [0 ], []))
462+ self .assertTrue (scc .commute (my_cx , [0 , 1 ], [], ZGate (), [0 ], []))
463+
464+ self .assertFalse (scc .commute (my_cx , [0 , 1 ], [], my_cx , [1 , 0 ], []))
465+ self .assertTrue (scc .commute (my_cx , [0 , 1 ], [], my_cx , [0 , 1 ], []))
466+
467+ def test_custom_gate_caching (self ):
468+ """Test a custom gate is correctly handled on consecutive runs."""
469+
470+ all_commuter = MyEvilRXGate (0 ) # this will commute with anything
471+ some_rx = MyEvilRXGate (1.6192 ) # this should not commute with H
472+
473+ # the order here is important: we're testing whether the gate that commutes with
474+ # everything is used after the first commutation check, regardless of the internal
475+ # gate parameters
476+ self .assertTrue (scc .commute (all_commuter , [0 ], [], HGate (), [0 ], []))
477+ self .assertFalse (scc .commute (some_rx , [0 ], [], HGate (), [0 ], []))
478+
479+ def test_nonfloat_param (self ):
480+ """Test commutation-checking on a gate that has non-float ``params``."""
481+ pauli_gate = PauliGate ("XX" )
482+ rx_gate_theta = RXGate (Parameter ("Theta" ))
483+ self .assertTrue (scc .commute (pauli_gate , [0 , 1 ], [], rx_gate_theta , [0 ], []))
484+ self .assertTrue (scc .commute (rx_gate_theta , [0 ], [], pauli_gate , [0 , 1 ], []))
485+
433486
434487if __name__ == "__main__" :
435488 unittest .main ()
0 commit comments