Skip to content

Commit faee5f9

Browse files
Add support for cirq.LineQubit to op_serializer.py
1 parent 534f65d commit faee5f9

File tree

3 files changed

+57
-37
lines changed

3 files changed

+57
-37
lines changed

tensorflow_quantum/core/serialize/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ py_test(
2727
srcs_version = "PY3",
2828
deps = [
2929
":op_serializer",
30-
"//tensorflow_quantum/core/proto:program_py_proto"
30+
"//tensorflow_quantum/core/proto:program_py_proto",
31+
"//tensorflow_quantum/python:util",
3132
],
3233
)
3334

tensorflow_quantum/core/serialize/op_serializer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""op_serializer.py adapated from Cirq release 0.9.0"""
1515

1616
from typing import List
17+
import cirq
1718
import sympy
1819
import numpy as np
1920
from tensorflow_quantum.core.proto import program_pb2
@@ -31,7 +32,11 @@
3132

3233
def qubit_to_proto(qubit):
3334
"""Return proto representation of a GridQubit."""
34-
return '{}_{}'.format(qubit.row, qubit.col)
35+
if isinstance(qubit, cirq.GridQubit):
36+
return '{}_{}'.format(qubit.row, qubit.col)
37+
if isinstance(qubit, cirq.LineQubit):
38+
return '{}'.format(qubit.x)
39+
raise ValueError('Unsupported qubit type:' + str(type(qubit)))
3540

3641

3742
def _arg_to_proto(value, *, arg_function_language, out=None):
@@ -193,7 +198,7 @@ def to_proto(
193198

194199
msg.gate.id = self.serialized_gate_id
195200
for qubit in op.qubits:
196-
msg.qubits.add().id = '{}_{}'.format(qubit.row, qubit.col)
201+
msg.qubits.add().id = qubit_to_proto(qubit)
197202
for arg in self.args:
198203
value = self._value_from_gate(op, arg)
199204
if value is not None and (not arg.default or value != arg.default):

tensorflow_quantum/core/serialize/op_serializer_test.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.protobuf import json_format
2424
from tensorflow_quantum.core.proto import program_pb2
2525
from tensorflow_quantum.core.serialize import op_serializer
26+
from tensorflow_quantum.python import util
2627

2728

2829
def op_proto(json):
@@ -142,8 +143,12 @@ def get_val(op):
142143
class OpSerializerTest(tf.test.TestCase, parameterized.TestCase):
143144
"""Test OpSerializer functions correctly."""
144145

145-
@parameterized.parameters(TEST_CASES)
146-
def test_to_proto_attribute(self, val_type, val, arg_value):
146+
@parameterized.parameters([
147+
CASE + (x,)
148+
for CASE in TEST_CASES
149+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
150+
])
151+
def test_to_proto_attribute(self, val_type, val, arg_value, q):
147152
"""Test proto attribute serialization works."""
148153
serializer = op_serializer.GateOpSerializer(
149154
gate_type=GateWithAttribute,
@@ -153,7 +158,6 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
153158
serialized_type=val_type,
154159
op_getter='val')
155160
])
156-
q = cirq.GridQubit(1, 2)
157161
result = serializer.to_proto(GateWithAttribute(val)(q),
158162
arg_function_language='linear')
159163
expected = op_proto({
@@ -164,13 +168,17 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
164168
'my_val': arg_value
165169
},
166170
'qubits': [{
167-
'id': '1_2'
171+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
168172
}]
169173
})
170174
self.assertEqual(result, expected)
171175

172-
@parameterized.parameters(TEST_CASES)
173-
def test_to_proto_property(self, val_type, val, arg_value):
176+
@parameterized.parameters([
177+
CASE + (x,)
178+
for CASE in TEST_CASES
179+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
180+
])
181+
def test_to_proto_property(self, val_type, val, arg_value, q):
174182
"""Test proto property serialization works."""
175183
serializer = op_serializer.GateOpSerializer(
176184
gate_type=GateWithProperty,
@@ -180,7 +188,6 @@ def test_to_proto_property(self, val_type, val, arg_value):
180188
serialized_type=val_type,
181189
op_getter='val')
182190
])
183-
q = cirq.GridQubit(1, 2)
184191
result = serializer.to_proto(GateWithProperty(val)(q),
185192
arg_function_language='linear')
186193
expected = op_proto({
@@ -191,13 +198,17 @@ def test_to_proto_property(self, val_type, val, arg_value):
191198
'my_val': arg_value
192199
},
193200
'qubits': [{
194-
'id': '1_2'
201+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
195202
}]
196203
})
197204
self.assertEqual(result, expected)
198205

199-
@parameterized.parameters(TEST_CASES)
200-
def test_to_proto_callable(self, val_type, val, arg_value):
206+
@parameterized.parameters([
207+
CASE + (x,)
208+
for CASE in TEST_CASES
209+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
210+
])
211+
def test_to_proto_callable(self, val_type, val, arg_value, q):
201212
"""Test callable serialization works."""
202213
serializer = op_serializer.GateOpSerializer(
203214
gate_type=GateWithMethod,
@@ -207,7 +218,6 @@ def test_to_proto_callable(self, val_type, val, arg_value):
207218
serialized_type=val_type,
208219
op_getter=get_val)
209220
])
210-
q = cirq.GridQubit(1, 2)
211221
result = serializer.to_proto(GateWithMethod(val)(q),
212222
arg_function_language='linear')
213223
expected = op_proto({
@@ -218,12 +228,13 @@ def test_to_proto_callable(self, val_type, val, arg_value):
218228
'my_val': arg_value
219229
},
220230
'qubits': [{
221-
'id': '1_2'
231+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
222232
}]
223233
})
224234
self.assertEqual(result, expected)
225235

226-
def test_to_proto_gate_predicate(self):
236+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
237+
def test_to_proto_gate_predicate(self, q):
227238
"""Test can_serialize works."""
228239
serializer = op_serializer.GateOpSerializer(
229240
gate_type=GateWithAttribute,
@@ -234,15 +245,15 @@ def test_to_proto_gate_predicate(self):
234245
op_getter='val')
235246
],
236247
can_serialize_predicate=lambda x: x.gate.val == 1)
237-
q = cirq.GridQubit(1, 2)
238248
self.assertIsNone(serializer.to_proto(GateWithAttribute(0)(q)))
239249
self.assertIsNotNone(serializer.to_proto(GateWithAttribute(1)(q)))
240250
self.assertFalse(
241251
serializer.can_serialize_operation(GateWithAttribute(0)(q)))
242252
self.assertTrue(
243253
serializer.can_serialize_operation(GateWithAttribute(1)(q)))
244254

245-
def test_to_proto_gate_mismatch(self):
255+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
256+
def test_to_proto_gate_mismatch(self, q):
246257
"""Test proto gate mismatch errors."""
247258
serializer = op_serializer.GateOpSerializer(
248259
gate_type=GateWithProperty,
@@ -252,13 +263,13 @@ def test_to_proto_gate_mismatch(self):
252263
serialized_type=float,
253264
op_getter='val')
254265
])
255-
q = cirq.GridQubit(1, 2)
256266
with self.assertRaisesRegex(
257267
ValueError,
258268
expected_regex='GateWithAttribute.*GateWithProperty'):
259269
serializer.to_proto(GateWithAttribute(1.0)(q))
260270

261-
def test_to_proto_unsupported_type(self):
271+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
272+
def test_to_proto_unsupported_type(self, q):
262273
"""Test proto unsupported types errors."""
263274
serializer = op_serializer.GateOpSerializer(
264275
gate_type=GateWithProperty,
@@ -268,11 +279,11 @@ def test_to_proto_unsupported_type(self):
268279
serialized_type=bytes,
269280
op_getter='val')
270281
])
271-
q = cirq.GridQubit(1, 2)
272282
with self.assertRaisesRegex(ValueError, expected_regex='bytes'):
273283
serializer.to_proto(GateWithProperty(b's')(q))
274284

275-
def test_to_proto_required_but_not_present(self):
285+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
286+
def test_to_proto_required_but_not_present(self, q):
276287
"""Test required and missing args errors."""
277288
serializer = op_serializer.GateOpSerializer(
278289
gate_type=GateWithProperty,
@@ -282,11 +293,11 @@ def test_to_proto_required_but_not_present(self):
282293
serialized_type=float,
283294
op_getter=lambda x: None)
284295
])
285-
q = cirq.GridQubit(1, 2)
286296
with self.assertRaisesRegex(ValueError, expected_regex='required'):
287297
serializer.to_proto(GateWithProperty(1.0)(q))
288298

289-
def test_to_proto_no_getattr(self):
299+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
300+
def test_to_proto_no_getattr(self, q):
290301
"""Test no op getter fails."""
291302
serializer = op_serializer.GateOpSerializer(
292303
gate_type=GateWithProperty,
@@ -296,11 +307,11 @@ def test_to_proto_no_getattr(self):
296307
serialized_type=float,
297308
op_getter='nope')
298309
])
299-
q = cirq.GridQubit(1, 2)
300310
with self.assertRaisesRegex(ValueError, expected_regex='does not have'):
301311
serializer.to_proto(GateWithProperty(1.0)(q))
302312

303-
def test_to_proto_not_required_ok(self):
313+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
314+
def test_to_proto_not_required_ok(self, q):
304315
"""Test non require arg absense succeeds."""
305316
serializer = op_serializer.GateOpSerializer(
306317
gate_type=GateWithProperty,
@@ -326,15 +337,19 @@ def test_to_proto_not_required_ok(self):
326337
}
327338
},
328339
'qubits': [{
329-
'id': '1_2'
340+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
330341
}]
331342
})
332343

333-
q = cirq.GridQubit(1, 2)
334344
self.assertEqual(serializer.to_proto(GateWithProperty(0.125)(q)),
335345
expected)
336346

337347
@parameterized.parameters([{
348+
**x,
349+
**{
350+
'q': q
351+
}
352+
} for x in [{
338353
'val_type': float,
339354
'val': 's'
340355
}, {
@@ -352,8 +367,8 @@ def test_to_proto_not_required_ok(self):
352367
}, {
353368
'val_type': List[bool],
354369
'val': (1.0,)
355-
}])
356-
def test_to_proto_type_mismatch(self, val_type, val):
370+
}] for q in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]])
371+
def test_to_proto_type_mismatch(self, val_type, val, q):
357372
"""Test type mismatch fails."""
358373
serializer = op_serializer.GateOpSerializer(
359374
gate_type=GateWithProperty,
@@ -363,11 +378,11 @@ def test_to_proto_type_mismatch(self, val_type, val):
363378
serialized_type=val_type,
364379
op_getter='val')
365380
])
366-
q = cirq.GridQubit(1, 2)
367381
with self.assertRaisesRegex(ValueError, expected_regex=str(type(val))):
368382
serializer.to_proto(GateWithProperty(val)(q))
369383

370-
def test_can_serialize_operation_subclass(self):
384+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
385+
def test_can_serialize_operation_subclass(self, q):
371386
"""Test can serialize subclass."""
372387
serializer = op_serializer.GateOpSerializer(
373388
gate_type=GateWithAttribute,
@@ -378,11 +393,11 @@ def test_can_serialize_operation_subclass(self):
378393
op_getter='val')
379394
],
380395
can_serialize_predicate=lambda x: x.gate.val == 1)
381-
q = cirq.GridQubit(1, 1)
382396
self.assertTrue(serializer.can_serialize_operation(SubclassGate(1)(q)))
383397
self.assertFalse(serializer.can_serialize_operation(SubclassGate(0)(q)))
384398

385-
def test_defaults_not_serialized(self):
399+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
400+
def test_defaults_not_serialized(self, q):
386401
"""Test defaults not serialized."""
387402
serializer = op_serializer.GateOpSerializer(
388403
gate_type=GateWithAttribute,
@@ -393,7 +408,6 @@ def test_defaults_not_serialized(self):
393408
default=1.0,
394409
op_getter='val')
395410
])
396-
q = cirq.GridQubit(1, 2)
397411
no_default = op_proto({
398412
'gate': {
399413
'id': 'my_gate'
@@ -406,7 +420,7 @@ def test_defaults_not_serialized(self):
406420
}
407421
},
408422
'qubits': [{
409-
'id': '1_2'
423+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
410424
}]
411425
})
412426
self.assertEqual(no_default,
@@ -416,7 +430,7 @@ def test_defaults_not_serialized(self):
416430
'id': 'my_gate'
417431
},
418432
'qubits': [{
419-
'id': '1_2'
433+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
420434
}]
421435
})
422436
self.assertEqual(with_default,

0 commit comments

Comments
 (0)