Skip to content

Commit e05f868

Browse files
jaeyooCirqBot
authored andcommitted
Add __radd__ in Circuit (#2418)
This commit adds __radd__ in cirq.Circuit() because Moment + Circuit doesn't work even if Circuit + Moment works.
1 parent 9747be7 commit e05f868

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

cirq/circuits/circuit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ def __add__(self, other):
271271
result = self.copy()
272272
return result.__iadd__(other)
273273

274+
def __radd__(self, other):
275+
# The Circuit + Circuit case is handled by __add__
276+
if not isinstance(other, (ops.Operation, Iterable)):
277+
return NotImplemented
278+
# Auto wrap OP_TREE inputs into a circuit.
279+
result = Circuit(other)
280+
return result.__iadd__(self)
281+
274282
def __imul__(self, repetitions: int):
275283
if not isinstance(repetitions, int):
276284
return NotImplemented

cirq/circuits/circuit_test.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,36 @@ def test_add_op_tree():
248248
_ = c + cirq.X
249249

250250

251+
def test_radd_op_tree():
252+
a = cirq.NamedQubit('a')
253+
b = cirq.NamedQubit('b')
254+
255+
c = cirq.Circuit()
256+
assert [cirq.X(a), cirq.Y(b)] + c == cirq.Circuit([
257+
cirq.Moment([cirq.X(a), cirq.Y(b)]),
258+
])
259+
260+
assert cirq.X(a) + c == cirq.Circuit(cirq.X(a))
261+
assert [cirq.X(a)] + c == cirq.Circuit(cirq.X(a))
262+
assert [[[cirq.X(a)], []]] + c == cirq.Circuit(cirq.X(a))
263+
assert (cirq.X(a),) + c == cirq.Circuit(cirq.X(a))
264+
assert (cirq.X(a) for _ in range(1)) + c == cirq.Circuit(cirq.X(a))
265+
with pytest.raises(AttributeError):
266+
_ = cirq.X + c
267+
with pytest.raises(TypeError):
268+
_ = 0 + c
269+
270+
# non-empty circuit addition
271+
d = cirq.Circuit()
272+
d.append(cirq.Y(b))
273+
assert [cirq.X(a)] + d == cirq.Circuit(
274+
[cirq.Moment([cirq.X(a)]),
275+
cirq.Moment([cirq.Y(b)])])
276+
assert cirq.Moment([cirq.X(a)]) + d == cirq.Circuit(
277+
[cirq.Moment([cirq.X(a)]),
278+
cirq.Moment([cirq.Y(b)])])
279+
280+
251281
def test_bool():
252282
assert not cirq.Circuit()
253283
assert cirq.Circuit(cirq.X(cirq.NamedQubit('a')))
@@ -1180,7 +1210,12 @@ def test_findall_operations_until_blocked():
11801210
assert circuit.findall_operations_until_blocked(
11811211
start_frontier={d: idx}, is_blocker=stop_if_op) == []
11821212
assert circuit.findall_operations_until_blocked(
1183-
start_frontier={a:idx, b:idx, c:idx, d: idx},
1213+
start_frontier={
1214+
a: idx,
1215+
b: idx,
1216+
c: idx,
1217+
d: idx
1218+
},
11841219
is_blocker=stop_if_op) == []
11851220

11861221
# Cases where nothing is blocked, it goes to the end

0 commit comments

Comments
 (0)