2525 results = sim.run(circuit)
2626"""
2727
28- from collections import defaultdict , Iterable
2928import math
30- from typing import Dict , Iterator , List , Sequence , Union , cast
29+ from collections import defaultdict , Iterable
30+ from typing import Dict , Iterator , List , Sequence , Set , Union , cast
3131from typing import Tuple # pylint: disable=unused-import
3232
33- import functools
3433import numpy as np
3534
3635from cirq .circuits import Circuit
@@ -176,24 +175,22 @@ def run_sweep(
176175 Circuit ) else program .to_circuit ()
177176 param_resolvers = self ._to_resolvers (params or ParamResolver ({}))
178177
178+ xmon_circuit , keys = self ._to_xmon_circuit (circuit )
179179 trial_results = [] # type: List[SimulatorTrialResult]
180180 for param_resolver in param_resolvers :
181- measurements = {} # type: Dict[str, List[np.ndarray]]
181+ measurements = {
182+ k : [] for k in keys } # type: Dict[str, List[np.ndarray]]
182183 final_states = [] # type: List[np.ndarray]
183184 for _ in range (repetitions ):
184- all_step_results = self .moment_steps (circuit ,
185- options or Options (),
186- qubits ,
187- initial_state ,
188- param_resolver )
189- final_step_result = functools .reduce (
190- StepResult .merge ,
191- all_step_results )
192- for k , v in final_step_result .measurements .items ():
193- if k not in measurements :
194- measurements [k ] = []
195- measurements [k ].append (np .array (v , dtype = bool ))
196- final_states .append (final_step_result .state ())
185+ all_step_results = simulator_iterator (xmon_circuit ,
186+ options or Options (),
187+ qubits ,
188+ initial_state ,
189+ param_resolver )
190+ for step_result in all_step_results :
191+ for k , v in step_result .measurements .items ():
192+ measurements [k ].append (np .array (v , dtype = bool ))
193+ final_states .append (step_result .state ())
197194 trial_results .append (SimulatorTrialResult (
198195 param_resolver ,
199196 repetitions ,
@@ -241,9 +238,18 @@ def moment_steps(
241238 each moment and returning a StepResult for each moment.
242239 """
243240 param_resolver = param_resolver or ParamResolver ({})
244- return simulator_iterator (program , options or Options (), qubits ,
241+ xmon_circuit , _ = self ._to_xmon_circuit (program )
242+ return simulator_iterator (xmon_circuit , options or Options (), qubits ,
245243 initial_state , param_resolver )
246244
245+ def _to_xmon_circuit (self , circuit : Circuit ) -> Tuple [Circuit , Set [str ]]:
246+ # TODO: Use one optimization pass.
247+ xmon_circuit = Circuit (circuit .moments )
248+ ConvertToXmonGates ().optimize_circuit (xmon_circuit )
249+ DropEmptyMoments ().optimize_circuit (xmon_circuit )
250+ keys = find_measurement_keys (xmon_circuit )
251+ return xmon_circuit , keys
252+
247253
248254def simulator_iterator (
249255 circuit : Circuit ,
@@ -258,7 +264,7 @@ def simulator_iterator(
258264 Simulator and use methods on that object to get an iterator.
259265
260266 Args:
261- circuit: The circuit to simulate.
267+ circuit: The circuit to simulate; must contain xmon gates only .
262268 options: Options configuring the simulation.
263269 qubits: If specified this list of qubits will be used to define
264270 a canonical ordering of the qubits. This canonical ordering
@@ -283,18 +289,12 @@ def simulator_iterator(
283289 qubits = list (circuit_qubits )
284290 qubit_map = {q : i for i , q in enumerate (qubits )}
285291
286- # TODO: Use one optimization pass.
287- circuit_copy = Circuit (circuit .moments )
288- ConvertToXmonGates ().optimize_circuit (circuit_copy )
289- DropEmptyMoments ().optimize_circuit (circuit_copy )
290- validate_unique_measurement_keys (circuit_copy )
291-
292292 with Stepper (num_qubits = len (qubits ),
293293 num_prefix_qubits = options .num_prefix_qubits ,
294294 initial_state = initial_state ,
295295 min_qubits_before_shard = options .min_qubits_before_shard
296296 ) as stepper :
297- for moment in circuit_copy .moments :
297+ for moment in circuit .moments :
298298 measurements = defaultdict (list ) # type: Dict[str, List[bool]]
299299 phase_map = {} # type: Dict[Tuple[int, ...], float]
300300 for op in moment .operations :
@@ -330,15 +330,16 @@ def simulator_iterator(
330330 yield StepResult (stepper , qubit_map , measurements )
331331
332332
333- def validate_unique_measurement_keys (circuit ) :
334- keys = set ()
333+ def find_measurement_keys (circuit : Circuit ) -> Set [ str ] :
334+ keys = set () # type: Set[str]
335335 for moment in circuit .moments :
336336 for op in moment .operations :
337337 if isinstance (op .gate , xmon_gates .XmonMeasurementGate ):
338338 key = op .gate .key
339339 if key in keys :
340340 raise ValueError ('Repeated Measurement key {}' .format (key ))
341341 keys .add (key )
342+ return keys
342343
343344
344345class StepResult :
@@ -402,26 +403,3 @@ def set_state(self, state: Union[int, np.ndarray]):
402403 dtype.
403404 """
404405 self ._stepper .reset_state (state )
405-
406- @staticmethod
407- def merge (a : 'StepResult' , b : 'StepResult' ) -> 'StepResult' :
408- """Merges measurement results of last_result into a new Result.
409-
410- The measurement results are merges such that measurements with
411- duplicate keys have the results of last_result before those of this
412- objects' results.
413-
414- Args:
415- a: First result to merge.
416- b: Second result to merge.
417-
418- Returns:
419- A new StepResult with merged measurements.
420- """
421- new_measurements = {} # type: Dict[str, list]
422- for d in [a .measurements , b .measurements ]:
423- for key , results in d .items ():
424- if key not in new_measurements :
425- new_measurements [key ] = []
426- new_measurements [key ].extend (results )
427- return StepResult (a ._stepper , a .qubit_map , new_measurements )
0 commit comments