1313# limitations under the License.
1414
1515import itertools
16- from typing import (
17- Any ,
18- Dict ,
19- Iterable ,
20- List ,
21- Mapping ,
22- Optional ,
23- Sequence ,
24- Tuple ,
25- TYPE_CHECKING ,
26- Union ,
27- )
16+ from collections import defaultdict
17+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Tuple , TYPE_CHECKING , Union
2818
2919import numpy as np
3020
@@ -43,30 +33,32 @@ class _MeasurementQid(ops.Qid):
4333 Exactly one qubit will be created per qubit in the measurement gate.
4434 """
4535
46- def __init__ (self , key : Union [str , 'cirq.MeasurementKey' ], qid : 'cirq.Qid' ):
36+ def __init__ (self , key : Union [str , 'cirq.MeasurementKey' ], qid : 'cirq.Qid' , index : int = 0 ):
4737 """Initializes the qubit.
4838
4939 Args:
5040 key: The key of the measurement gate being deferred.
5141 qid: One qubit that is being measured. Each deferred measurement
5242 should create one new _MeasurementQid per qubit being measured
5343 by that gate.
44+ index: For repeated measurement keys, this represents the index of that measurement.
5445 """
5546 self ._key = value .MeasurementKey .parse_serialized (key ) if isinstance (key , str ) else key
5647 self ._qid = qid
48+ self ._index = index
5749
5850 @property
5951 def dimension (self ) -> int :
6052 return self ._qid .dimension
6153
6254 def _comparison_key (self ) -> Any :
63- return str (self ._key ), self ._qid ._comparison_key ()
55+ return str (self ._key ), self ._index , self . _qid ._comparison_key ()
6456
6557 def __str__ (self ) -> str :
66- return f"M('{ self ._key } ', q={ self ._qid } )"
58+ return f"M('{ self ._key } [ { self . _index } ] ', q={ self ._qid } )"
6759
6860 def __repr__ (self ) -> str :
69- return f'_MeasurementQid({ self ._key !r} , { self ._qid !r} )'
61+ return f'_MeasurementQid({ self ._key !r} , { self ._qid !r} , { self . _index } )'
7062
7163
7264@transformer_api .transformer
@@ -102,16 +94,18 @@ def defer_measurements(
10294
10395 circuit = transformer_primitives .unroll_circuit_op (circuit , deep = True , tags_to_check = None )
10496 terminal_measurements = {op for _ , op in find_terminal_measurements (circuit )}
105- measurement_qubits : Dict ['cirq.MeasurementKey' , List ['_MeasurementQid' ]] = {}
97+ measurement_qubits : Dict ['cirq.MeasurementKey' , List [Tuple ['cirq.Qid' , ...]]] = defaultdict (
98+ list
99+ )
106100
107101 def defer (op : 'cirq.Operation' , _ ) -> 'cirq.OP_TREE' :
108102 if op in terminal_measurements :
109103 return op
110104 gate = op .gate
111105 if isinstance (gate , ops .MeasurementGate ):
112106 key = value .MeasurementKey .parse_serialized (gate .key )
113- targets = [_MeasurementQid (key , q ) for q in op .qubits ]
114- measurement_qubits [key ] = targets
107+ targets = [_MeasurementQid (key , q , len ( measurement_qubits [ key ]) ) for q in op .qubits ]
108+ measurement_qubits [key ]. append ( tuple ( targets ))
115109 cxs = [_mod_add (q , target ) for q , target in zip (op .qubits , targets )]
116110 confusions = [
117111 _ConfusionChannel (m , [op .qubits [i ].dimension for i in indexes ]).on (
@@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
125119 return [defer (op , None ) for op in protocols .decompose_once (op )]
126120 elif op .classical_controls :
127121 # Convert to a quantum control
128- keys = sorted (set (key for c in op .classical_controls for key in c .keys ))
129- for key in keys :
122+
123+ # First create a sorted set of the indexed keys for this control.
124+ keys = sorted (
125+ set (
126+ indexed_key
127+ for condition in op .classical_controls
128+ for indexed_key in (
129+ [(condition .key , condition .index )]
130+ if isinstance (condition , value .KeyCondition )
131+ else [(k , - 1 ) for k in condition .keys ]
132+ )
133+ )
134+ )
135+ for key , index in keys :
130136 if key not in measurement_qubits :
131137 raise ValueError (f'Deferred measurement for key={ key } not found.' )
138+ if index >= len (measurement_qubits [key ]) or index < - len (measurement_qubits [key ]):
139+ raise ValueError (f'Invalid index for { key } ' )
132140
133141 # Try every possible datastore state (exponential in the number of keys) against the
134142 # condition, and the ones that work are the control values for the new op.
@@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
140148
141149 # Rearrange these into the format expected by SumOfProducts
142150 products = [
143- [i for key in keys for i in store .records [key ][ 0 ]]
151+ [val for k , i in keys for val in store .records [k ][ i ]]
144152 for store in compatible_datastores
145153 ]
146-
147154 control_values = ops .SumOfProducts (products )
148- qs = [q for key in keys for q in measurement_qubits [key ]]
155+ qs = [q for k , i in keys for q in measurement_qubits [k ][ i ]]
149156 return op .without_classical_controls ().controlled_by (* qs , control_values = control_values )
150157 return op
151158
@@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
155162 tags_to_ignore = context .tags_to_ignore if context else (),
156163 raise_if_add_qubits = False ,
157164 ).unfreeze ()
158- for k , qubits in measurement_qubits .items ():
159- circuit .append (ops .measure (* qubits , key = k ))
165+ for k , qubits_list in measurement_qubits .items ():
166+ for qubits in qubits_list :
167+ circuit .append (ops .measure (* qubits , key = k ))
160168 return circuit
161169
162170
163171def _all_possible_datastore_states (
164- keys : Iterable ['cirq.MeasurementKey' ],
165- measurement_qubits : Mapping ['cirq.MeasurementKey' , Iterable [ 'cirq.Qid' ]],
172+ keys : Iterable [Tuple [ 'cirq.MeasurementKey' , int ] ],
173+ measurement_qubits : Dict ['cirq.MeasurementKey' , List [ Tuple [ 'cirq.Qid' , ...] ]],
166174) -> Iterable ['cirq.ClassicalDataStoreReader' ]:
167175 """The cartesian product of all possible DataStore states for the given keys."""
168176 # First we get the list of all possible values. So if we have a key mapped to qubits of shape
@@ -179,17 +187,28 @@ def _all_possible_datastore_states(
179187 # ((1, 1), (0,)),
180188 # ((1, 1), (1,)),
181189 # ((1, 1), (2,))]
182- all_values = itertools .product (
190+ all_possible_measurements = itertools .product (
183191 * [
184- tuple (itertools .product (* [range (q .dimension ) for q in measurement_qubits [k ]]))
185- for k in keys
192+ tuple (itertools .product (* [range (q .dimension ) for q in measurement_qubits [k ][ i ] ]))
193+ for k , i in keys
186194 ]
187195 )
188- # Then we create the ClassicalDataDictionaryStore for each of the above.
189- for sequences in all_values :
190- lookup = {k : [sequence ] for k , sequence in zip (keys , sequences )}
196+ # Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list`
197+ # is a single row of the above example, and can be zipped with `keys`.
198+ for measurement_list in all_possible_measurements :
199+ # Initialize a set of measurement records for this iteration. This will have the same shape
200+ # as `measurement_qubits` but zeros for all measurements.
201+ records = {
202+ key : [(0 ,) * len (qubits ) for qubits in qubits_list ]
203+ for key , qubits_list in measurement_qubits .items ()
204+ }
205+ # Set the measurement values from the current row of the above, for each key/index we care
206+ # about.
207+ for (k , i ), measurement in zip (keys , measurement_list ):
208+ records [k ][i ] = measurement
209+ # Finally yield this sample to the consumer.
191210 yield value .ClassicalDataDictionaryStore (
192- _records = lookup , _measured_qubits = { k : [ tuple ( measurement_qubits [ k ])] for k in keys }
211+ _records = records , _measured_qubits = measurement_qubits
193212 )
194213
195214
0 commit comments