Skip to content

Commit d2aee41

Browse files
authored
Fix qasm generation/parsing for classical controls (#5434)
1 parent ebd2d6d commit d2aee41

File tree

7 files changed

+140
-22
lines changed

7 files changed

+140
-22
lines changed

cirq-core/cirq/contrib/qasm_import/_lexer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self):
3232
'measure': 'MEASURE',
3333
'if': 'IF',
3434
'->': 'ARROW',
35-
'!=': 'NE',
35+
'==': 'EQ',
3636
}
3737

3838
tokens = ['FORMAT_SPEC', 'NUMBER', 'NATURAL_NUMBER', 'QELIBINC', 'ID', 'PI'] + list(
@@ -103,8 +103,8 @@ def t_ARROW(self, t):
103103
"""->"""
104104
return t
105105

106-
def t_NE(self, t):
107-
"""!="""
106+
def t_EQ(self, t):
107+
"""=="""
108108
return t
109109

110110
def t_ID(self, t):

cirq-core/cirq/contrib/qasm_import/_parser.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
1717

1818
import numpy as np
19+
import sympy
1920
from ply import yacc
2021

2122
from cirq import ops, Circuit, NamedQubit, CX
@@ -496,11 +497,19 @@ def p_measurement(self, p):
496497
]
497498

498499
# if operations
499-
# if : IF '(' carg NE NATURAL_NUMBER ')' ID qargs
500+
# if : IF '(' carg EQ NATURAL_NUMBER ')' ID qargs
500501

501502
def p_if(self, p):
502-
"""if : IF '(' carg NE NATURAL_NUMBER ')' gate_op"""
503-
p[0] = [ops.ClassicallyControlledOperation(conditions=p[3], sub_operation=tuple(p[7])[0])]
503+
"""if : IF '(' carg EQ NATURAL_NUMBER ')' gate_op"""
504+
# We have to split the register into bits (since that's what measurement does above),
505+
# and create one condition per bit, checking against that part of the binary value.
506+
conditions = []
507+
for i, key in enumerate(p[3]):
508+
v = (p[5] >> i) & 1
509+
conditions.append(sympy.Eq(sympy.Symbol(key), v))
510+
p[0] = [
511+
ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0])
512+
]
504513

505514
def p_error(self, p):
506515
if p is None:

cirq-core/cirq/contrib/qasm_import/_parser_test.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,16 +215,17 @@ def test_CX_gate():
215215
def test_classical_control():
216216
qasm = """OPENQASM 2.0;
217217
qreg q[2];
218-
creg m_a[1];
219-
measure q[0] -> m_a[0];
220-
if (m_a!=0) CX q[0], q[1];
218+
creg a[1];
219+
measure q[0] -> a[0];
220+
if (a==1) CX q[0],q[1];
221221
"""
222222
parser = QasmParser()
223223

224224
q_0 = cirq.NamedQubit('q_0')
225225
q_1 = cirq.NamedQubit('q_1')
226226
expected_circuit = cirq.Circuit(
227-
cirq.measure(q_0, key='m_a_0'), cirq.CNOT(q_0, q_1).with_classical_controls('m_a_0')
227+
cirq.measure(q_0, key='a_0'),
228+
cirq.CNOT(q_0, q_1).with_classical_controls(sympy.Eq(sympy.Symbol('a_0'), 1)),
228229
)
229230

230231
parsed_qasm = parser.parse(qasm)
@@ -235,6 +236,62 @@ def test_classical_control():
235236
ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit)
236237
assert parsed_qasm.qregs == {'q': 2}
237238

239+
# Note this cannot *exactly* round-trip because the way QASM and Cirq handle measurements
240+
# into classical registers is different. Cirq parses QASM classical registers into m_a_i for i
241+
# in 0..bit_count. Thus the generated key has an extra "_0" at the end.
242+
expected_generated_qasm = f"""// Generated from Cirq v{cirq.__version__}
243+
244+
OPENQASM 2.0;
245+
include "qelib1.inc";
246+
247+
248+
// Qubits: [q_0, q_1]
249+
qreg q[2];
250+
creg m_a_0[1];
251+
252+
253+
measure q[0] -> m_a_0[0];
254+
if (m_a_0==1) cx q[0],q[1];
255+
"""
256+
assert cirq.qasm(parsed_qasm.circuit) == expected_generated_qasm
257+
258+
259+
def test_classical_control_multi_bit():
260+
qasm = """OPENQASM 2.0;
261+
qreg q[2];
262+
creg a[2];
263+
measure q[0] -> a[0];
264+
measure q[0] -> a[1];
265+
if (a==1) CX q[0],q[1];
266+
"""
267+
parser = QasmParser()
268+
269+
q_0 = cirq.NamedQubit('q_0')
270+
q_1 = cirq.NamedQubit('q_1')
271+
272+
# Since we split the measurement into two, we also need two conditions.
273+
# m_a==1 corresponds to m_a[0]==1, m_a[1]==0
274+
expected_circuit = cirq.Circuit(
275+
cirq.measure(q_0, key='a_0'),
276+
cirq.measure(q_0, key='a_1'),
277+
cirq.CNOT(q_0, q_1).with_classical_controls(
278+
sympy.Eq(sympy.Symbol('a_0'), 1), sympy.Eq(sympy.Symbol('a_1'), 0)
279+
),
280+
)
281+
282+
parsed_qasm = parser.parse(qasm)
283+
284+
assert parsed_qasm.supportedFormat
285+
assert not parsed_qasm.qelib1Include
286+
287+
ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit)
288+
assert parsed_qasm.qregs == {'q': 2}
289+
290+
# Note that this will *not* round-trip, but there's no good way around that due to the
291+
# difference in how Cirq and QASM do multi-bit measurements.
292+
with pytest.raises(ValueError, match='QASM does not support multiple conditions'):
293+
_ = cirq.qasm(parsed_qasm.circuit)
294+
238295

