Skip to content

Commit 2c7632a

Browse files
Merge pull request #555 from quantumlib/u/maffoo/state-vector
Update for compatibility with cirq 1.0
2 parents b4fdca1 + d2f9c9b commit 2c7632a

File tree

5 files changed

+55
-128
lines changed

5 files changed

+55
-128
lines changed

qsimcirq/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,7 @@ def _load_qsim_custatevec():
5353
qsim_custatevec = _load_qsim_custatevec()
5454

5555
from .qsim_circuit import add_op_to_opstring, add_op_to_circuit, QSimCircuit
56-
from .qsim_simulator import (
57-
QSimOptions,
58-
QSimSimulatorState,
59-
QSimSimulatorTrialResult,
60-
QSimSimulator,
61-
)
56+
from .qsim_simulator import QSimOptions, QSimSimulator
6257
from .qsimh_simulator import QSimhSimulator
6358

6459
from qsimcirq._version import (

qsimcirq/qsim_circuit.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import warnings
16-
from typing import Dict, Union
16+
from typing import Dict, List, Optional, Sequence, Tuple, Union
1717

1818
import cirq
1919
import numpy as np
@@ -210,26 +210,33 @@ def _has_cirq_gate_kind(op: cirq.Operation):
210210
return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro())
211211

212212

213-
def _control_details(gate: cirq.ControlledGate, qubits):
214-
control_qubits = []
215-
control_values = []
213+
def _control_details(
214+
gate: cirq.ControlledGate, qubits: Sequence[cirq.Qid]
215+
) -> Tuple[List[cirq.Qid], List[int]]:
216+
control_qubits: List[cirq.Qid] = []
217+
control_values: List[int] = []
216218
# TODO: support qudit control
217-
for i, cvs in enumerate(gate.control_values):
219+
assignments = list(gate.control_values.expand())
220+
if len(qubits) > 1 and len(assignments) > 1:
221+
raise ValueError(
222+
f"Cannot translate controlled gate with multiple assignments for multiple qubits: {gate}"
223+
)
224+
for q, cvs in zip(qubits, zip(*assignments)):
218225
if 0 in cvs and 1 in cvs:
219226
# This qubit does not affect control.
220227
continue
221-
elif 0 not in cvs and 1 not in cvs:
222-
# This gate will never trigger.
223-
warnings.warn(f"Gate has no valid control value: {gate}", RuntimeWarning)
224-
return (None, None)
228+
elif any(cv not in (0, 1) for cv in cvs):
229+
raise ValueError(
230+
f"Cannot translate control values other than 0 and 1: cvs={cvs}"
231+
)
225232
# Either 0 or 1 is in cvs, but not both.
226-
control_qubits.append(qubits[i])
233+
control_qubits.append(q)
227234
if 0 in cvs:
228235
control_values.append(0)
229236
elif 1 in cvs:
230237
control_values.append(1)
231238

232-
return (control_qubits, control_values)
239+
return control_qubits, control_values
233240

234241

