Skip to content

Commit 3344433

Browse files
authored
Merge pull request #632 from tensorflow/line_ser
Add support for cirq.LineQubit to op_serializer.py
2 parents 8941f33 + afb0608 commit 3344433

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

tensorflow_quantum/core/serialize/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ py_test(
2727
srcs_version = "PY3",
2828
deps = [
2929
":op_serializer",
30-
"//tensorflow_quantum/core/proto:program_py_proto"
30+
"//tensorflow_quantum/core/proto:program_py_proto",
3131
],
3232
)
3333

tensorflow_quantum/core/serialize/op_serializer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""op_serializer.py adapated from Cirq release 0.9.0"""
1515

1616
from typing import List
17+
import cirq
1718
import sympy
1819
import numpy as np
1920
from tensorflow_quantum.core.proto import program_pb2
@@ -31,7 +32,11 @@
3132

3233
def qubit_to_proto(qubit):
3334
"""Return proto representation of a GridQubit."""
34-
return '{}_{}'.format(qubit.row, qubit.col)
35+
if isinstance(qubit, cirq.GridQubit):
36+
return '{}_{}'.format(qubit.row, qubit.col)
37+
if isinstance(qubit, cirq.LineQubit):
38+
return '{}'.format(qubit.x)
39+
raise ValueError('Unsupported qubit type:' + str(type(qubit)))
3540

3641

3742
def _arg_to_proto(value, *, arg_function_language, out=None):
@@ -193,7 +198,7 @@ def to_proto(
193198

194199
msg.gate.id = self.serialized_gate_id
195200
for qubit in op.qubits:
196-
msg.qubits.add().id = '{}_{}'.format(qubit.row, qubit.col)
201+
msg.qubits.add().id = qubit_to_proto(qubit)
197202
for arg in self.args:
198203
value = self._value_from_gate(op, arg)
199204
if value is not None and (not arg.default or value != arg.default):

tensorflow_quantum/core/serialize/op_serializer_test.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,12 @@ def get_val(op):
142142
class OpSerializerTest(tf.test.TestCase, parameterized.TestCase):
143143
"""Test OpSerializer functions correctly."""
144144

145-
@parameterized.parameters(TEST_CASES)
146-
def test_to_proto_attribute(self, val_type, val, arg_value):
145+
@parameterized.parameters([
146+
CASE + (x,)
147+
for CASE in TEST_CASES
148+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
149+
])
150+
def test_to_proto_attribute(self, val_type, val, arg_value, q):
147151
"""Test proto attribute serialization works."""
148152
serializer = op_serializer.GateOpSerializer(
149153
gate_type=GateWithAttribute,
@@ -153,7 +157,6 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
153157
serialized_type=val_type,
154158
op_getter='val')
155159
])
156-
q = cirq.GridQubit(1, 2)
157160
result = serializer.to_proto(GateWithAttribute(val)(q),
158161
arg_function_language='linear')
159162
expected = op_proto({
@@ -164,13 +167,17 @@ def test_to_proto_attribute(self, val_type, val, arg_value):
164167
'my_val': arg_value
165168
},
166169
'qubits': [{
167-
'id': '1_2'
170+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
168171
}]
169172
})
170173
self.assertEqual(result, expected)
171174

172-
@parameterized.parameters(TEST_CASES)
173-
def test_to_proto_property(self, val_type, val, arg_value):
175+
@parameterized.parameters([
176+
CASE + (x,)
177+
for CASE in TEST_CASES
178+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
179+
])
180+
def test_to_proto_property(self, val_type, val, arg_value, q):
174181
"""Test proto property serialization works."""
175182
serializer = op_serializer.GateOpSerializer(
176183
gate_type=GateWithProperty,
@@ -180,7 +187,6 @@ def test_to_proto_property(self, val_type, val, arg_value):
180187
serialized_type=val_type,
181188
op_getter='val')
182189
])
183-
q = cirq.GridQubit(1, 2)
184190
result = serializer.to_proto(GateWithProperty(val)(q),
185191
arg_function_language='linear')
186192
expected = op_proto({
@@ -191,13 +197,17 @@ def test_to_proto_property(self, val_type, val, arg_value):
191197
'my_val': arg_value
192198
},
193199
'qubits': [{
194-
'id': '1_2'
200+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
195201
}]
196202
})
197203
self.assertEqual(result, expected)
198204

199-
@parameterized.parameters(TEST_CASES)
200-
def test_to_proto_callable(self, val_type, val, arg_value):
205+
@parameterized.parameters([
206+
CASE + (x,)
207+
for CASE in TEST_CASES
208+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
209+
])
210+
def test_to_proto_callable(self, val_type, val, arg_value, q):
201211
"""Test callable serialization works."""
202212
serializer = op_serializer.GateOpSerializer(
203213
gate_type=GateWithMethod,
@@ -207,7 +217,6 @@ def test_to_proto_callable(self, val_type, val, arg_value):
207217
serialized_type=val_type,
208218
op_getter=get_val)
209219
])
210-
q = cirq.GridQubit(1, 2)
211220
result = serializer.to_proto(GateWithMethod(val)(q),
212221
arg_function_language='linear')
213222
expected = op_proto({
@@ -218,12 +227,13 @@ def test_to_proto_callable(self, val_type, val, arg_value):
218227
'my_val': arg_value
219228
},
220229
'qubits': [{
221-
'id': '1_2'
230+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
222231
}]
223232
})
224233
self.assertEqual(result, expected)
225234

226-
def test_to_proto_gate_predicate(self):
235+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
236+
def test_to_proto_gate_predicate(self, q):
227237
"""Test can_serialize works."""
228238
serializer = op_serializer.GateOpSerializer(
229239
gate_type=GateWithAttribute,
@@ -234,15 +244,15 @@ def test_to_proto_gate_predicate(self):
234244
op_getter='val')
235245
],
236246
can_serialize_predicate=lambda x: x.gate.val == 1)
237-
q = cirq.GridQubit(1, 2)
238247
self.assertIsNone(serializer.to_proto(GateWithAttribute(0)(q)))
239248
self.assertIsNotNone(serializer.to_proto(GateWithAttribute(1)(q)))
240249
self.assertFalse(
241250
serializer.can_serialize_operation(GateWithAttribute(0)(q)))
242251
self.assertTrue(
243252
serializer.can_serialize_operation(GateWithAttribute(1)(q)))
244253

245-
def test_to_proto_gate_mismatch(self):
254+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
255+
def test_to_proto_gate_mismatch(self, q):
246256
"""Test proto gate mismatch errors."""
247257
serializer = op_serializer.GateOpSerializer(
248258
gate_type=GateWithProperty,
@@ -252,13 +262,13 @@ def test_to_proto_gate_mismatch(self):
252262
serialized_type=float,
253263
op_getter='val')
254264
])
255-
q = cirq.GridQubit(1, 2)
256265
with self.assertRaisesRegex(
257266
ValueError,
258267
expected_regex='GateWithAttribute.*GateWithProperty'):
259268
serializer.to_proto(GateWithAttribute(1.0)(q))
260269

261-
def test_to_proto_unsupported_type(self):
270+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
271+
def test_to_proto_unsupported_type(self, q):
262272
"""Test proto unsupported types errors."""
263273
serializer = op_serializer.GateOpSerializer(
264274
gate_type=GateWithProperty,
@@ -268,11 +278,11 @@ def test_to_proto_unsupported_type(self):
268278
serialized_type=bytes,
269279
op_getter='val')
270280
])
271-
q = cirq.GridQubit(1, 2)
272281
with self.assertRaisesRegex(ValueError, expected_regex='bytes'):
273282
serializer.to_proto(GateWithProperty(b's')(q))
274283

275-
def test_to_proto_required_but_not_present(self):
284+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
285+
def test_to_proto_required_but_not_present(self, q):
276286
"""Test required and missing args errors."""
277287
serializer = op_serializer.GateOpSerializer(
278288
gate_type=GateWithProperty,
@@ -282,11 +292,11 @@ def test_to_proto_required_but_not_present(self):
282292
serialized_type=float,
283293
op_getter=lambda x: None)
284294
])
285-
q = cirq.GridQubit(1, 2)
286295
with self.assertRaisesRegex(ValueError, expected_regex='required'):
287296
serializer.to_proto(GateWithProperty(1.0)(q))
288297

289-
def test_to_proto_no_getattr(self):
298+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
299+
def test_to_proto_no_getattr(self, q):
290300
"""Test no op getter fails."""
291301
serializer = op_serializer.GateOpSerializer(
292302
gate_type=GateWithProperty,
@@ -296,11 +306,11 @@ def test_to_proto_no_getattr(self):
296306
serialized_type=float,
297307
op_getter='nope')
298308
])
299-
q = cirq.GridQubit(1, 2)
300309
with self.assertRaisesRegex(ValueError, expected_regex='does not have'):
301310
serializer.to_proto(GateWithProperty(1.0)(q))
302311

303-
def test_to_proto_not_required_ok(self):
312+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
313+
def test_to_proto_not_required_ok(self, q):
304314
"""Test non require arg absense succeeds."""
305315
serializer = op_serializer.GateOpSerializer(
306316
gate_type=GateWithProperty,
@@ -326,15 +336,19 @@ def test_to_proto_not_required_ok(self):
326336
}
327337
},
328338
'qubits': [{
329-
'id': '1_2'
339+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
330340
}]
331341
})
332342

333-
q = cirq.GridQubit(1, 2)
334343
self.assertEqual(serializer.to_proto(GateWithProperty(0.125)(q)),
335344
expected)
336345

337346
@parameterized.parameters([{
347+
**x,
348+
**{
349+
'q': q
350+
}
351+
} for x in [{
338352
'val_type': float,
339353
'val': 's'
340354
}, {
@@ -352,8 +366,8 @@ def test_to_proto_not_required_ok(self):
352366
}, {
353367
'val_type': List[bool],
354368
'val': (1.0,)
355-
}])
356-
def test_to_proto_type_mismatch(self, val_type, val):
369+
}] for q in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]])
370+
def test_to_proto_type_mismatch(self, val_type, val, q):
357371
"""Test type mismatch fails."""
358372
serializer = op_serializer.GateOpSerializer(
359373
gate_type=GateWithProperty,
@@ -363,11 +377,11 @@ def test_to_proto_type_mismatch(self, val_type, val):
363377
serialized_type=val_type,
364378
op_getter='val')
365379
])
366-
q = cirq.GridQubit(1, 2)
367380
with self.assertRaisesRegex(ValueError, expected_regex=str(type(val))):
368381
serializer.to_proto(GateWithProperty(val)(q))
369382

370-
def test_can_serialize_operation_subclass(self):
383+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
384+
def test_can_serialize_operation_subclass(self, q):
371385
"""Test can serialize subclass."""
372386
serializer = op_serializer.GateOpSerializer(
373387
gate_type=GateWithAttribute,
@@ -378,11 +392,11 @@ def test_can_serialize_operation_subclass(self):
378392
op_getter='val')
379393
],
380394
can_serialize_predicate=lambda x: x.gate.val == 1)
381-
q = cirq.GridQubit(1, 1)
382395
self.assertTrue(serializer.can_serialize_operation(SubclassGate(1)(q)))
383396
self.assertFalse(serializer.can_serialize_operation(SubclassGate(0)(q)))
384397

385-
def test_defaults_not_serialized(self):
398+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
399+
def test_defaults_not_serialized(self, q):
386400
"""Test defaults not serialized."""
387401
serializer = op_serializer.GateOpSerializer(
388402
gate_type=GateWithAttribute,
@@ -393,7 +407,6 @@ def test_defaults_not_serialized(self):
393407
default=1.0,
394408
op_getter='val')
395409
])
396-
q = cirq.GridQubit(1, 2)
397410
no_default = op_proto({
398411
'gate': {
399412
'id': 'my_gate'
@@ -406,7 +419,7 @@ def test_defaults_not_serialized(self):
406419
}
407420
},
408421
'qubits': [{
409-
'id': '1_2'
422+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
410423
}]
411424
})
412425
self.assertEqual(no_default,
@@ -416,7 +429,7 @@ def test_defaults_not_serialized(self):
416429
'id': 'my_gate'
417430
},
418431
'qubits': [{
419-
'id': '1_2'
432+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
420433
}]
421434
})
422435
self.assertEqual(with_default,

0 commit comments

Comments
 (0)