3434
3535from cirq .circuits import Circuit
3636from cirq .circuits .drop_empty_moments import DropEmptyMoments
37+ from cirq .extension import Extensions
3738from cirq .google import xmon_gates
39+ from cirq .google import xmon_gate_ext
3840from cirq .google .convert_to_xmon_gates import ConvertToXmonGates
3941from cirq .google .sim .xmon_stepper import Stepper
4042from cirq .ops import raw_types
@@ -119,6 +121,7 @@ def run(
119121 options : Options = None ,
120122 qubits : Sequence [raw_types .QubitId ] = None ,
121123 initial_state : Union [int , np .ndarray ] = 0 ,
124+ extensions : Extensions = None ,
122125 ) -> SimulatorTrialResult :
123126 """Simulates the entire supplied Circuit.
124127
@@ -135,12 +138,16 @@ def run(
135138 Otherwise if this is a np.ndarray it is the full initial
136139 state. In this case it must be the correct size, be normalized
137140 (an L2 norm of 1), and have a dtype of np.complex64.
141+ extensions: Extensions that will be applied while trying to
142+ decompose the circuit's gates into XmonGates. If None, this uses
143+ the default of xmon_gate_ext.
138144
139145 Returns:
140146 Results for this run.
141147 """
142148 return self .run_sweep (circuit , [param_resolver ], repetitions , options ,
143- qubits , initial_state )[0 ]
149+ qubits , initial_state ,
150+ extensions or xmon_gate_ext )[0 ]
144151
145152 def run_sweep (
146153 self ,
@@ -150,6 +157,7 @@ def run_sweep(
150157 options : Options = None ,
151158 qubits : Sequence [raw_types .QubitId ] = None ,
152159 initial_state : Union [int , np .ndarray ] = 0 ,
160+ extensions : Extensions = None
153161 ) -> List [SimulatorTrialResult ]:
154162 """Simulates the entire supplied Circuit.
155163
@@ -166,6 +174,9 @@ def run_sweep(
166174 Otherwise if this is a np.ndarray it is the full initial state.
167175 In this case it must be the correct size, be normalized (an L2
168176 norm of 1), and have a dtype of np.complex64.
177+ extensions: Extensions that will be applied while trying to
178+ decompose the circuit's gates into XmonGates. If None, this uses
179+ the default of xmon_gate_ext.
169180
170181 Returns:
171182 List of trial results for this run, one for each possible parameter
@@ -175,18 +186,20 @@ def run_sweep(
175186 Circuit ) else program .to_circuit ()
176187 param_resolvers = self ._to_resolvers (params or ParamResolver ({}))
177188
178- xmon_circuit , keys = self ._to_xmon_circuit (circuit )
189+ xmon_circuit , keys = self ._to_xmon_circuit (circuit ,
190+ extensions or xmon_gate_ext )
179191 trial_results = [] # type: List[SimulatorTrialResult]
180192 for param_resolver in param_resolvers :
181193 measurements = {
182194 k : [] for k in keys } # type: Dict[str, List[np.ndarray]]
183195 final_states = [] # type: List[np.ndarray]
184196 for _ in range (repetitions ):
185- all_step_results = simulator_iterator (xmon_circuit ,
186- options or Options (),
187- qubits ,
188- initial_state ,
189- param_resolver )
197+ all_step_results = simulator_iterator (
198+ xmon_circuit ,
199+ options or Options (),
200+ qubits ,
201+ initial_state ,
202+ param_resolver )
190203 for step_result in all_step_results :
191204 for k , v in step_result .measurements .items ():
192205 measurements [k ].append (np .array (v , dtype = bool ))
@@ -216,7 +229,8 @@ def moment_steps(
216229 options : 'Options' = None ,
217230 qubits : Sequence [raw_types .QubitId ] = None ,
218231 initial_state : Union [int , np .ndarray ]= 0 ,
219- param_resolver : ParamResolver = None ) -> Iterator ['StepResult' ]:
232+ param_resolver : ParamResolver = None ,
233+ extensions : Extensions = None ) -> Iterator ['StepResult' ]:
220234 """Returns an iterator of XmonStepResults for each moment simulated.
221235
222236 Args:
@@ -232,20 +246,25 @@ def moment_steps(
232246 norm of 1), and have a dtype of np.complex64.
233247 param_resolver: A ParamResolver for determining values of
234248 Symbols.
249+ extensions: Extensions that will be applied while trying to
250+ decompose the circuit's gates into XmonGates. If None, this
251+ uses the default of xmon_gate_ext.
235252
236253 Returns:
237254 SimulatorIterator that steps through the simulation, simulating
238255 each moment and returning a StepResult for each moment.
239256 """
240257 param_resolver = param_resolver or ParamResolver ({})
241- xmon_circuit , _ = self ._to_xmon_circuit (program )
258+ xmon_circuit , _ = self ._to_xmon_circuit (program ,
259+ extensions or xmon_gate_ext )
242260 return simulator_iterator (xmon_circuit , options or Options (), qubits ,
243261 initial_state , param_resolver )
244262
245- def _to_xmon_circuit (self , circuit : Circuit ) -> Tuple [Circuit , Set [str ]]:
263+ def _to_xmon_circuit (self , circuit : Circuit ,
264+ extensions : Extensions = None ) -> Tuple [Circuit , Set [str ]]:
246265 # TODO: Use one optimization pass.
247266 xmon_circuit = Circuit (circuit .moments )
248- ConvertToXmonGates ().optimize_circuit (xmon_circuit )
267+ ConvertToXmonGates (extensions ).optimize_circuit (xmon_circuit )
249268 DropEmptyMoments ().optimize_circuit (xmon_circuit )
250269 keys = find_measurement_keys (xmon_circuit )
251270 return xmon_circuit , keys
@@ -256,7 +275,7 @@ def simulator_iterator(
256275 options : 'Options' = Options (),
257276 qubits : Sequence [raw_types .QubitId ] = None ,
258277 initial_state : Union [int , np .ndarray ]= 0 ,
259- param_resolver : ParamResolver = ParamResolver ({})
278+ param_resolver : ParamResolver = ParamResolver ({}),
260279) -> Iterator ['StepResult' ]:
261280 """Iterator over TrialResults from Moments of a Circuit.
262281
0 commit comments