Skip to content

Commit 360508b

Browse files
authored
Create ConvertToXmonGates optimizer with support for 2-qubit operations (#256)
1 parent a2a5fc3 commit 360508b

File tree

8 files changed

+136
-68
lines changed

8 files changed

+136
-68
lines changed

cirq/google/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from cirq.google.convert_to_xmon_gates import (
1616
ConvertToXmonGates,
17-
xmon_gate_ext,
1817
)
1918
from cirq.google.decompositions import (
2019
controlled_op_to_native_gates,
@@ -39,6 +38,9 @@
3938
from cirq.google.xmon_device import (
4039
XmonDevice,
4140
)
41+
from cirq.google.xmon_gate_extensions import (
42+
xmon_gate_ext,
43+
)
4244
from cirq.google.xmon_gates import (
4345
Exp11Gate,
4446
ExpWGate,

cirq/google/convert_to_xmon_gates.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,99 @@
1313
# limitations under the License.
1414

1515
from cirq import ops
16-
from cirq.circuits.optimization_pass import PointOptimizer, \
17-
PointOptimizationSummary
16+
from cirq.circuits.optimization_pass import (
17+
PointOptimizationSummary,
18+
PointOptimizer,
19+
)
20+
from cirq.extension import Extensions
21+
from cirq.google.decompositions import (
22+
single_qubit_matrix_to_native_gates,
23+
two_qubit_matrix_to_native_gates,
24+
)
1825
from cirq.google.xmon_gate_extensions import xmon_gate_ext
1926
from cirq.google.xmon_gates import XmonGate
2027

2128

2229
class ConvertToXmonGates(PointOptimizer):
23-
"""Pointwise converts a circuit to XmonGates if possible."""
30+
"""Attempts to convert strange gates into XmonGates.
2431
25-
def __init__(self, ignore_cast_failures=True):
26-
self.ignore_cast_failures = ignore_cast_failures
32+
First, checks if the given extensions are able to cast the gate into an
33+
XmonGate instance.
34+
35+
Second, checks if the given extensions are able to cast the gate into a
36+
CompositeGate instance. If so, recurses on the decomposition.
37+
38+
Third, checks if the given extensions are able to cast the gate into a
39+
KnownMatrixGate instance. If so, and the gate is a 1-qubit or 2-qubit
40+
gate, then performs circuit synthesis of the operation.
41+
42+
Fourth, if ignore_failures is set, gives up and returns the gate unchanged.
43+
Otherwise raises a TypeError.
44+
"""
45+
46+
def __init__(self,
47+
extensions: Extensions=None,
48+
ignore_failures=False) -> None:
49+
"""
50+
Args:
51+
extensions: The extensions instance to use when trying to
52+
cast gates to known types. Defaults to the standard xmon
53+
gate extension.
54+
ignore_failures: If set, gates that fail to convert are forwarded
55+
unchanged. If not set, conversion failures raise a TypeError.
56+
"""
57+
self.extensions = extensions or xmon_gate_ext
58+
self.ignore_failures = ignore_failures
59+
60+
def _convert_one(self, op: ops.Operation) -> ops.OP_TREE:
61+
# Already supported?
62+
if isinstance(op.gate, XmonGate):
63+
return op
64+
65+
# Maybe we know how to wrap it?
66+
xmon = self.extensions.try_cast(op.gate, XmonGate)
67+
if xmon is not None:
68+
return xmon.on(*op.qubits)
69+
70+
# Provides a decomposition?
71+
composite = self.extensions.try_cast(op.gate, ops.CompositeGate)
72+
if composite is not None:
73+
return composite.default_decompose(op.qubits)
74+
75+
# Known matrix?
76+
mat = self.extensions.try_cast(op.gate, ops.KnownMatrixGate)
77+
if mat is not None and len(op.qubits) == 1:
78+
gates = single_qubit_matrix_to_native_gates(mat.matrix())
79+
return [g.on(op.qubits[0]) for g in gates]
80+
if mat is not None and len(op.qubits) == 2:
81+
return two_qubit_matrix_to_native_gates(
82+
op.qubits[0],
83+
op.qubits[1],
84+
mat.matrix(),
85+
allow_partial_czs=True)
86+
87+
# Just let it be?
88+
if self.ignore_failures:
89+
return op
90+
91+
raise TypeError("Don't know how to work with {!r}. "
92+
"It isn't an XmonGate, "
93+
"1-qubit KnownMatrixGate, "
94+
"2-qubit KnownMatrixGate, "
95+
"or CompositeGate.".format(op))
96+
97+
def convert(self, op: ops.Operation) -> ops.OP_TREE:
98+
converted = self._convert_one(op)
99+
if converted is op:
100+
return converted
101+
return [self.convert(e) for e in ops.flatten_op_tree(converted)]
27102

28103
def optimization_at(self, circuit, index, op):
29-
if self.ignore_cast_failures:
30-
c = xmon_gate_ext.try_cast(op.gate, XmonGate)
31-
if c is None:
32-
return None
33-
else:
34-
c = xmon_gate_ext.cast(op.gate, XmonGate)
35-
36-
if c is op.gate:
104+
converted = self.convert(op)
105+
if converted is op:
37106
return None
38107

39108
return PointOptimizationSummary(
40109
clear_span=1,
41-
new_operations=ops.Operation(c, op.qubits),
110+
new_operations=converted,
42111
clear_qubits=op.qubits)

cirq/google/eject_z_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020

2121
def assert_optimizes(before, after):
2222
pre_optimizations = [
23-
ConvertToXmonGates(),
23+
ConvertToXmonGates(ignore_failures=True)
2424
]
2525
followup_optimizations = [
26-
ConvertToXmonGates(),
26+
ConvertToXmonGates(ignore_failures=True),
2727
circuits.DropEmptyMoments()
2828
]
2929

cirq/google/engine/engine.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from google.protobuf.json_format import MessageToDict
3030

3131
from cirq.api.google.v1 import program_pb2
32-
from cirq.circuits import Circuit, ExpandComposite
32+
from cirq.circuits import Circuit
3333
from cirq.circuits.drop_empty_moments import DropEmptyMoments
3434
from cirq.devices import Device
3535
from cirq.google.convert_to_xmon_gates import ConvertToXmonGates
@@ -188,14 +188,9 @@ def run_sweep(self,
188188
if not device:
189189
raise TypeError('device is required when running a circuit')
190190
# Convert to a schedule.
191-
expand = ExpandComposite()
192-
convert = ConvertToXmonGates(ignore_cast_failures=False)
193-
drop = DropEmptyMoments()
194-
195191
circuit_copy = Circuit(program.moments)
196-
expand.optimize_circuit(circuit_copy)
197-
convert.optimize_circuit(circuit_copy)
198-
drop.optimize_circuit(circuit_copy)
192+
ConvertToXmonGates().optimize_circuit(circuit_copy)
193+
DropEmptyMoments().optimize_circuit(circuit_copy)
199194

200195
schedule = moment_by_moment_schedule(device, circuit_copy)
201196
elif isinstance(program, Schedule):

cirq/google/sim/xmon_simulator.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,21 @@
2525
results = sim.run(circuit)
2626
"""
2727

28-
import functools
29-
import math
30-
3128
from collections import defaultdict, Iterable
32-
from typing import Tuple # pylint: disable=unused-import
29+
import math
3330
from typing import Dict, Iterator, List, Sequence, Union, cast
31+
from typing import Tuple # pylint: disable=unused-import
3432

33+
import functools
3534
import numpy as np
3635

37-
from cirq.circuits import Circuit, ExpandComposite
36+
from cirq.circuits import Circuit
3837
from cirq.circuits.drop_empty_moments import DropEmptyMoments
39-
from cirq.ops import raw_types
40-
from cirq.schedules import Schedule
4138
from cirq.google import xmon_gates
4239
from cirq.google.convert_to_xmon_gates import ConvertToXmonGates
4340
from cirq.google.sim.xmon_stepper import Stepper
41+
from cirq.ops import raw_types
42+
from cirq.schedules import Schedule
4443
from cirq.study import ParamResolver, Sweep, Sweepable, TrialResult
4544

4645

@@ -85,7 +84,7 @@ class SimulatorTrialResult(TrialResult):
8584
Attributes:
8685
measurements: A dictionary from measurement gate key to measurement
8786
results ordered by the qubits acted upon by the measurement gate.
88-
final_state: The final state (wave function) of the system after
87+
final_states: The final states (wave function) of the system after
8988
the trial finishes.
9089
"""
9190

@@ -133,10 +132,10 @@ def run(
133132
a canonical ordering of the qubits. This canonical ordering
134133
is used to define the wave function.
135134
initial_state: If an int, the state is set to the computational
136-
basis state corresponding corresponding to this state. Otherwise
137-
if this is a np.ndarray it is the full initial state. In this
138-
case it must be the correct size, be normalized (an L2 norm of
139-
1), and have a dtype of np.complex64.
135+
basis state corresponding corresponding to this state.
136+
Otherwise if this is a np.ndarray it is the full initial
137+
state. In this case it must be the correct size, be normalized
138+
(an L2 norm of 1), and have a dtype of np.complex64.
140139
141140
Returns:
142141
Results for this run.
@@ -164,10 +163,10 @@ def run_sweep(
164163
a canonical ordering of the qubits. This canonical ordering
165164
is used to define the wave function.
166165
initial_state: If an int, the state is set to the computational
167-
basis state corresponding corresponding to this state. Otherwise
168-
if this is a np.ndarray it is the full initial state. In this
169-
case it must be the correct size, be normalized (an L2 norm of
170-
1), and have a dtype of np.complex64.
166+
basis state corresponding corresponding to this state.
167+
Otherwise if this is a np.ndarray it is the full initial state.
168+
In this case it must be the correct size, be normalized (an L2
169+
norm of 1), and have a dtype of np.complex64.
171170
172171
Returns:
173172
List of trial results for this run, one for each possible parameter
@@ -230,10 +229,10 @@ def moment_steps(
230229
a canonical ordering of the qubits. This canonical ordering
231230
is used to define the wave function.
232231
initial_state: If an int, the state is set to the computational
233-
basis state corresponding corresponding to this state. Otherwise
234-
if this is a np.ndarray it is the full initial state. In this
235-
case it must be the correct size, be normalized (an L2 norm of
236-
1), and have a dtype of np.complex64.
232+
basis state corresponding corresponding to this state.
233+
Otherwise if this is a np.ndarray it is the full initial state.
234+
In this case it must be the correct size, be normalized (an L2
235+
norm of 1), and have a dtype of np.complex64.
237236
param_resolver: A ParamResolver for determining values of
238237
Symbols.
239238
@@ -285,21 +284,16 @@ def simulator_iterator(
285284
qubit_map = {q: i for i, q in enumerate(qubits)}
286285

287286
# TODO: Use one optimization pass.
288-
expand = ExpandComposite()
289-
convert = ConvertToXmonGates(ignore_cast_failures=False)
290-
drop = DropEmptyMoments()
291-
292287
circuit_copy = Circuit(circuit.moments)
293-
expand.optimize_circuit(circuit_copy)
294-
convert.optimize_circuit(circuit_copy)
295-
drop.optimize_circuit(circuit_copy)
288+
ConvertToXmonGates().optimize_circuit(circuit_copy)
289+
DropEmptyMoments().optimize_circuit(circuit_copy)
296290
validate_unique_measurement_keys(circuit_copy)
297291

298-
with Stepper(
299-
num_qubits=len(qubits),
300-
num_prefix_qubits=options.num_prefix_qubits,
301-
initial_state=initial_state,
302-
min_qubits_before_shard=options.min_qubits_before_shard) as stepper:
292+
with Stepper(num_qubits=len(qubits),
293+
num_prefix_qubits=options.num_prefix_qubits,
294+
initial_state=initial_state,
295+
min_qubits_before_shard=options.min_qubits_before_shard
296+
) as stepper:
303297
for moment in circuit_copy.moments:
304298
measurements = defaultdict(list) # type: Dict[str, List[bool]]
305299
phase_map = {} # type: Dict[Tuple[int, ...], float]
@@ -330,9 +324,8 @@ def simulator_iterator(
330324
result = not result
331325
measurements[gate.key].append(result)
332326
else:
333-
raise TypeError(
334-
'Gate %s is not a gate supported by the xmon simulator.'
335-
% gate)
327+
raise TypeError('{!r} is not supported by the '
328+
'xmon simulator.'.format(gate))
336329
stepper.simulate_phases(phase_map)
337330
yield StepResult(stepper, qubit_map, measurements)
338331

@@ -414,9 +407,9 @@ def set_state(self, state: Union[int, np.ndarray]):
414407
def merge(a: 'StepResult', b: 'StepResult') -> 'StepResult':
415408
"""Merges measurement results of last_result into a new Result.
416409
417-
The measurement results are merges such that measurements with duplicate
418-
keys have the results of last_result before those of this objects
419-
results.
410+
The measurement results are merges such that measurements with
411+
duplicate keys have the results of last_result before those of this
412+
objects' results.
420413
421414
Args:
422415
a: First result to merge.
@@ -425,7 +418,7 @@ def merge(a: 'StepResult', b: 'StepResult') -> 'StepResult':
425418
Returns:
426419
A new StepResult with merged measurements.
427420
"""
428-
new_measurements = {} # type: Dict[str, np.ndarray]
421+
new_measurements = {} # type: Dict[str, list]
429422
for d in [a.measurements, b.measurements]:
430423
for key, results in d.items():
431424
if key not in new_measurements:

cirq/google/sim/xmon_simulator_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222
import numpy as np
2323
import pytest
2424

25-
2625
from cirq.circuits import Circuit
27-
from cirq.linalg import allclose_up_to_global_phase
2826
from cirq.devices import UnconstrainedDevice
2927
from cirq.google import (
3028
ExpWGate, ExpZGate, Exp11Gate, XmonMeasurementGate, XmonQubit,
3129
)
3230
from cirq.google.sim import xmon_simulator
31+
from cirq.linalg import allclose_up_to_global_phase
3332
from cirq.ops import op_tree
3433
from cirq.ops import raw_types
3534
from cirq.ops.common_gates import CNOT, H, X, Y, Z, CZ

cirq/google/sim/xmon_stepper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def reset_state(self, reset_state):
256256
args.append(kwargs)
257257
self._pool.map(_reset_state, args)
258258

259-
def simulate_phases(self, phase_map: Dict[Tuple[int], float]):
259+
def simulate_phases(self, phase_map: Dict[Tuple[int, ...], float]):
260260
"""Simulate a set of phase gates on the xmon architecture.
261261
262262
Args:

cirq/google/xmon_gates.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def to_proto(self, *qubits):
9999
op.measurement.key = self.key
100100
return op
101101

102+
def __repr__(self):
103+
return 'XmonMeasurementGate({})'.format(repr(self.key))
104+
102105

103106
class Exp11Gate(XmonGate,
104107
ops.AsciiDiagrammableGate,
@@ -302,8 +305,10 @@ def ascii_wire_symbols(self):
302305
return 'Z',
303306

304307
def ascii_exponent(self):
305-
if self.half_turns in [-0.5, -0.25, 0.25, 0.5]:
308+
if self.half_turns in [0.25, 0.5]:
306309
return 1
310+
if self.half_turns in [-0.5, -0.25]:
311+
return -1
307312
return self.half_turns
308313

309314
def try_cast_to(self, desired_type):
@@ -329,6 +334,11 @@ def matrix(self):
329334
raise ValueError("Don't have a known matrix.")
330335
return ops.RotZGate(half_turns=self.half_turns).matrix()
331336

337+
def trace_distance_bound(self):
338+
if isinstance(self.half_turns, Symbol):
339+
return 1
340+
return abs(self.half_turns) * 3.5
341+
332342
def to_proto(self, *qubits):
333343
if len(qubits) != 1:
334344
raise ValueError('Wrong number of qubits.')

0 commit comments

Comments
 (0)