239296
def test_CX_gate_not_enough_args():
240297
qasm = """OPENQASM 2.0;

cirq-core/cirq/ops/classically_controlled_operation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,5 +207,9 @@ def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
207207

208208
def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
209209
args.validate_version('2.0')
210-
all_keys = " && ".join(c.qasm for c in self._conditions)
211-
return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args))
210+
if len(self._conditions) > 1:
211+
raise ValueError('QASM does not support multiple conditions.')
212+
subop_qasm = protocols.qasm(self._sub_operation, args=args)
213+
if not self._conditions:
214+
return subop_qasm
215+
return f'if ({self._conditions[0].qasm}) {subop_qasm}'

cirq-core/cirq/ops/classically_controlled_operation_test.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,14 @@ def test_diagram_subcircuit_layered():
198198

199199
def test_qasm():
200200
q0, q1 = cirq.LineQubit.range(2)
201-
circuit = cirq.Circuit(cirq.measure(q0, key='a'), cirq.X(q1).with_classical_controls('a'))
201+
circuit = cirq.Circuit(
202+
cirq.measure(q0, key='a'),
203+
cirq.X(q1).with_classical_controls(sympy.Eq(sympy.Symbol('a'), 0)),
204+
)
202205
qasm = cirq.qasm(circuit)
203206
assert (
204207
qasm
205-
== """// Generated from Cirq v0.15.0.dev
208+
== f"""// Generated from Cirq v{cirq.__version__}
206209
207210
OPENQASM 2.0;
208211
include "qelib1.inc";
@@ -214,11 +217,49 @@ def test_qasm():
214217
215218
216219
measure q[0] -> m_a[0];
217-
if (m_a!=0) x q[1];
220+
if (m_a==0) x q[1];
218221
"""
219222
)
220223

221224

225+
def test_qasm_no_conditions():
226+
q0, q1 = cirq.LineQubit.range(2)
227+
circuit = cirq.Circuit(
228+
cirq.measure(q0, key='a'), cirq.ClassicallyControlledOperation(cirq.X(q1), [])
229+
)
230+
qasm = cirq.qasm(circuit)
231+
assert (
232+
qasm
233+
== f"""// Generated from Cirq v{cirq.__version__}
234+
235+
OPENQASM 2.0;
236+
include "qelib1.inc";
237+
238+
239+
// Qubits: [q(0), q(1)]
240+
qreg q[2];
241+
creg m_a[1];
242+
243+
244+
measure q[0] -> m_a[0];
245+
x q[1];
246+
"""
247+
)
248+
249+
250+
def test_qasm_multiple_conditions():
251+
q0, q1 = cirq.LineQubit.range(2)
252+
circuit = cirq.Circuit(
253+
cirq.measure(q0, key='a'),
254+
cirq.measure(q0, key='b'),
255+
cirq.X(q1).with_classical_controls(
256+
sympy.Eq(sympy.Symbol('a'), 0), sympy.Eq(sympy.Symbol('b'), 0)
257+
),
258+
)
259+
with pytest.raises(ValueError, match='QASM does not support multiple conditions'):
260+
_ = cirq.qasm(circuit)
261+
262+
222263
@pytest.mark.parametrize('sim', ALL_SIMULATORS)
223264
def test_key_unset(sim):
224265
q0, q1 = cirq.LineQubit.range(2)

cirq-core/cirq/value/condition.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def _from_json_dict_(cls, key, **kwargs):
113113

114114
@property
115115
def qasm(self):
116-
if self.index != -1:
117-
raise NotImplementedError('Only most recent measurement at key can be used for QASM.')
118-
return f'm_{self.key}!=0'
116+
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')
119117

120118

121119
@dataclasses.dataclass(frozen=True)
@@ -162,4 +160,8 @@ def _from_json_dict_(cls, expr, **kwargs):
162160

163161
@property
164162
def qasm(self):
165-
raise NotImplementedError()
163+
if isinstance(self.expr, sympy.Equality):
164+
if isinstance(self.expr.lhs, sympy.Symbol) and isinstance(self.expr.rhs, sympy.Integer):
165+
# Measurements get prepended with "m_", so the condition needs to be too.
166+
return f'm_{self.expr.lhs}=={self.expr.rhs}'
167+
raise ValueError('QASM is defined only for SympyConditions of type key == constant.')

cirq-core/cirq/value/condition_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def resolve(records):
6767

6868

6969
def test_key_condition_qasm():
70-
assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0'
70+
with pytest.raises(ValueError, match='QASM is defined only for SympyConditions'):
71+
_ = cirq.KeyCondition(cirq.MeasurementKey('a')).qasm
7172

7273

7374
def test_sympy_condition_with_keys():
@@ -111,5 +112,9 @@ def resolve(records):
111112

112113

113114
def test_sympy_condition_qasm():
114-
with pytest.raises(NotImplementedError):
115-
_ = init_sympy_condition.qasm
115+
# Measurements get prepended with "m_", so the condition needs to be too.
116+
assert cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), 2)).qasm == 'm_a==2'
117+
with pytest.raises(
118+
ValueError, match='QASM is defined only for SympyConditions of type key == constant'
119+
):
120+
_ = cirq.SympyCondition(sympy.Symbol('a') != 2).qasm

0 commit comments

Comments
 (0)