23
23
from google .protobuf import json_format
24
24
from tensorflow_quantum .core .proto import program_pb2
25
25
from tensorflow_quantum .core .serialize import op_serializer
26
+ from tensorflow_quantum .python import util
26
27
27
28
28
29
def op_proto (json ):
@@ -142,8 +143,12 @@ def get_val(op):
142
143
class OpSerializerTest (tf .test .TestCase , parameterized .TestCase ):
143
144
"""Test OpSerializer functions correctly."""
144
145
145
- @parameterized .parameters (TEST_CASES )
146
- def test_to_proto_attribute (self , val_type , val , arg_value ):
146
+ @parameterized .parameters ([
147
+ CASE + (x ,)
148
+ for CASE in TEST_CASES
149
+ for x in [cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )]
150
+ ])
151
+ def test_to_proto_attribute (self , val_type , val , arg_value , q ):
147
152
"""Test proto attribute serialization works."""
148
153
serializer = op_serializer .GateOpSerializer (
149
154
gate_type = GateWithAttribute ,
@@ -153,7 +158,6 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
153
158
serialized_type = val_type ,
154
159
op_getter = 'val' )
155
160
])
156
- q = cirq .GridQubit (1 , 2 )
157
161
result = serializer .to_proto (GateWithAttribute (val )(q ),
158
162
arg_function_language = 'linear' )
159
163
expected = op_proto ({
@@ -164,13 +168,17 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
164
168
'my_val' : arg_value
165
169
},
166
170
'qubits' : [{
167
- 'id' : '1_2'
171
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
168
172
}]
169
173
})
170
174
self .assertEqual (result , expected )
171
175
172
- @parameterized .parameters (TEST_CASES )
173
- def test_to_proto_property (self , val_type , val , arg_value ):
176
+ @parameterized .parameters ([
177
+ CASE + (x ,)
178
+ for CASE in TEST_CASES
179
+ for x in [cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )]
180
+ ])
181
+ def test_to_proto_property (self , val_type , val , arg_value , q ):
174
182
"""Test proto property serialization works."""
175
183
serializer = op_serializer .GateOpSerializer (
176
184
gate_type = GateWithProperty ,
@@ -180,7 +188,6 @@ def test_to_proto_property(self, val_type, val, arg_value):
180
188
serialized_type = val_type ,
181
189
op_getter = 'val' )
182
190
])
183
- q = cirq .GridQubit (1 , 2 )
184
191
result = serializer .to_proto (GateWithProperty (val )(q ),
185
192
arg_function_language = 'linear' )
186
193
expected = op_proto ({
@@ -191,13 +198,17 @@ def test_to_proto_property(self, val_type, val, arg_value):
191
198
'my_val' : arg_value
192
199
},
193
200
'qubits' : [{
194
- 'id' : '1_2'
201
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
195
202
}]
196
203
})
197
204
self .assertEqual (result , expected )
198
205
199
- @parameterized .parameters (TEST_CASES )
200
- def test_to_proto_callable (self , val_type , val , arg_value ):
206
+ @parameterized .parameters ([
207
+ CASE + (x ,)
208
+ for CASE in TEST_CASES
209
+ for x in [cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )]
210
+ ])
211
+ def test_to_proto_callable (self , val_type , val , arg_value , q ):
201
212
"""Test callable serialization works."""
202
213
serializer = op_serializer .GateOpSerializer (
203
214
gate_type = GateWithMethod ,
@@ -207,7 +218,6 @@ def test_to_proto_callable(self, val_type, val, arg_value):
207
218
serialized_type = val_type ,
208
219
op_getter = get_val )
209
220
])
210
- q = cirq .GridQubit (1 , 2 )
211
221
result = serializer .to_proto (GateWithMethod (val )(q ),
212
222
arg_function_language = 'linear' )
213
223
expected = op_proto ({
@@ -218,12 +228,13 @@ def test_to_proto_callable(self, val_type, val, arg_value):
218
228
'my_val' : arg_value
219
229
},
220
230
'qubits' : [{
221
- 'id' : '1_2'
231
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
222
232
}]
223
233
})
224
234
self .assertEqual (result , expected )
225
235
226
- def test_to_proto_gate_predicate (self ):
236
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
237
+ def test_to_proto_gate_predicate (self , q ):
227
238
"""Test can_serialize works."""
228
239
serializer = op_serializer .GateOpSerializer (
229
240
gate_type = GateWithAttribute ,
@@ -234,15 +245,15 @@ def test_to_proto_gate_predicate(self):
234
245
op_getter = 'val' )
235
246
],
236
247
can_serialize_predicate = lambda x : x .gate .val == 1 )
237
- q = cirq .GridQubit (1 , 2 )
238
248
self .assertIsNone (serializer .to_proto (GateWithAttribute (0 )(q )))
239
249
self .assertIsNotNone (serializer .to_proto (GateWithAttribute (1 )(q )))
240
250
self .assertFalse (
241
251
serializer .can_serialize_operation (GateWithAttribute (0 )(q )))
242
252
self .assertTrue (
243
253
serializer .can_serialize_operation (GateWithAttribute (1 )(q )))
244
254
245
- def test_to_proto_gate_mismatch (self ):
255
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
256
+ def test_to_proto_gate_mismatch (self , q ):
246
257
"""Test proto gate mismatch errors."""
247
258
serializer = op_serializer .GateOpSerializer (
248
259
gate_type = GateWithProperty ,
@@ -252,13 +263,13 @@ def test_to_proto_gate_mismatch(self):
252
263
serialized_type = float ,
253
264
op_getter = 'val' )
254
265
])
255
- q = cirq .GridQubit (1 , 2 )
256
266
with self .assertRaisesRegex (
257
267
ValueError ,
258
268
expected_regex = 'GateWithAttribute.*GateWithProperty' ):
259
269
serializer .to_proto (GateWithAttribute (1.0 )(q ))
260
270
261
- def test_to_proto_unsupported_type (self ):
271
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
272
+ def test_to_proto_unsupported_type (self , q ):
262
273
"""Test proto unsupported types errors."""
263
274
serializer = op_serializer .GateOpSerializer (
264
275
gate_type = GateWithProperty ,
@@ -268,11 +279,11 @@ def test_to_proto_unsupported_type(self):
268
279
serialized_type = bytes ,
269
280
op_getter = 'val' )
270
281
])
271
- q = cirq .GridQubit (1 , 2 )
272
282
with self .assertRaisesRegex (ValueError , expected_regex = 'bytes' ):
273
283
serializer .to_proto (GateWithProperty (b's' )(q ))
274
284
275
- def test_to_proto_required_but_not_present (self ):
285
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
286
+ def test_to_proto_required_but_not_present (self , q ):
276
287
"""Test required and missing args errors."""
277
288
serializer = op_serializer .GateOpSerializer (
278
289
gate_type = GateWithProperty ,
@@ -282,11 +293,11 @@ def test_to_proto_required_but_not_present(self):
282
293
serialized_type = float ,
283
294
op_getter = lambda x : None )
284
295
])
285
- q = cirq .GridQubit (1 , 2 )
286
296
with self .assertRaisesRegex (ValueError , expected_regex = 'required' ):
287
297
serializer .to_proto (GateWithProperty (1.0 )(q ))
288
298
289
- def test_to_proto_no_getattr (self ):
299
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
300
+ def test_to_proto_no_getattr (self , q ):
290
301
"""Test no op getter fails."""
291
302
serializer = op_serializer .GateOpSerializer (
292
303
gate_type = GateWithProperty ,
@@ -296,11 +307,11 @@ def test_to_proto_no_getattr(self):
296
307
serialized_type = float ,
297
308
op_getter = 'nope' )
298
309
])
299
- q = cirq .GridQubit (1 , 2 )
300
310
with self .assertRaisesRegex (ValueError , expected_regex = 'does not have' ):
301
311
serializer .to_proto (GateWithProperty (1.0 )(q ))
302
312
303
- def test_to_proto_not_required_ok (self ):
313
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
314
+ def test_to_proto_not_required_ok (self , q ):
304
315
"""Test non require arg absense succeeds."""
305
316
serializer = op_serializer .GateOpSerializer (
306
317
gate_type = GateWithProperty ,
@@ -326,15 +337,19 @@ def test_to_proto_not_required_ok(self):
326
337
}
327
338
},
328
339
'qubits' : [{
329
- 'id' : '1_2'
340
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
330
341
}]
331
342
})
332
343
333
- q = cirq .GridQubit (1 , 2 )
334
344
self .assertEqual (serializer .to_proto (GateWithProperty (0.125 )(q )),
335
345
expected )
336
346
337
347
@parameterized .parameters ([{
348
+ ** x ,
349
+ ** {
350
+ 'q' : q
351
+ }
352
+ } for x in [{
338
353
'val_type' : float ,
339
354
'val' : 's'
340
355
}, {
@@ -352,8 +367,8 @@ def test_to_proto_not_required_ok(self):
352
367
}, {
353
368
'val_type' : List [bool ],
354
369
'val' : (1.0 ,)
355
- }])
356
- def test_to_proto_type_mismatch (self , val_type , val ):
370
+ }] for q in [ cirq . GridQubit ( 1 , 2 ), cirq . LineQubit ( 4 )]] )
371
+ def test_to_proto_type_mismatch (self , val_type , val , q ):
357
372
"""Test type mismatch fails."""
358
373
serializer = op_serializer .GateOpSerializer (
359
374
gate_type = GateWithProperty ,
@@ -363,11 +378,11 @@ def test_to_proto_type_mismatch(self, val_type, val):
363
378
serialized_type = val_type ,
364
379
op_getter = 'val' )
365
380
])
366
- q = cirq .GridQubit (1 , 2 )
367
381
with self .assertRaisesRegex (ValueError , expected_regex = str (type (val ))):
368
382
serializer .to_proto (GateWithProperty (val )(q ))
369
383
370
- def test_can_serialize_operation_subclass (self ):
384
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
385
+ def test_can_serialize_operation_subclass (self , q ):
371
386
"""Test can serialize subclass."""
372
387
serializer = op_serializer .GateOpSerializer (
373
388
gate_type = GateWithAttribute ,
@@ -378,11 +393,11 @@ def test_can_serialize_operation_subclass(self):
378
393
op_getter = 'val' )
379
394
],
380
395
can_serialize_predicate = lambda x : x .gate .val == 1 )
381
- q = cirq .GridQubit (1 , 1 )
382
396
self .assertTrue (serializer .can_serialize_operation (SubclassGate (1 )(q )))
383
397
self .assertFalse (serializer .can_serialize_operation (SubclassGate (0 )(q )))
384
398
385
- def test_defaults_not_serialized (self ):
399
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
400
+ def test_defaults_not_serialized (self , q ):
386
401
"""Test defaults not serialized."""
387
402
serializer = op_serializer .GateOpSerializer (
388
403
gate_type = GateWithAttribute ,
@@ -393,7 +408,6 @@ def test_defaults_not_serialized(self):
393
408
default = 1.0 ,
394
409
op_getter = 'val' )
395
410
])
396
- q = cirq .GridQubit (1 , 2 )
397
411
no_default = op_proto ({
398
412
'gate' : {
399
413
'id' : 'my_gate'
@@ -406,7 +420,7 @@ def test_defaults_not_serialized(self):
406
420
}
407
421
},
408
422
'qubits' : [{
409
- 'id' : '1_2'
423
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
410
424
}]
411
425
})
412
426
self .assertEqual (no_default ,
@@ -416,7 +430,7 @@ def test_defaults_not_serialized(self):
416
430
'id' : 'my_gate'
417
431
},
418
432
'qubits' : [{
419
- 'id' : '1_2'
433
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
420
434
}]
421
435
})
422
436
self .assertEqual (with_default ,
0 commit comments