Skip to content

Commit f32a262

Browse files
maffooStrilanc
authored andcommitted
Simplify xmon_simulator with unique measurement keys (#296)
1 parent 5c98986 commit f32a262

File tree

2 files changed

+30
-53
lines changed

2 files changed

+30
-53
lines changed

cirq/examples/bernstein_vazirani.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def bv_circuit(qubits: Sequence[cirq.QubitId], a: int) -> cirq.Circuit:
5252
# 4. Apply Hadamard gates to the outputs
5353
circuit.append(H_layer)
5454
# 5. Apply measurement layer
55-
circuit.append(cirq.ops.MeasurementGate('result').on(qubit)
56-
for qubit in qubits)
55+
circuit.append(cirq.ops.MeasurementGate('result').on(*qubits))
5756
return circuit
5857

5958

cirq/google/sim/xmon_simulator.py

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525
results = sim.run(circuit)
2626
"""
2727

28-
from collections import defaultdict, Iterable
2928
import math
30-
from typing import Dict, Iterator, List, Sequence, Union, cast
29+
from collections import defaultdict, Iterable
30+
from typing import Dict, Iterator, List, Sequence, Set, Union, cast
3131
from typing import Tuple # pylint: disable=unused-import
3232

33-
import functools
3433
import numpy as np
3534

3635
from cirq.circuits import Circuit
@@ -176,24 +175,22 @@ def run_sweep(
176175
Circuit) else program.to_circuit()
177176
param_resolvers = self._to_resolvers(params or ParamResolver({}))
178177

178+
xmon_circuit, keys = self._to_xmon_circuit(circuit)
179179
trial_results = [] # type: List[SimulatorTrialResult]
180180
for param_resolver in param_resolvers:
181-
measurements = {} # type: Dict[str, List[np.ndarray]]
181+
measurements = {
182+
k: [] for k in keys} # type: Dict[str, List[np.ndarray]]
182183
final_states = [] # type: List[np.ndarray]
183184
for _ in range(repetitions):
184-
all_step_results = self.moment_steps(circuit,
185-
options or Options(),
186-
qubits,
187-
initial_state,
188-
param_resolver)
189-
final_step_result = functools.reduce(
190-
StepResult.merge,
191-
all_step_results)
192-
for k, v in final_step_result.measurements.items():
193-
if k not in measurements:
194-
measurements[k] = []
195-
measurements[k].append(np.array(v, dtype=bool))
196-
final_states.append(final_step_result.state())
185+
all_step_results = simulator_iterator(xmon_circuit,
186+
options or Options(),
187+
qubits,
188+
initial_state,
189+
param_resolver)
190+
for step_result in all_step_results:
191+
for k, v in step_result.measurements.items():
192+
measurements[k].append(np.array(v, dtype=bool))
193+
final_states.append(step_result.state())
197194
trial_results.append(SimulatorTrialResult(
198195
param_resolver,
199196
repetitions,
@@ -241,9 +238,18 @@ def moment_steps(
241238
each moment and returning a StepResult for each moment.
242239
"""
243240
param_resolver = param_resolver or ParamResolver({})
244-
return simulator_iterator(program, options or Options(), qubits,
241+
xmon_circuit, _ = self._to_xmon_circuit(program)
242+
return simulator_iterator(xmon_circuit, options or Options(), qubits,
245243
initial_state, param_resolver)
246244

245+
def _to_xmon_circuit(self, circuit: Circuit) -> Tuple[Circuit, Set[str]]:
246+
# TODO: Use one optimization pass.
247+
xmon_circuit = Circuit(circuit.moments)
248+
ConvertToXmonGates().optimize_circuit(xmon_circuit)
249+
DropEmptyMoments().optimize_circuit(xmon_circuit)
250+
keys = find_measurement_keys(xmon_circuit)
251+
return xmon_circuit, keys
252+
247253

248254
def simulator_iterator(
249255
circuit: Circuit,
@@ -258,7 +264,7 @@ def simulator_iterator(
258264
Simulator and use methods on that object to get an iterator.
259265
260266
Args:
261-
circuit: The circuit to simulate.
267+
circuit: The circuit to simulate; must contain xmon gates only.
262268
options: Options configuring the simulation.
263269
qubits: If specified this list of qubits will be used to define
264270
a canonical ordering of the qubits. This canonical ordering
@@ -283,18 +289,12 @@ def simulator_iterator(
283289
qubits = list(circuit_qubits)
284290
qubit_map = {q: i for i, q in enumerate(qubits)}
285291

286-
# TODO: Use one optimization pass.
287-
circuit_copy = Circuit(circuit.moments)
288-
ConvertToXmonGates().optimize_circuit(circuit_copy)
289-
DropEmptyMoments().optimize_circuit(circuit_copy)
290-
validate_unique_measurement_keys(circuit_copy)
291-
292292
with Stepper(num_qubits=len(qubits),
293293
num_prefix_qubits=options.num_prefix_qubits,
294294
initial_state=initial_state,
295295
min_qubits_before_shard=options.min_qubits_before_shard
296296
) as stepper:
297-
for moment in circuit_copy.moments:
297+
for moment in circuit.moments:
298298
measurements = defaultdict(list) # type: Dict[str, List[bool]]
299299
phase_map = {} # type: Dict[Tuple[int, ...], float]
300300
for op in moment.operations:
@@ -330,15 +330,16 @@ def simulator_iterator(
330330
yield StepResult(stepper, qubit_map, measurements)
331331

332332

333-
def validate_unique_measurement_keys(circuit):
334-
keys = set()
333+
def find_measurement_keys(circuit: Circuit) -> Set[str]:
334+
keys = set() # type: Set[str]
335335
for moment in circuit.moments:
336336
for op in moment.operations:
337337
if isinstance(op.gate, xmon_gates.XmonMeasurementGate):
338338
key = op.gate.key
339339
if key in keys:
340340
raise ValueError('Repeated Measurement key {}'.format(key))
341341
keys.add(key)
342+
return keys
342343

343344

344345
class StepResult:
@@ -402,26 +403,3 @@ def set_state(self, state: Union[int, np.ndarray]):
402403
dtype.
403404
"""
404405
self._stepper.reset_state(state)
405-
406-
@staticmethod
407-
def merge(a: 'StepResult', b: 'StepResult') -> 'StepResult':
408-
"""Merges measurement results of last_result into a new Result.
409-
410-
The measurement results are merges such that measurements with
411-
duplicate keys have the results of last_result before those of this
412-
objects' results.
413-
414-
Args:
415-
a: First result to merge.
416-
b: Second result to merge.
417-
418-
Returns:
419-
A new StepResult with merged measurements.
420-
"""
421-
new_measurements = {} # type: Dict[str, list]
422-
for d in [a.measurements, b.measurements]:
423-
for key, results in d.items():
424-
if key not in new_measurements:
425-
new_measurements[key] = []
426-
new_measurements[key].extend(results)
427-
return StepResult(a._stepper, a.qubit_map, new_measurements)

0 commit comments

Comments
 (0)