Skip to content

Commit 5b1d88a

Browse files
Fix pickling of SabreSwap (Qiskit#15074) (Qiskit#15076)
This commit fixes the support for pickling SabreSwap. In Qiskit#14317 a new RoutingTarget rust struct was added to encapsulate the target details, and this was exposed to Python so that the Python class for the transpiler pass was able to reuse the object between multiple runs. However, this new type didn't implement pickle support and it would cause a failure when trying to pickle a SabreSwap instance that had a routing target populated. This commit fixes this oversight and implements pickle support for the RoutingTarget so that SabreSwap can always be pickled. Fixes Qiskit#15071 (cherry picked from commit a816e64) Co-authored-by: Matthew Treinish <[email protected]>
1 parent c6d52a0 commit 5b1d88a

File tree

4 files changed

+117
-2
lines changed

4 files changed

+117
-2
lines changed

crates/transpiler/src/passes/sabre/neighbors.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ use rustworkx_core::petgraph::visit::*;
2626
/// small that a linear search is faster).
2727
#[derive(Clone, Debug)]
2828
pub struct Neighbors {
29-
neighbors: Vec<PhysicalQubit>,
30-
partition: Vec<usize>,
29+
pub(crate) neighbors: Vec<PhysicalQubit>,
30+
pub(crate) partition: Vec<usize>,
3131
}
3232
impl Neighbors {
3333
/// Construct the neighbor adjacency table from a coupling graph.

crates/transpiler/src/passes/sabre/route.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::num::NonZero;
1717

1818
use numpy::{PyArray2, ToPyArray};
1919
use pyo3::prelude::*;
20+
use pyo3::types::PyDict;
2021
use pyo3::Python;
2122

2223
use hashbrown::HashSet;
@@ -315,6 +316,50 @@ impl RoutingTarget {
315316
pub struct PyRoutingTarget(pub Option<RoutingTarget>);
316317
#[pymethods]
317318
impl PyRoutingTarget {
319+
#[new]
320+
fn py_new() -> Self {
321+
PyRoutingTarget(None)
322+
}
323+
324+
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
325+
let out_dict = PyDict::new(py);
326+
out_dict.set_item(
327+
"neighbors",
328+
self.0.as_ref().map(|x| x.neighbors.neighbors.clone()),
329+
)?;
330+
out_dict.set_item(
331+
"partition",
332+
self.0.as_ref().map(|x| x.neighbors.partition.clone()),
333+
)?;
334+
Ok(out_dict)
335+
}
336+
337+
fn __setstate__(&mut self, value: Bound<PyDict>) -> PyResult<()> {
338+
let neighbors_array: Option<Vec<PhysicalQubit>> = value
339+
.get_item("neighbors")?
340+
.map(|x| x.extract())
341+
.transpose()?;
342+
if let Some(neighbors_array) = neighbors_array {
343+
let partition: Vec<usize> = value
344+
.get_item("partition")?
345+
.map(|x| x.extract())
346+
.transpose()?
347+
.unwrap();
348+
let neighbors = Neighbors {
349+
neighbors: neighbors_array,
350+
partition,
351+
};
352+
if self.0.is_none() {
353+
self.0 = Some(RoutingTarget::from_neighbors(neighbors));
354+
} else {
355+
self.0.as_mut().unwrap().distance =
356+
distance_matrix(&neighbors, usize::MAX, f64::NAN);
357+
self.0.as_mut().unwrap().neighbors = neighbors;
358+
}
359+
}
360+
Ok(())
361+
}
362+
318363
#[staticmethod]
319364
pub(crate) fn from_target(target: &Target) -> PyResult<Self> {
320365
let coupling = match target.coupling_graph() {
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed an issue with :mod:`pickle` support for the :class:`.SabreSwap` where a
5+
:class:`.SabreSwap` instance would error when being pickled after the
6+
:meth:`.SabreSwap.run` method was run.
7+
Fixed `#15071 <https://github.com/Qiskit/qiskit/issues/15071>`__.

test/python/transpiler/test_sabre_swap.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import unittest
1616
import itertools
17+
import pickle
18+
from copy import deepcopy
19+
import io
1720

1821
import ddt
1922
import numpy.random
@@ -101,6 +104,66 @@ def looping_circuit(uphill_swaps=1, additional_local_minimum_gates=0):
101104
class TestSabreSwap(QiskitTestCase):
102105
"""Tests the SabreSwap pass."""
103106

107+
def test_sabre_swap_pickle(self):
108+
"""Test the pass can be pickled."""
109+
coupling = CouplingMap.from_ring(5)
110+
target = Target.from_configuration(["u", "cx"], coupling_map=coupling)
111+
sabre_swap = SabreSwap(target, "lookahead", seed=42, trials=1024)
112+
with io.BytesIO() as buf:
113+
pickle.dump(sabre_swap, buf)
114+
buf.seek(0)
115+
output = pickle.load(buf)
116+
self.assertIsInstance(output, SabreSwap)
117+
self.assertIsNone(output._routing_target)
118+
self.assertEqual(sabre_swap.heuristic, output.heuristic)
119+
self.assertEqual(sabre_swap.trials, output.trials)
120+
self.assertEqual(sabre_swap.fake_run, output.fake_run)
121+
122+
test_circuit = QuantumCircuit(5)
123+
test_circuit.cx(0, 1)
124+
test_circuit.cx(0, 2)
125+
test_circuit.cx(0, 3)
126+
test_circuit.cx(0, 4)
127+
before_result = sabre_swap(test_circuit)
128+
with io.BytesIO() as buf:
129+
pickle.dump(sabre_swap, buf)
130+
buf.seek(0)
131+
output = pickle.load(buf)
132+
self.assertIsInstance(output, SabreSwap)
133+
self.assertIsNotNone(output._routing_target)
134+
self.assertEqual(sabre_swap.heuristic, output.heuristic)
135+
self.assertEqual(sabre_swap.trials, output.trials)
136+
self.assertEqual(sabre_swap.fake_run, output.fake_run)
137+
after_result = output(test_circuit)
138+
self.assertEqual(before_result, after_result)
139+
140+
def test_sabre_swap_deepcopy(self):
141+
"""Test the pass can be deepcopied."""
142+
coupling = CouplingMap.from_ring(5)
143+
target = Target.from_configuration(["u", "cx"], coupling_map=coupling)
144+
sabre_swap = SabreSwap(target, "lookahead", seed=42, trials=1024)
145+
output = deepcopy(sabre_swap)
146+
self.assertIsInstance(output, SabreSwap)
147+
self.assertIsNone(output._routing_target)
148+
self.assertEqual(sabre_swap.heuristic, output.heuristic)
149+
self.assertEqual(sabre_swap.trials, output.trials)
150+
self.assertEqual(sabre_swap.fake_run, output.fake_run)
151+
152+
test_circuit = QuantumCircuit(5)
153+
test_circuit.cx(0, 1)
154+
test_circuit.cx(0, 2)
155+
test_circuit.cx(0, 3)
156+
test_circuit.cx(0, 4)
157+
before_result = sabre_swap(test_circuit)
158+
output = deepcopy(sabre_swap)
159+
self.assertIsInstance(output, SabreSwap)
160+
self.assertIsNotNone(output._routing_target)
161+
self.assertEqual(sabre_swap.heuristic, output.heuristic)
162+
self.assertEqual(sabre_swap.trials, output.trials)
163+
self.assertEqual(sabre_swap.fake_run, output.fake_run)
164+
after_result = output(test_circuit)
165+
self.assertEqual(before_result, after_result)
166+
104167
def test_trivial_case(self):
105168
"""Test that an already mapped circuit is unchanged.
106169
┌───┐┌───┐

0 commit comments

Comments
 (0)