Skip to content

Commit 8941f33

Browse files
authored
Merge pull request #633 from tensorflow/line_deser
Add cirq.LineQubit support to op_deserializer.py.
2 parents 1ab27e0 + 171633a commit 8941f33

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

tensorflow_quantum/core/serialize/op_deserializer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import cirq
2020

2121
GRID_QUBIT_ID_PATTERN = r'^q?(-?\d+)_(-?\d+)$'
22+
LINE_QUBIT_ID_PATTERN = r'^q?(-?\d+)$'
2223
SUPPORTED_FUNCTIONS_FOR_LANGUAGE = {
2324
'': frozenset(),
2425
'linear': frozenset({'add', 'mul'}),
@@ -42,11 +43,17 @@ def qubit_from_proto(proto_id):
4243
"""
4344

4445
match = re.match(GRID_QUBIT_ID_PATTERN, proto_id)
45-
if match is None:
46-
raise ValueError(
47-
f'Expected GridQubit proto w/ form [q]<int>_<int>, got {proto_id}')
48-
row, col = match.groups()
49-
return cirq.GridQubit(row=int(row), col=int(col))
46+
if match is not None:
47+
row, col = match.groups()
48+
return cirq.GridQubit(row=int(row), col=int(col))
49+
50+
match = re.match(LINE_QUBIT_ID_PATTERN, proto_id)
51+
if match is not None:
52+
x, = match.groups()
53+
return cirq.LineQubit(int(x))
54+
55+
raise ValueError('Expected GridQubit proto w/ form [q]<int>_<int>,'
56+
f' or LineQubit w/ form [q]<int> got {proto_id}')
5057

5158

5259
def _arg_from_proto(

tensorflow_quantum/core/serialize/op_deserializer_test.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,12 @@ def _value_equality_values_(self):
9696
class OpDeserializerTest(tf.test.TestCase, parameterized.TestCase):
9797
"""Test OpDeserializer functionality."""
9898

99-
@parameterized.parameters(TEST_CASES)
100-
def test_from_proto(self, val_type, val, arg_value):
99+
@parameterized.parameters([
100+
CASE + (x,)
101+
for CASE in TEST_CASES
102+
for x in [cirq.GridQubit(1, 2), cirq.LineQubit(4)]
103+
])
104+
def test_from_proto(self, val_type, val, arg_value, q):
101105
"""Test from proto under many cases."""
102106
deserializer = op_deserializer.GateOpDeserializer(
103107
serialized_gate_id='my_gate',
@@ -116,10 +120,9 @@ def test_from_proto(self, val_type, val, arg_value):
116120
'my_val': arg_value
117121
},
118122
'qubits': [{
119-
'id': '1_2'
123+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
120124
}]
121125
})
122-
q = cirq.GridQubit(1, 2)
123126
result = deserializer.from_proto(serialized,
124127
arg_function_language='linear')
125128
self.assertEqual(result, GateWithAttribute(val)(q))
@@ -260,7 +263,8 @@ def test_from_proto_function_argument_not_set(self):
260263
_ = deserializer.from_proto(serialized,
261264
arg_function_language='linear')
262265

263-
def test_from_proto_value_func(self):
266+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
267+
def test_from_proto_value_func(self, q):
264268
"""Test value func deserialization in simple case."""
265269
deserializer = op_deserializer.GateOpDeserializer(
266270
serialized_gate_id='my_gate',
@@ -282,14 +286,14 @@ def test_from_proto_value_func(self):
282286
}
283287
},
284288
'qubits': [{
285-
'id': '1_2'
289+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
286290
}]
287291
})
288-
q = cirq.GridQubit(1, 2)
289292
result = deserializer.from_proto(serialized)
290293
self.assertEqual(result, GateWithAttribute(1.125)(q))
291294

292-
def test_from_proto_not_required_ok(self):
295+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
296+
def test_from_proto_not_required_ok(self, q):
293297
"""Deserialization succeeds for missing not required fields."""
294298
deserializer = op_deserializer.GateOpDeserializer(
295299
serialized_gate_id='my_gate',
@@ -315,13 +319,13 @@ def test_from_proto_not_required_ok(self):
315319
}
316320
},
317321
'qubits': [{
318-
'id': '1_2'
322+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
319323
}]
320324
})
321-
q = cirq.GridQubit(1, 2)
322325
result = deserializer.from_proto(serialized)
323326
self.assertEqual(result, GateWithAttribute(0.125)(q))
324327

328+
325329
def test_from_proto_missing_required_arg(self):
326330
"""Error raised when required field is missing."""
327331
deserializer = op_deserializer.GateOpDeserializer(
@@ -382,7 +386,8 @@ def test_from_proto_required_arg_not_assigned(self):
382386
with self.assertRaises(ValueError):
383387
deserializer.from_proto(serialized)
384388

385-
def test_defaults(self):
389+
@parameterized.parameters([cirq.GridQubit(1, 2), cirq.LineQubit(4)])
390+
def test_defaults(self, q):
386391
"""Ensure default values still deserialize."""
387392
deserializer = op_deserializer.GateOpDeserializer(
388393
serialized_gate_id='my_gate',
@@ -402,13 +407,12 @@ def test_defaults(self):
402407
},
403408
'args': {},
404409
'qubits': [{
405-
'id': '1_2'
410+
'id': '1_2' if isinstance(q, cirq.GridQubit) else '4'
406411
}]
407412
})
408413
g = GateWithAttribute(1.0)
409414
g.not_req = 'hello'
410-
self.assertEqual(deserializer.from_proto(serialized),
411-
g(cirq.GridQubit(1, 2)))
415+
self.assertEqual(deserializer.from_proto(serialized), g(q))
412416

413417

414418
if __name__ == "__main__":

0 commit comments

Comments
 (0)