235242
def add_op_to_opstring(
@@ -271,7 +278,9 @@ def add_op_to_circuit(
271278
qsim_qubits = qubits
272279
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
273280
if is_controlled:
274-
control_qubits, control_values = _control_details(qsim_gate, qubits)
281+
control_qubits, control_values = _control_details(
282+
qsim_gate, qubits[: qsim_gate.num_controls()]
283+
)
275284
if control_qubits is None:
276285
# This gate has no valid control, and will be omitted.
277286
return

qsimcirq/qsim_simulator.py

Lines changed: 27 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from collections import deque
1616
from dataclasses import dataclass
17-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
17+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
1818

1919
import cirq
2020

@@ -24,68 +24,6 @@
2424
import qsimcirq.qsim_circuit as qsimc
2525

2626

27-
class QSimSimulatorState(cirq.StateVectorSimulatorState):
28-
def __init__(self, qsim_data: np.ndarray, qubit_map: Dict[cirq.Qid, int]):
29-
state_vector = qsim_data.view(np.complex64)
30-
super().__init__(state_vector=state_vector, qubit_map=qubit_map)
31-
32-
33-
@cirq.value_equality(unhashable=True)
34-
class QSimSimulatorTrialResult(cirq.StateVectorMixin, cirq.SimulationTrialResult):
35-
def __init__(
36-
self,
37-
params: cirq.ParamResolver,
38-
measurements: Dict[str, np.ndarray],
39-
final_simulator_state: QSimSimulatorState,
40-
):
41-
super().__init__(
42-
params=params,
43-
measurements=measurements,
44-
final_simulator_state=final_simulator_state,
45-
)
46-
47-
# The following methods are (temporarily) copied here from
48-
# cirq.StateVectorTrialResult due to incompatibility with the
49-
# intermediate state simulation support which that class requires.
50-
# TODO: remove these methods once inheritance is restored.
51-
52-
@property
53-
def final_state_vector(self):
54-
return self._final_simulator_state.state_vector
55-
56-
def state_vector(self):
57-
"""Return the state vector at the end of the computation."""
58-
return self._final_simulator_state.state_vector.copy()
59-
60-
def _value_equality_values_(self):
61-
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}
62-
return (self.params, measurements, self._final_simulator_state)
63-
64-
def __str__(self) -> str:
65-
samples = super().__str__()
66-
final = self.state_vector()
67-
if len([1 for e in final if abs(e) > 0.001]) < 16:
68-
state_vector = self.dirac_notation(3)
69-
else:
70-
state_vector = str(final)
71-
return f"measurements: {samples}\noutput vector: {state_vector}"
72-
73-
def _repr_pretty_(self, p: Any, cycle: bool) -> None:
74-
"""Text output in Jupyter."""
75-
if cycle:
76-
# There should never be a cycle. This is just in case.
77-
p.text("StateVectorTrialResult(...)")
78-
else:
79-
p.text(str(self))
80-
81-
def __repr__(self) -> str:
82-
return (
83-
f"cirq.StateVectorTrialResult(params={self.params!r}, "
84-
f"measurements={self.measurements!r}, "
85-
f"final_simulator_state={self._final_simulator_state!r})"
86-
)
87-
88-
8927
# This should probably live in Cirq...
9028
# TODO: update to support CircuitOperations.
9129
def _needs_trajectories(circuit: cirq.Circuit) -> bool:
@@ -189,7 +127,7 @@ class MeasInfo:
189127
class QSimSimulator(
190128
cirq.SimulatesSamples,
191129
cirq.SimulatesAmplitudes,
192-
cirq.SimulatesFinalState,
130+
cirq.SimulatesFinalState[cirq.StateVectorTrialResult],
193131
cirq.SimulatesExpectationValues,
194132
):
195133
def __init__(
@@ -438,13 +376,13 @@ def _sample_measure_results(
438376

439377
return results
440378

441-
def compute_amplitudes_sweep(
379+
def compute_amplitudes_sweep_iter(
442380
self,
443381
program: cirq.Circuit,
444382
bitstrings: Sequence[int],
445383
params: cirq.Sweepable,
446384
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
447-
) -> Sequence[Sequence[complex]]:
385+
) -> Iterator[Sequence[complex]]:
448386
"""Computes the desired amplitudes using qsim.
449387
450388
The initial state is assumed to be the all zeros state.
@@ -460,8 +398,8 @@ def compute_amplitudes_sweep(
460398
often used in specifying the initial state, i.e. the ordering of the
461399
computational basis states.
462400
463-
Returns:
464-
List of amplitudes.
401+
Yields:
402+
Amplitudes.
465403
"""
466404

467405
# Add noise to the circuit if a noise model was provided.
@@ -484,7 +422,6 @@ def compute_amplitudes_sweep(
484422

485423
param_resolvers = cirq.to_resolvers(params)
486424

487-
trials_results = []
488425
if _needs_trajectories(program):
489426
translator_fn_name = "translate_cirq_to_qtrajectory"
490427
simulator_fn = self._sim_module.qtrajectory_simulate
@@ -500,18 +437,15 @@ def compute_amplitudes_sweep(
500437
cirq_order,
501438
)
502439
options["s"] = self.get_seed()
503-
amplitudes = simulator_fn(options)
504-
trials_results.append(amplitudes)
440+
yield simulator_fn(options)
505441

506-
return trials_results
507-
508-
def simulate_sweep(
442+
def simulate_sweep_iter(
509443
self,
510444
program: cirq.Circuit,
511445
params: cirq.Sweepable,
512446
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
513447
initial_state: Optional[Union[int, np.ndarray]] = None,
514-
) -> List["SimulationTrialResult"]:
448+
) -> Iterator[cirq.StateVectorTrialResult]:
515449
"""Simulates the supplied Circuit.
516450
517451
This method returns a result which allows access to the entire
@@ -572,7 +506,6 @@ def simulate_sweep(
572506
f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}"
573507
)
574508

575-
trials_results = []
576509
if _needs_trajectories(program):
577510
translator_fn_name = "translate_cirq_to_qtrajectory"
578511
fullstate_simulator_fn = self._sim_module.qtrajectory_simulate_fullstate
@@ -589,33 +522,32 @@ def simulate_sweep(
589522
cirq_order,
590523
)
591524
options["s"] = self.get_seed()
592-
qubit_map = {qubit: index for index, qubit in enumerate(qsim_order)}
593525

594526
if isinstance(initial_state, int):
595527
qsim_state = fullstate_simulator_fn(options, initial_state)
596528
elif isinstance(initial_state, np.ndarray):
597529
qsim_state = fullstate_simulator_fn(options, input_vector)
598530
assert qsim_state.dtype == np.float32
599531
assert qsim_state.ndim == 1
600-
final_state = QSimSimulatorState(qsim_state, qubit_map)
532+
533+
final_state = cirq.StateVectorSimulationState(
534+
initial_state=qsim_state.view(np.complex64), qubits=cirq_order
535+
)
601536
# create result for this parameter
602537
# TODO: We need to support measurements.
603-
result = QSimSimulatorTrialResult(
538+
yield cirq.StateVectorTrialResult(
604539
params=prs, measurements={}, final_simulator_state=final_state
605540
)
606-
trials_results.append(result)
607541

608-
return trials_results
609-
610-
def simulate_expectation_values_sweep(
542+
def simulate_expectation_values_sweep_iter(
611543
self,
612544
program: cirq.Circuit,
613545
observables: Union[cirq.PauliSumLike, List[cirq.PauliSumLike]],
614546
params: cirq.Sweepable,
615547
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
616548
initial_state: Any = None,
617549
permit_terminal_measurements: bool = False,
618-
) -> List[List[float]]:
550+
) -> Iterator[List[float]]:
619551
"""Simulates the supplied circuit and calculates exact expectation
620552
values for the given observables on its final state.
621553
@@ -638,8 +570,8 @@ def simulate_expectation_values_sweep(
638570
is set to True. This is meant to prevent measurements from
639571
ruining expectation value calculations.
640572
641-
Returns:
642-
A list of expectation values, with the value at index `n`
573+
Yields:
574+
Lists of expectation values, with the value at index `n`
643575
corresponding to `observables[n]` from the input.
644576
645577
Raises:
@@ -703,7 +635,6 @@ def simulate_expectation_values_sweep(
703635
f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}"
704636
)
705637

706-
results = []
707638
if _needs_trajectories(program):
708639
translator_fn_name = "translate_cirq_to_qtrajectory"
709640
ev_simulator_fn = self._sim_module.qtrajectory_simulate_expectation_values
@@ -724,9 +655,7 @@ def simulate_expectation_values_sweep(
724655
evs = ev_simulator_fn(options, opsums_and_qubit_counts, initial_state)
725656
elif isinstance(initial_state, np.ndarray):
726657
evs = ev_simulator_fn(options, opsums_and_qubit_counts, input_vector)
727-
results.append(evs)
728-
729-
return results
658+
yield evs
730659

731660
def simulate_moment_expectation_values(
732661
self,
@@ -870,20 +799,13 @@ def _translate_circuit(
870799
translator_fn_name: str,
871800
qubit_order: cirq.QubitOrderOrList,
872801
):
873-
# If the circuit is memoized, reuse the corresponding translated
874-
# circuit.
875-
translated_circuit = None
876-
for original, translated, m_indices in self._translated_circuits:
802+
# If the circuit is memoized, reuse the corresponding translated circuit.
803+
for original, translated, moment_indices in self._translated_circuits:
877804
if original == circuit:
878-
translated_circuit = translated
879-
moment_indices = m_indices
880-
break
881-
882-
if translated_circuit is None:
883-
translator_fn = getattr(circuit, translator_fn_name)
884-
translated_circuit, moment_indices = translator_fn(qubit_order)
885-
self._translated_circuits.append(
886-
(circuit, translated_circuit, moment_indices)
887-
)
805+
return translated, moment_indices
806+
807+
translator_fn = getattr(circuit, translator_fn_name)
808+
translated, moment_indices = translator_fn(qubit_order)
809+
self._translated_circuits.append((circuit, translated, moment_indices))
888810

889-
return translated_circuit, moment_indices
811+
return translated, moment_indices

qsimcirq_tests/qsimcirq_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,10 @@ def test_control_values():
839839
cirq.X(qubits[2]).controlled_by(*qubits[:2], control_values=[1, 2]),
840840
)
841841
qsimSim = qsimcirq.QSimSimulator()
842-
with pytest.warns(RuntimeWarning, match="Gate has no valid control value"):
843-
result = qsimSim.simulate(cirq_circuit, qubit_order=qubits)
844-
assert result.state_vector()[0] == 1
842+
with pytest.raises(
843+
ValueError, match="Cannot translate control values other than 0 and 1"
844+
):
845+
_ = qsimSim.simulate(cirq_circuit, qubit_order=qubits)
845846

846847

847848
def test_control_limits():
@@ -1659,7 +1660,7 @@ def test_cirq_qsim_all_supported_gates():
16591660
qsim_result = qsim_simulator.simulate(circuit)
16601661

16611662
assert cirq.linalg.allclose_up_to_global_phase(
1662-
qsim_result.state_vector(), cirq_result.state_vector()
1663+
qsim_result.state_vector(), cirq_result.state_vector(), atol=1e-5
16631664
)
16641665

16651666

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
absl-py
2-
cirq-core
2+
cirq-core~=1.0
33
numpy~=1.21
44
pybind11
55
typing_extensions

0 commit comments

Comments
 (0)