1010// copyright notice, and modified files need to carry a notice indicating
1111// that they have been altered from the originals.
1212
13- use std:: f64:: consts:: PI ;
13+ use std:: f64:: consts:: { FRAC_PI_2 , FRAC_PI_4 , PI } ;
1414
1515use hashbrown:: { HashMap , HashSet } ;
1616use pyo3:: exceptions:: PyRuntimeError ;
@@ -19,19 +19,16 @@ use pyo3::{pyfunction, wrap_pyfunction, Bound, PyResult, Python};
1919use rustworkx_core:: petgraph:: stable_graph:: NodeIndex ;
2020use smallvec:: { smallvec, SmallVec } ;
2121
22- use crate :: commutation_checker:: CommutationChecker ;
2322use qiskit_circuit:: dag_circuit:: { DAGCircuit , NodeType , Wire } ;
2423use qiskit_circuit:: operations:: { Operation , Param , StandardGate } ;
2524use qiskit_circuit:: Qubit ;
25+ use qiskit_synthesis:: QiskitError ;
2626
2727use super :: analyze_commutations;
28- use qiskit_synthesis :: { euler_one_qubit_decomposer , QiskitError } ;
28+ use crate :: commutation_checker :: CommutationChecker ;
2929
3030const _CUTOFF_PRECISION: f64 = 1e-5 ;
3131static ROTATION_GATES : [ & str ; 4 ] = [ "p" , "u1" , "rz" , "rx" ] ;
32- static HALF_TURNS : [ & str ; 2 ] = [ "z" , "x" ] ;
33- static QUARTER_TURNS : [ & str ; 1 ] = [ "s" ] ;
34- static EIGHTH_TURNS : [ & str ; 1 ] = [ "t" ] ;
3532
3633static VAR_Z_MAP : [ ( & str , StandardGate ) ; 3 ] = [
3734 ( "rz" , StandardGate :: RZ ) ,
@@ -93,14 +90,39 @@ pub fn cancel_commutations(
9390 . map ( |( _, gate) | gate)
9491 } )
9592 . or_else ( || {
93+ // Fallback to the first matching key from basis if there is no match in dag.op_names
9694 basis. iter ( ) . find_map ( |g| {
9795 VAR_Z_MAP
9896 . iter ( )
9997 . find ( |( key, _) | * key == g. as_str ( ) )
10098 . map ( |( _, gate) | gate)
10199 } )
102100 } ) ;
103- // Fallback to the first matching key from basis if there is no match in dag.op_names
101+
102+ // RZ and P/U1 have a phase difference of angle/2, which we need to account for
103+ let z_phase_shift = match z_var_gate {
104+ Some ( z_var_gate) => {
105+ if z_var_gate == & StandardGate :: RZ {
106+ |gate_name : & str , angle : f64 | -> f64 {
107+ if [ "u1" , "p" ] . contains ( & gate_name) {
108+ angle / 2.
109+ } else {
110+ 0.
111+ }
112+ }
113+ } else {
114+ |gate_name : & str , angle : f64 | -> f64 {
115+ if gate_name == "rz" {
116+ -angle / 2.
117+ } else {
118+ 0.
119+ }
120+ }
121+ }
122+ }
123+ // if there's no z_var_gate detected, Z rotations are not merged, so we have no phase shifts
124+ None => |_name : & str , _angle : f64 | -> f64 { 0. } ,
125+ } ;
104126
105127 // Gate sets to be cancelled
106128 /* Traverse each qubit to generate the cancel dictionaries
@@ -219,38 +241,29 @@ pub fn cancel_commutations(
219241 } ;
220242 let node_op_name = node_op. op . name ( ) ;
221243
222- let node_angle = if ROTATION_GATES . contains ( & node_op_name) {
223- match node_op. params_view ( ) . first ( ) {
224- Some ( Param :: Float ( f) ) => Ok ( * f ) ,
244+ let ( node_angle, phase_shift ) = if ROTATION_GATES . contains ( & node_op_name) {
245+ let node_angle = match node_op. params_view ( ) . first ( ) {
246+ Some ( Param :: Float ( f) ) => * f ,
225247 _ => return Err ( QiskitError :: new_err ( format ! (
226248 "Rotational gate with parameter expression encountered in cancellation {:?}" ,
227249 node_op. op
228250 ) ) )
229- }
230- } else if HALF_TURNS . contains ( & node_op_name) {
231- Ok ( PI )
232- } else if QUARTER_TURNS . contains ( & node_op_name) {
233- Ok ( PI / 2.0 )
234- } else if EIGHTH_TURNS . contains ( & node_op_name) {
235- Ok ( PI / 4.0 )
251+ } ;
252+ let phase_shift = z_phase_shift ( node_op_name, node_angle) ;
253+ Ok ( ( node_angle, phase_shift) )
236254 } else {
237- Err ( PyRuntimeError :: new_err ( format ! (
238- "Angle for operation {} is not defined" ,
239- node_op_name
240- ) ) )
241- } ;
242- total_angle += node_angle?;
243-
244- let Param :: Float ( new_phase) = node_op
245- . op
246- . definition ( node_op. params_view ( ) )
247- . unwrap ( )
248- . global_phase ( )
249- . clone ( )
250- else {
251- unreachable ! ( )
252- } ;
253- total_phase += new_phase
255+ match node_op_name {
256+ "t" => Ok ( ( FRAC_PI_4 , z_phase_shift ( "p" , FRAC_PI_4 ) ) ) ,
257+ "s" => Ok ( ( FRAC_PI_2 , z_phase_shift ( "p" , FRAC_PI_2 ) ) ) ,
258+ "z" => Ok ( ( PI , z_phase_shift ( "p" , PI ) ) ) ,
259+ "x" => Ok ( ( PI , FRAC_PI_2 ) ) ,
260+ _ => Err ( PyRuntimeError :: new_err ( format ! (
261+ "Angle for operation {node_op_name} is not defined"
262+ ) ) ) ,
263+ }
264+ } ?;
265+ total_angle += node_angle;
266+ total_phase += phase_shift;
254267 }
255268
256269 let new_op = match cancel_key. gate {
@@ -259,24 +272,21 @@ pub fn cancel_commutations(
259272 _ => unreachable ! ( ) ,
260273 } ;
261274
262- let gate_angle = euler_one_qubit_decomposer :: mod_2pi ( total_angle, 0. ) ;
275+ let pi_multiple = total_angle / PI ;
263276
264- let new_op_phase: f64 = if gate_angle. abs ( ) > _CUTOFF_PRECISION {
265- dag. insert_1q_on_incoming_qubit ( ( * new_op, & [ total_angle] ) , cancel_set[ 0 ] ) ;
266- let Param :: Float ( new_phase) = new_op
267- . definition ( & [ Param :: Float ( total_angle) ] )
268- . unwrap ( )
269- . global_phase ( )
270- . clone ( )
271- else {
272- unreachable ! ( ) ;
273- } ;
274- new_phase
277+ let mod4 = pi_multiple. rem_euclid ( 4. ) ;
278+ if mod4 < _CUTOFF_PRECISION || ( 4. - mod4) < _CUTOFF_PRECISION {
279+ // if the angle is close to a 4-pi multiple (from above or below), then the
280+ // operator is equal to the identity
281+ } else if ( mod4 - 2. ) . abs ( ) < _CUTOFF_PRECISION {
282+ // a 2-pi multiple has a phase of pi: RX(2pi) = RZ(2pi) = -I = I exp(i pi)
283+ total_phase -= PI ;
275284 } else {
276- 0.0
277- } ;
285+ // any other is not the identity and we add the gate
286+ dag. insert_1q_on_incoming_qubit ( ( * new_op, & [ total_angle] ) , cancel_set[ 0 ] ) ;
287+ }
278288
279- dag. add_global_phase ( & Param :: Float ( total_phase - new_op_phase ) ) ?;
289+ dag. add_global_phase ( & Param :: Float ( total_phase) ) ?;
280290
281291 for node in cancel_set {
282292 dag. remove_op_node ( * node) ;
0 commit comments