Skip to content

Commit d67c818

Browse files
Fix global phase update in BasisTranslator Pass (Qiskit#14078)
* add missing branch for global phase update in basis_translator * Adding test * reno
1 parent cd05386 commit d67c818

File tree

3 files changed

+52
-29
lines changed

3 files changed

+52
-29
lines changed

crates/accelerate/src/basis/basis_translator/mod.rs

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -762,38 +762,46 @@ fn replace_node(
762762
)?;
763763
}
764764

765-
if let Param::ParameterExpression(old_phase) = target_dag.global_phase() {
766-
let bound_old_phase = old_phase.bind(py);
767-
let bind_dict = PyDict::new(py);
768-
for key in target_dag.global_phase().iter_parameters(py)? {
769-
let key = key?;
770-
bind_dict.set_item(&key, parameter_map.get_item(&key)?)?;
771-
}
772-
let mut new_phase: Bound<PyAny>;
773-
if bind_dict.values().iter().any(|param| {
774-
param
775-
.is_instance(PARAMETER_EXPRESSION.get_bound(py))
776-
.is_ok_and(|x| x)
777-
}) {
778-
new_phase = bound_old_phase.clone();
779-
for key_val in bind_dict.items() {
780-
new_phase =
781-
new_phase.call_method1(intern!(py, "assign"), key_val.downcast()?)?;
765+
match target_dag.global_phase() {
766+
Param::ParameterExpression(old_phase) => {
767+
let bound_old_phase = old_phase.bind(py);
768+
let bind_dict = PyDict::new(py);
769+
for key in target_dag.global_phase().iter_parameters(py)? {
770+
let key = key?;
771+
bind_dict.set_item(&key, parameter_map.get_item(&key)?)?;
782772
}
783-
} else {
784-
new_phase = bound_old_phase.call_method1(intern!(py, "bind"), (bind_dict,))?;
785-
}
786-
if !new_phase.getattr(intern!(py, "parameters"))?.is_truthy()? {
787-
new_phase = new_phase.call_method0(intern!(py, "numeric"))?;
788-
if new_phase.is_instance(&PyComplex::type_object(py))? {
789-
return Err(TranspilerError::new_err(format!(
790-
"Global phase must be real, but got {}",
791-
new_phase.repr()?
792-
)));
773+
let mut new_phase: Bound<PyAny>;
774+
if bind_dict.values().iter().any(|param| {
775+
param
776+
.is_instance(PARAMETER_EXPRESSION.get_bound(py))
777+
.is_ok_and(|x| x)
778+
}) {
779+
new_phase = bound_old_phase.clone();
780+
for key_val in bind_dict.items() {
781+
new_phase =
782+
new_phase.call_method1(intern!(py, "assign"), key_val.downcast()?)?;
783+
}
784+
} else {
785+
new_phase = bound_old_phase.call_method1(intern!(py, "bind"), (bind_dict,))?;
793786
}
787+
if !new_phase.getattr(intern!(py, "parameters"))?.is_truthy()? {
788+
new_phase = new_phase.call_method0(intern!(py, "numeric"))?;
789+
if new_phase.is_instance(&PyComplex::type_object(py))? {
790+
return Err(TranspilerError::new_err(format!(
791+
"Global phase must be real, but got {}",
792+
new_phase.repr()?
793+
)));
794+
}
795+
}
796+
let new_phase: Param = new_phase.extract()?;
797+
dag.add_global_phase(&new_phase)?;
794798
}
795-
let new_phase: Param = new_phase.extract()?;
796-
dag.add_global_phase(&new_phase)?;
799+
800+
Param::Float(_) => {
801+
dag.add_global_phase(target_dag.global_phase())?;
802+
}
803+
804+
_ => {}
797805
}
798806
}
799807

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed a problem in :class:`.BasisTranslator` transpiler pass, where the global
5+
phase of the DAG was not updated correctly.
6+
Fixed `#14074 <https://github.com/Qiskit/qiskit/issues/14074>`__.

test/python/transpiler/test_basis_translator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,15 @@ def test_global_phase(self):
10081008
)
10091009
self.assertEqual(Operator(dag_to_circuit(out_dag)), Operator(expected))
10101010

1011+
def test_rx_to_rz(self):
1012+
"""Verify global phase is updated correctly in basis translation.
1013+
See https://github.com/Qiskit/qiskit/issues/14074."""
1014+
theta = 0.5 * pi
1015+
circ = QuantumCircuit(1)
1016+
circ.rx(theta, 0)
1017+
out_circ = BasisTranslator(std_eqlib, ["h", "rz"])(circ)
1018+
self.assertEqual(Operator(circ), Operator(out_circ))
1019+
10111020
def test_skip_target_basis_equivalences_1(self):
10121021
"""Test that BasisTranslator skips gates in the target_basis - #6085"""
10131022
circ = QuantumCircuit()

0 commit comments

Comments
 (0)