Skip to content

Commit b65e0b4

Browse files
Adjust Circuit.__add__ behavior to match Circuit.__iadd__ (#3364)
Fixes #3353. This is a breaking change: it modifies the moment structure produced by code of the form `circuit + gate`. Other circuit-addition protocols (e.g. `circuit + moment(s)` or `circuit + circuit`) are unaffected by this change. The new test fails with the old `circuit + gate` behavior.
1 parent 83c1f23 commit b65e0b4

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

cirq/circuits/circuit.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,18 +262,14 @@ def __iadd__(self, other):
262262
return self
263263

264264
def __add__(self, other):
265-
if not isinstance(other, type(self)):
266-
if not isinstance(other, (ops.Operation, Iterable)):
267-
return NotImplemented
268-
# Auto wrap OP_TREE inputs into a circuit.
269-
other = Circuit(other)
270-
271-
device = (self._device if other.device is devices.UNCONSTRAINED_DEVICE
272-
else other.device)
273-
device_2 = (other.device if self._device is devices.UNCONSTRAINED_DEVICE
274-
else self._device)
275-
if device != device_2:
276-
raise ValueError("Can't add circuits with incompatible devices.")
265+
if isinstance(other, type(self)):
266+
if (devices.UNCONSTRAINED_DEVICE not in [
267+
self._device, other.device
268+
] and self._device != other.device):
269+
raise ValueError(
270+
"Can't add circuits with incompatible devices.")
271+
elif not isinstance(other, (ops.Operation, Iterable)):
272+
return NotImplemented
277273

278274
result = self.copy()
279275
return result.__iadd__(other)

cirq/circuits/circuit_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,15 @@ def test_radd_op_tree():
268268
_ = [cirq.X(cirq.NamedQubit('a'))] + c
269269

270270

271+
def test_add_iadd_equivalence():
272+
q0, q1 = cirq.LineQubit.range(2)
273+
iadd_circuit = cirq.Circuit(cirq.X(q0))
274+
iadd_circuit += cirq.H(q1)
275+
276+
add_circuit = cirq.Circuit(cirq.X(q0)) + cirq.H(q1)
277+
assert iadd_circuit == add_circuit
278+
279+
271280
def test_bool():
272281
assert not cirq.Circuit()
273282
assert cirq.Circuit(cirq.X(cirq.NamedQubit('a')))

0 commit comments

Comments
 (0)