@@ -96,8 +96,12 @@ def _value_equality_values_(self):
96
96
class OpDeserializerTest (tf .test .TestCase , parameterized .TestCase ):
97
97
"""Test OpDeserializer functionality."""
98
98
99
- @parameterized .parameters (TEST_CASES )
100
- def test_from_proto (self , val_type , val , arg_value ):
99
+ @parameterized .parameters ([
100
+ CASE + (x ,)
101
+ for CASE in TEST_CASES
102
+ for x in [cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )]
103
+ ])
104
+ def test_from_proto (self , val_type , val , arg_value , q ):
101
105
"""Test from proto under many cases."""
102
106
deserializer = op_deserializer .GateOpDeserializer (
103
107
serialized_gate_id = 'my_gate' ,
@@ -116,10 +120,9 @@ def test_from_proto(self, val_type, val, arg_value):
116
120
'my_val' : arg_value
117
121
},
118
122
'qubits' : [{
119
- 'id' : '1_2'
123
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
120
124
}]
121
125
})
122
- q = cirq .GridQubit (1 , 2 )
123
126
result = deserializer .from_proto (serialized ,
124
127
arg_function_language = 'linear' )
125
128
self .assertEqual (result , GateWithAttribute (val )(q ))
@@ -260,7 +263,8 @@ def test_from_proto_function_argument_not_set(self):
260
263
_ = deserializer .from_proto (serialized ,
261
264
arg_function_language = 'linear' )
262
265
263
- def test_from_proto_value_func (self ):
266
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
267
+ def test_from_proto_value_func (self , q ):
264
268
"""Test value func deserialization in simple case."""
265
269
deserializer = op_deserializer .GateOpDeserializer (
266
270
serialized_gate_id = 'my_gate' ,
@@ -282,14 +286,14 @@ def test_from_proto_value_func(self):
282
286
}
283
287
},
284
288
'qubits' : [{
285
- 'id' : '1_2'
289
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
286
290
}]
287
291
})
288
- q = cirq .GridQubit (1 , 2 )
289
292
result = deserializer .from_proto (serialized )
290
293
self .assertEqual (result , GateWithAttribute (1.125 )(q ))
291
294
292
- def test_from_proto_not_required_ok (self ):
295
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
296
+ def test_from_proto_not_required_ok (self , q ):
293
297
"""Deserialization succeeds for missing not required fields."""
294
298
deserializer = op_deserializer .GateOpDeserializer (
295
299
serialized_gate_id = 'my_gate' ,
@@ -315,13 +319,13 @@ def test_from_proto_not_required_ok(self):
315
319
}
316
320
},
317
321
'qubits' : [{
318
- 'id' : '1_2'
322
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
319
323
}]
320
324
})
321
- q = cirq .GridQubit (1 , 2 )
322
325
result = deserializer .from_proto (serialized )
323
326
self .assertEqual (result , GateWithAttribute (0.125 )(q ))
324
327
328
+
325
329
def test_from_proto_missing_required_arg (self ):
326
330
"""Error raised when required field is missing."""
327
331
deserializer = op_deserializer .GateOpDeserializer (
@@ -382,7 +386,8 @@ def test_from_proto_required_arg_not_assigned(self):
382
386
with self .assertRaises (ValueError ):
383
387
deserializer .from_proto (serialized )
384
388
385
- def test_defaults (self ):
389
+ @parameterized .parameters ([cirq .GridQubit (1 , 2 ), cirq .LineQubit (4 )])
390
+ def test_defaults (self , q ):
386
391
"""Ensure default values still deserialize."""
387
392
deserializer = op_deserializer .GateOpDeserializer (
388
393
serialized_gate_id = 'my_gate' ,
@@ -402,13 +407,12 @@ def test_defaults(self):
402
407
},
403
408
'args' : {},
404
409
'qubits' : [{
405
- 'id' : '1_2'
410
+ 'id' : '1_2' if isinstance ( q , cirq . GridQubit ) else '4'
406
411
}]
407
412
})
408
413
g = GateWithAttribute (1.0 )
409
414
g .not_req = 'hello'
410
- self .assertEqual (deserializer .from_proto (serialized ),
411
- g (cirq .GridQubit (1 , 2 )))
415
+ self .assertEqual (deserializer .from_proto (serialized ), g (q ))
412
416
413
417
414
418
if __name__ == "__main__" :
0 commit comments