Skip to content

Commit 5dc7f35

Browse files
authored
Fix merge_interactions failing to convert into xmon gates (#696)
- Updated the cost metric to include presence-of-non-xmon - Converting gates to xmon before returning Fixes #490
1 parent 8750e8c commit 5dc7f35

File tree

3 files changed

+143
-117
lines changed

3 files changed

+143
-117
lines changed

cirq/google/merge_interactions.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,12 @@
1818

1919
import numpy as np
2020

21-
from cirq import ops
22-
from cirq.circuits import (
23-
Circuit,
24-
PointOptimizer,
25-
PointOptimizationSummary,
26-
)
21+
from cirq import circuits, ops
2722
from cirq.extension import Extensions
28-
from cirq.google.decompositions import two_qubit_matrix_to_native_gates
23+
from cirq.google import decompositions, convert_to_xmon_gates, xmon_gates
2924

3025

31-
class MergeInteractions(PointOptimizer):
26+
class MergeInteractions(circuits.PointOptimizer):
3227
"""Combines series of adjacent one and two-qubit gates operating on a pair
3328
of qubits."""
3429

@@ -40,74 +35,88 @@ def __init__(self,
4035
self.allow_partial_czs = allow_partial_czs
4136
self.extensions = extensions or Extensions()
4237

43-
def optimization_at(self, circuit, index, op):
38+
def optimization_at(self,
39+
circuit: circuits.Circuit,
40+
index: int,
41+
op: ops.Operation
42+
) -> Optional[circuits.PointOptimizationSummary]:
4443
if len(op.qubits) != 2:
4544
return None
4645

47-
interaction_count, indices, matrix = (
46+
old_operations, indices, matrix = (
4847
self._scan_two_qubit_ops_into_matrix(circuit, index, op.qubits))
49-
if interaction_count <= 1:
50-
return None
5148

5249
# Find a max-3-cz construction.
53-
operations = two_qubit_matrix_to_native_gates(
50+
new_operations = decompositions.two_qubit_matrix_to_native_gates(
5451
op.qubits[0],
5552
op.qubits[1],
5653
matrix,
5754
self.allow_partial_czs,
5855
self.tolerance)
5956

60-
# TODO: don't replace if there's no benefit in CZ depth.
57+
old_interaction_count = len([old_op for old_op in old_operations
58+
if len(old_op.qubits) == 2])
59+
new_interaction_count = len([new_op for new_op in new_operations
60+
if len(new_op.qubits) == 2])
61+
switch_to_new = False
62+
switch_to_new |= new_interaction_count < old_interaction_count
63+
switch_to_new |= any(not xmon_gates.XmonGate.is_xmon_op(old_op)
64+
for old_op in old_operations)
65+
if not switch_to_new:
66+
return None
67+
68+
converter = convert_to_xmon_gates.ConvertToXmonGates()
69+
new_xmon_operations = [converter.convert(new_op)
70+
for new_op in new_operations]
6171

62-
return PointOptimizationSummary(
72+
return circuits.PointOptimizationSummary(
6373
clear_span=max(indices) + 1 - index,
6474
clear_qubits=op.qubits,
65-
new_operations=operations)
75+
new_operations=new_xmon_operations)
6676

6777
def _op_to_matrix(self,
68-
op: ops.Operation,
78+
op: Optional[ops.Operation],
6979
qubits: Tuple[ops.QubitId, ...]
70-
) -> Optional[Tuple[np.ndarray, bool]]:
80+
) -> Optional[np.ndarray]:
7181
"""Determines the effect of an operation on the given qubits.
7282
73-
The operation must be a 1-qubit operation on one of the given qubits,
74-
or a 2-qubit operation on both of the given qubits. Also, the operation
75-
must have a known matrix. Otherwise None is returned.
83+
If the operation is a 1-qubit operation on one of the given qubits,
84+
or a 2-qubit operation on both of the given qubits, and also the
85+
operation has a known matrix, then a matrix is returned. Otherwise None
86+
is returned.
7687
7788
Args:
7889
op: The operation to understand.
7990
qubits: The qubits we care about. Order determines matrix tensor
8091
order.
8192
8293
Returns:
83-
None, or else a tuple containing a matrix equivalent to the effect
84-
of the operation and a boolean indicating if the operation is a
85-
2-qubit interaction.
94+
None, or else a matrix equivalent to the effect of the operation.
8695
"""
8796
q1, q2 = qubits
8897

8998
known = self.extensions.try_cast(ops.KnownMatrix, op)
90-
if known is None:
99+
if known is None or op is None:
91100
return None
92101
m = known.matrix()
93102

94103
if op.qubits == qubits:
95-
return m, True
104+
return m
96105
if op.qubits == (q2, q1):
97-
return MergeInteractions._flip_kron_order(m), True
106+
return MergeInteractions._flip_kron_order(m)
98107
if op.qubits == (q1,):
99-
return np.kron(m, np.eye(2)), False
108+
return np.kron(m, np.eye(2))
100109
if op.qubits == (q2,):
101-
return np.kron(np.eye(2), m), False
110+
return np.kron(np.eye(2), m)
102111

103112
return None
104113

105114
def _scan_two_qubit_ops_into_matrix(
106115
self,
107-
circuit: Circuit,
116+
circuit: circuits.Circuit,
108117
index: Optional[int],
109118
qubits: Tuple[ops.QubitId, ...]
110-
) -> Tuple[int, List[int], np.ndarray]:
119+
) -> Tuple[List[ops.Operation], List[int], np.ndarray]:
111120
"""Accumulates operations affecting the given pair of qubits.
112121
113122
The scan terminates when it hits the end of the circuit, finds an
@@ -121,37 +130,36 @@ def _scan_two_qubit_ops_into_matrix(
121130
122131
Returns:
123132
A tuple containing:
124-
0. The number of 2-qubit operations that were scanned.
133+
0. The operations.
125134
1. The moment indices those operations were on.
126135
2. A matrix equivalent to the effect of the scanned operations.
127136
"""
128137

129138
product = np.eye(4, dtype=np.complex128)
130-
interaction_count = 0
139+
all_operations = []
131140
touched_indices = []
132141

133142
while index is not None:
134-
operations = {circuit.operation_at(q, index) for q in qubits}
143+
operations = list({circuit.operation_at(q, index) for q in qubits})
135144
op_data = [
136145
self._op_to_matrix(op, qubits)
137146
for op in operations
138-
if op
139147
]
140148

141149
# Stop at any non-constant or non-local interaction.
142150
if any(e is None for e in op_data):
143151
break
144-
present_op_data = cast(List[Tuple[np.ndarray, bool]], op_data)
152+
present_ops = [op for op in operations if op]
153+
present_op_data = cast(List[np.ndarray], op_data)
145154

146-
for op_mat, interacts in present_op_data:
155+
for op_mat in present_op_data:
147156
product = np.dot(op_mat, product)
148-
if interacts:
149-
interaction_count += 1
157+
all_operations.extend(present_ops)
150158

151159
touched_indices.append(index)
152160
index = circuit.next_moment_operating_on(qubits, index + 1)
153161

154-
return interaction_count, touched_indices, product
162+
return all_operations, touched_indices, product
155163

156164
@staticmethod
157165
def _flip_kron_order(mat4x4: np.ndarray) -> np.ndarray:

cirq/google/merge_interactions_test.py

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,134 +12,140 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from cirq import testing
16-
from cirq import circuits
17-
from cirq import ops
18-
from cirq.google import ExpZGate, MergeInteractions, MergeRotations
19-
from cirq.value import Symbol
15+
import cirq
16+
import cirq.google as cg
2017

2118

22-
def assert_optimizes(before, after):
23-
opt = MergeInteractions()
24-
opt.optimize_circuit(before)
19+
def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
20+
actual = cirq.Circuit(before)
21+
opt = cg.MergeInteractions()
22+
opt.optimize_circuit(actual)
2523

2624
# Ignore differences that would be caught by follow-up optimizations.
2725
followup_optimizations = [
28-
MergeRotations(),
29-
circuits.DropNegligible(),
30-
circuits.DropEmptyMoments()
26+
cg.MergeRotations(),
27+
cg.EjectFullW(),
28+
cg.EjectZ(),
29+
cirq.DropNegligible(),
30+
cirq.DropEmptyMoments()
3131
]
3232
for post in followup_optimizations:
33-
post.optimize_circuit(before)
34-
post.optimize_circuit(after)
33+
post.optimize_circuit(actual)
34+
post.optimize_circuit(expected)
3535

36-
if before != after:
36+
if actual != expected:
3737
# coverage: ignore
3838
print('ACTUAL')
39-
print(before)
39+
print(actual)
4040
print('EXPECTED')
41-
print(after)
42-
assert before == after
41+
print(expected)
42+
assert actual == expected
4343

4444

4545
def assert_optimization_not_broken(circuit):
4646
"""Check that the unitary matrix for the input circuit is the same (up to
4747
global phase and rounding error) as the unitary matrix of the optimized
4848
circuit."""
4949
u_before = circuit.to_unitary_matrix()
50-
MergeInteractions().optimize_circuit(circuit)
50+
cg.MergeInteractions().optimize_circuit(circuit)
5151
u_after = circuit.to_unitary_matrix()
5252

53-
testing.assert_allclose_up_to_global_phase(u_before, u_after, atol=1e-8)
53+
cirq.testing.assert_allclose_up_to_global_phase(
54+
u_before, u_after, atol=1e-8)
5455

5556

5657
def test_clears_paired_cnot():
57-
q0 = ops.QubitId()
58-
q1 = ops.QubitId()
58+
a, b = cirq.LineQubit.range(2)
5959
assert_optimizes(
60-
before=circuits.Circuit([
61-
circuits.Moment([ops.CNOT(q0, q1)]),
62-
circuits.Moment([ops.CNOT(q0, q1)]),
60+
before=cirq.Circuit([
61+
cirq.Moment([cirq.CNOT(a, b)]),
62+
cirq.Moment([cirq.CNOT(a, b)]),
6363
]),
64-
after=circuits.Circuit())
64+
expected=cirq.Circuit())
6565

6666

6767
def test_ignores_czs_separated_by_parameterized():
68-
q0 = ops.QubitId()
69-
q1 = ops.QubitId()
68+
a, b = cirq.LineQubit.range(2)
7069
assert_optimizes(
71-
before=circuits.Circuit([
72-
circuits.Moment([ops.CZ(q0, q1)]),
73-
circuits.Moment([ExpZGate(
74-
half_turns=Symbol('boo'))(q0)]),
75-
circuits.Moment([ops.CZ(q0, q1)]),
70+
before=cirq.Circuit([
71+
cirq.Moment([cirq.CZ(a, b)]),
72+
cirq.Moment([cg.ExpZGate(
73+
half_turns=cirq.Symbol('boo'))(a)]),
74+
cirq.Moment([cirq.CZ(a, b)]),
7675
]),
77-
after=circuits.Circuit([
78-
circuits.Moment([ops.CZ(q0, q1)]),
79-
circuits.Moment([ExpZGate(
80-
half_turns=Symbol('boo'))(q0)]),
81-
circuits.Moment([ops.CZ(q0, q1)]),
76+
expected=cirq.Circuit([
77+
cirq.Moment([cirq.CZ(a, b)]),
78+
cirq.Moment([cg.ExpZGate(
79+
half_turns=cirq.Symbol('boo'))(a)]),
80+
cirq.Moment([cirq.CZ(a, b)]),
8281
]))
8382

8483

8584
def test_ignores_czs_separated_by_outer_cz():
86-
q00 = ops.QubitId()
87-
q01 = ops.QubitId()
88-
q10 = ops.QubitId()
85+
q00 = cirq.GridQubit(0, 0)
86+
q01 = cirq.GridQubit(0, 1)
87+
q10 = cirq.GridQubit(1, 0)
8988
assert_optimizes(
90-
before=circuits.Circuit([
91-
circuits.Moment([ops.CZ(q00, q01)]),
92-
circuits.Moment([ops.CZ(q00, q10)]),
93-
circuits.Moment([ops.CZ(q00, q01)]),
89+
before=cirq.Circuit([
90+
cirq.Moment([cirq.CZ(q00, q01)]),
91+
cirq.Moment([cirq.CZ(q00, q10)]),
92+
cirq.Moment([cirq.CZ(q00, q01)]),
9493
]),
95-
after=circuits.Circuit([
96-
circuits.Moment([ops.CZ(q00, q01)]),
97-
circuits.Moment([ops.CZ(q00, q10)]),
98-
circuits.Moment([ops.CZ(q00, q01)]),
94+
expected=cirq.Circuit([
95+
cirq.Moment([cirq.CZ(q00, q01)]),
96+
cirq.Moment([cirq.CZ(q00, q10)]),
97+
cirq.Moment([cirq.CZ(q00, q01)]),
9998
]))
10099

101100

102101
def test_cnots_separated_by_single_gates_correct():
103-
q0 = ops.QubitId()
104-
q1 = ops.QubitId()
102+
a, b = cirq.LineQubit.range(2)
105103
assert_optimization_not_broken(
106-
circuits.Circuit.from_ops(
107-
ops.CNOT(q0, q1),
108-
ops.H(q1),
109-
ops.CNOT(q0, q1),
104+
cirq.Circuit.from_ops(
105+
cirq.CNOT(a, b),
106+
cirq.H(b),
107+
cirq.CNOT(a, b),
110108
))
111109

112110

113111
def test_czs_separated_by_single_gates_correct():
114-
q0 = ops.QubitId()
115-
q1 = ops.QubitId()
112+
a, b = cirq.LineQubit.range(2)
116113
assert_optimization_not_broken(
117-
circuits.Circuit.from_ops(
118-
ops.CZ(q0, q1),
119-
ops.X(q1),
120-
ops.X(q1),
121-
ops.X(q1),
122-
ops.CZ(q0, q1),
114+
cirq.Circuit.from_ops(
115+
cirq.CZ(a, b),
116+
cirq.X(b),
117+
cirq.X(b),
118+
cirq.X(b),
119+
cirq.CZ(a, b),
123120
))
124121

125122

126123
def test_inefficient_circuit_correct():
127124
t = 0.1
128125
v = 0.11
129-
q0 = ops.QubitId()
130-
q1 = ops.QubitId()
126+
a, b = cirq.LineQubit.range(2)
131127
assert_optimization_not_broken(
132-
circuits.Circuit.from_ops(
133-
ops.H(q1),
134-
ops.CNOT(q0, q1),
135-
ops.H(q1),
136-
ops.CNOT(q0, q1),
137-
ops.CNOT(q1, q0),
138-
ops.H(q0),
139-
ops.CNOT(q0, q1),
140-
ops.Z(q0)**t, ops.Z(q1)**-t,
141-
ops.CNOT(q0, q1),
142-
ops.H(q0), ops.Z(q1)**v,
143-
ops.CNOT(q0, q1),
144-
ops.Z(q0)**-v, ops.Z(q1)**-v,
128+
cirq.Circuit.from_ops(
129+
cirq.H(b),
130+
cirq.CNOT(a, b),
131+
cirq.H(b),
132+
cirq.CNOT(a, b),
133+
cirq.CNOT(b, a),
134+
cirq.H(a),
135+
cirq.CNOT(a, b),
136+
cirq.Z(a)**t, cirq.Z(b)**-t,
137+
cirq.CNOT(a, b),
138+
cirq.H(a), cirq.Z(b)**v,
139+
cirq.CNOT(a, b),
140+
cirq.Z(a)**-v, cirq.Z(b)**-v,
145141
))
142+
143+
144+
def test_optimizes_single_iswap():
145+
a, b = cirq.LineQubit.range(2)
146+
c = cirq.Circuit.from_ops(cirq.ISWAP(a, b))
147+
assert_optimization_not_broken(c)
148+
cg.MergeInteractions().optimize_circuit(c)
149+
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2
150+
assert all(cg.XmonGate.is_xmon_op(op)
151+
for op in c.all_operations())

0 commit comments

Comments
 (0)