1414
1515from collections import deque
1616from dataclasses import dataclass
17- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
17+ from typing import Any , Dict , Iterator , List , Optional , Sequence , Tuple , Union
1818
1919import cirq
2020
2424import qsimcirq .qsim_circuit as qsimc
2525
2626
27- class QSimSimulatorState (cirq .StateVectorSimulatorState ):
28- def __init__ (self , qsim_data : np .ndarray , qubit_map : Dict [cirq .Qid , int ]):
29- state_vector = qsim_data .view (np .complex64 )
30- super ().__init__ (state_vector = state_vector , qubit_map = qubit_map )
31-
32-
33- @cirq .value_equality (unhashable = True )
34- class QSimSimulatorTrialResult (cirq .StateVectorMixin , cirq .SimulationTrialResult ):
35- def __init__ (
36- self ,
37- params : cirq .ParamResolver ,
38- measurements : Dict [str , np .ndarray ],
39- final_simulator_state : QSimSimulatorState ,
40- ):
41- super ().__init__ (
42- params = params ,
43- measurements = measurements ,
44- final_simulator_state = final_simulator_state ,
45- )
46-
47- # The following methods are (temporarily) copied here from
48- # cirq.StateVectorTrialResult due to incompatibility with the
49- # intermediate state simulation support which that class requires.
50- # TODO: remove these methods once inheritance is restored.
51-
52- @property
53- def final_state_vector (self ):
54- return self ._final_simulator_state .state_vector
55-
56- def state_vector (self ):
57- """Return the state vector at the end of the computation."""
58- return self ._final_simulator_state .state_vector .copy ()
59-
60- def _value_equality_values_ (self ):
61- measurements = {k : v .tolist () for k , v in sorted (self .measurements .items ())}
62- return (self .params , measurements , self ._final_simulator_state )
63-
64- def __str__ (self ) -> str :
65- samples = super ().__str__ ()
66- final = self .state_vector ()
67- if len ([1 for e in final if abs (e ) > 0.001 ]) < 16 :
68- state_vector = self .dirac_notation (3 )
69- else :
70- state_vector = str (final )
71- return f"measurements: { samples } \n output vector: { state_vector } "
72-
73- def _repr_pretty_ (self , p : Any , cycle : bool ) -> None :
74- """Text output in Jupyter."""
75- if cycle :
76- # There should never be a cycle. This is just in case.
77- p .text ("StateVectorTrialResult(...)" )
78- else :
79- p .text (str (self ))
80-
81- def __repr__ (self ) -> str :
82- return (
83- f"cirq.StateVectorTrialResult(params={ self .params !r} , "
84- f"measurements={ self .measurements !r} , "
85- f"final_simulator_state={ self ._final_simulator_state !r} )"
86- )
87-
88-
8927# This should probably live in Cirq...
9028# TODO: update to support CircuitOperations.
9129def _needs_trajectories (circuit : cirq .Circuit ) -> bool :
@@ -189,7 +127,7 @@ class MeasInfo:
189127class QSimSimulator (
190128 cirq .SimulatesSamples ,
191129 cirq .SimulatesAmplitudes ,
192- cirq .SimulatesFinalState ,
130+ cirq .SimulatesFinalState [ cirq . StateVectorTrialResult ] ,
193131 cirq .SimulatesExpectationValues ,
194132):
195133 def __init__ (
@@ -438,13 +376,13 @@ def _sample_measure_results(
438376
439377 return results
440378
441- def compute_amplitudes_sweep (
379+ def compute_amplitudes_sweep_iter (
442380 self ,
443381 program : cirq .Circuit ,
444382 bitstrings : Sequence [int ],
445383 params : cirq .Sweepable ,
446384 qubit_order : cirq .QubitOrderOrList = cirq .QubitOrder .DEFAULT ,
447- ) -> Sequence [Sequence [complex ]]:
385+ ) -> Iterator [Sequence [complex ]]:
448386 """Computes the desired amplitudes using qsim.
449387
450388 The initial state is assumed to be the all zeros state.
@@ -460,8 +398,8 @@ def compute_amplitudes_sweep(
460398 often used in specifying the initial state, i.e. the ordering of the
461399 computational basis states.
462400
463- Returns :
464- List of amplitudes .
401+ Yields :
402+ Amplitudes .
465403 """
466404
467405 # Add noise to the circuit if a noise model was provided.
@@ -484,7 +422,6 @@ def compute_amplitudes_sweep(
484422
485423 param_resolvers = cirq .to_resolvers (params )
486424
487- trials_results = []
488425 if _needs_trajectories (program ):
489426 translator_fn_name = "translate_cirq_to_qtrajectory"
490427 simulator_fn = self ._sim_module .qtrajectory_simulate
@@ -500,18 +437,15 @@ def compute_amplitudes_sweep(
500437 cirq_order ,
501438 )
502439 options ["s" ] = self .get_seed ()
503- amplitudes = simulator_fn (options )
504- trials_results .append (amplitudes )
440+ yield simulator_fn (options )
505441
506- return trials_results
507-
508- def simulate_sweep (
442+ def simulate_sweep_iter (
509443 self ,
510444 program : cirq .Circuit ,
511445 params : cirq .Sweepable ,
512446 qubit_order : cirq .QubitOrderOrList = cirq .QubitOrder .DEFAULT ,
513447 initial_state : Optional [Union [int , np .ndarray ]] = None ,
514- ) -> List [ "SimulationTrialResult" ]:
448+ ) -> Iterator [ cirq . StateVectorTrialResult ]:
515449 """Simulates the supplied Circuit.
516450
517451 This method returns a result which allows access to the entire
@@ -572,7 +506,6 @@ def simulate_sweep(
572506 f"Expected: { 2 ** num_qubits * 2 } Received: { len (input_vector )} "
573507 )
574508
575- trials_results = []
576509 if _needs_trajectories (program ):
577510 translator_fn_name = "translate_cirq_to_qtrajectory"
578511 fullstate_simulator_fn = self ._sim_module .qtrajectory_simulate_fullstate
@@ -589,33 +522,32 @@ def simulate_sweep(
589522 cirq_order ,
590523 )
591524 options ["s" ] = self .get_seed ()
592- qubit_map = {qubit : index for index , qubit in enumerate (qsim_order )}
593525
594526 if isinstance (initial_state , int ):
595527 qsim_state = fullstate_simulator_fn (options , initial_state )
596528 elif isinstance (initial_state , np .ndarray ):
597529 qsim_state = fullstate_simulator_fn (options , input_vector )
598530 assert qsim_state .dtype == np .float32
599531 assert qsim_state .ndim == 1
600- final_state = QSimSimulatorState (qsim_state , qubit_map )
532+
533+ final_state = cirq .StateVectorSimulationState (
534+ initial_state = qsim_state .view (np .complex64 ), qubits = cirq_order
535+ )
601536 # create result for this parameter
602537 # TODO: We need to support measurements.
603- result = QSimSimulatorTrialResult (
538+ yield cirq . StateVectorTrialResult (
604539 params = prs , measurements = {}, final_simulator_state = final_state
605540 )
606- trials_results .append (result )
607541
608- return trials_results
609-
610- def simulate_expectation_values_sweep (
542+ def simulate_expectation_values_sweep_iter (
611543 self ,
612544 program : cirq .Circuit ,
613545 observables : Union [cirq .PauliSumLike , List [cirq .PauliSumLike ]],
614546 params : cirq .Sweepable ,
615547 qubit_order : cirq .QubitOrderOrList = cirq .QubitOrder .DEFAULT ,
616548 initial_state : Any = None ,
617549 permit_terminal_measurements : bool = False ,
618- ) -> List [List [float ]]:
550+ ) -> Iterator [List [float ]]:
619551 """Simulates the supplied circuit and calculates exact expectation
620552 values for the given observables on its final state.
621553
@@ -638,8 +570,8 @@ def simulate_expectation_values_sweep(
638570 is set to True. This is meant to prevent measurements from
639571 ruining expectation value calculations.
640572
641- Returns :
642- A list of expectation values, with the value at index `n`
573+ Yields :
574+ Lists of expectation values, with the value at index `n`
643575 corresponding to `observables[n]` from the input.
644576
645577 Raises:
@@ -703,7 +635,6 @@ def simulate_expectation_values_sweep(
703635 f"Expected: { 2 ** num_qubits * 2 } Received: { len (input_vector )} "
704636 )
705637
706- results = []
707638 if _needs_trajectories (program ):
708639 translator_fn_name = "translate_cirq_to_qtrajectory"
709640 ev_simulator_fn = self ._sim_module .qtrajectory_simulate_expectation_values
@@ -724,9 +655,7 @@ def simulate_expectation_values_sweep(
724655 evs = ev_simulator_fn (options , opsums_and_qubit_counts , initial_state )
725656 elif isinstance (initial_state , np .ndarray ):
726657 evs = ev_simulator_fn (options , opsums_and_qubit_counts , input_vector )
727- results .append (evs )
728-
729- return results
658+ yield evs
730659
731660 def simulate_moment_expectation_values (
732661 self ,
@@ -870,20 +799,13 @@ def _translate_circuit(
870799 translator_fn_name : str ,
871800 qubit_order : cirq .QubitOrderOrList ,
872801 ):
873- # If the circuit is memoized, reuse the corresponding translated
874- # circuit.
875- translated_circuit = None
876- for original , translated , m_indices in self ._translated_circuits :
802+ # If the circuit is memoized, reuse the corresponding translated circuit.
803+ for original , translated , moment_indices in self ._translated_circuits :
877804 if original == circuit :
878- translated_circuit = translated
879- moment_indices = m_indices
880- break
881-
882- if translated_circuit is None :
883- translator_fn = getattr (circuit , translator_fn_name )
884- translated_circuit , moment_indices = translator_fn (qubit_order )
885- self ._translated_circuits .append (
886- (circuit , translated_circuit , moment_indices )
887- )
805+ return translated , moment_indices
806+
807+ translator_fn = getattr (circuit , translator_fn_name )
808+ translated , moment_indices = translator_fn (qubit_order )
809+ self ._translated_circuits .append ((circuit , translated , moment_indices ))
888810
889- return translated_circuit , moment_indices
811+ return translated , moment_indices
0 commit comments