Skip to content

Commit 342bcc3

Browse files
authored
Fix sign error in EjectFullW and silent merge error in eject_z_test.py (#691)
- Slipped by because of a silent merge conflict overwriting updated methods with old ones - The old methods were calling the new "canonicalize_up_to_measurement_phase" with the wrong number of arguments, triggering a type error, triggering a path expecting a type error for a *different reason* (i.e. non-unitary operations in the circuit) - Removed the non-unitary fallback path and renamed the method to make it clear only measurement-implies-terminal circuits were expected - Simplified the W-over-partial-W derivation - Fixed false-positive assertions that were now correctly failing - Added a test that fails for the old behavior but not for the new behavior Fixes #684
1 parent 2beb783 commit 342bcc3

File tree

3 files changed

+28
-24
lines changed

3 files changed

+28
-24
lines changed

cirq/google/eject_full_w.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
from typing import Optional, cast, TYPE_CHECKING, Iterable
1919

20-
from cirq import circuits, ops, extension
20+
from cirq import circuits, ops, extension, value
2121
from cirq.google import decompositions
2222
from cirq.google.xmon_gates import ExpZGate, ExpWGate, Exp11Gate
23-
from cirq.value import Symbol
2423

2524
if TYPE_CHECKING:
2625
# pylint: disable=unused-import
@@ -205,15 +204,11 @@ def _potential_cross_partial_w(moment_index: int,
205204
206205
Uses the following identity:
207206
───W(a)───W(b)^t───
208-
≡ ───Z^-a───X───Z^a───Z^-b───X^t───Z^b─── (expand Ws)
209-
≡ ───Z^-a───Z^-a───Z^b───X^t───Z^-b───X─── (move X right flipping Zs)
210-
≡ ───Z^(b-2a)───X^t───Z^-b───X─── (merge Zs)
211-
≡ ───Z^(b-2a)───X^t───Z^-(b-2a)───Z^(b-2a)───Z^-b───X─── (match left Z)
212-
≡ ───W(b-2a)^t───Z^(b-2a)───Z^-b───X─── (merge into W)
213-
≡ ───W(b-2a)^t───Z^-2a───X─── (cancel Z^b)
214-
≡ ───W(b-2a)^t───Z^-a───Z^-a───X─── (split Z^-2a)
215-
≡ ───W(b-2a)^t───Z^-a───X───Z^a─── (flip Z^-a across)
216-
≡ ───W(b-2a)^t───W(a)─── (merge W)
207+
≡ ───Z^-a───X───Z^a───W(b)^t────── (expand W(a))
208+
≡ ───Z^-a───X───W(b-a)^t───Z^a──── (move Z^a across, phasing axis)
209+
≡ ───Z^-a───W(a-b)^t───X───Z^a──── (move X across, negating axis angle)
210+
≡ ───W(2a-b)^t───Z^-a───X───Z^a─── (move Z^-a across, phasing axis)
211+
≡ ───W(2a-b)^t───W(a)───
217212
"""
218213
a = state.held_w_phases.get(op.qubits[0])
219214
if a is None:
@@ -222,7 +217,7 @@ def _potential_cross_partial_w(moment_index: int,
222217
b = cast(float, w.axis_half_turns)
223218
t = cast(float, w.half_turns)
224219
new_op = ExpWGate(half_turns=t,
225-
axis_half_turns=b-2*a).on(op.qubits[0])
220+
axis_half_turns=2*a-b).on(op.qubits[0])
226221
state.deletions.append((moment_index, op))
227222
state.inline_intos.append((moment_index, new_op))
228223

@@ -314,16 +309,16 @@ def _try_get_known_cz_half_turns(op: ops.Operation) -> Optional[float]:
314309
not isinstance(op.gate, (Exp11Gate, ops.Rot11Gate))):
315310
return None
316311
h = op.gate.half_turns
317-
if isinstance(h, Symbol):
312+
if isinstance(h, value.Symbol):
318313
return None
319314
return h
320315

321316

322317
def _try_get_known_w(op: ops.Operation) -> Optional[ExpWGate]:
323318
if (not isinstance(op, ops.GateOperation) or
324319
not isinstance(op.gate, ExpWGate) or
325-
isinstance(op.gate.half_turns, Symbol) or
326-
isinstance(op.gate.axis_half_turns, Symbol)):
320+
isinstance(op.gate.half_turns, value.Symbol) or
321+
isinstance(op.gate.axis_half_turns, value.Symbol)):
327322
return None
328323
return op.gate
329324

@@ -333,6 +328,6 @@ def _try_get_known_z_half_turns(op: ops.Operation) -> Optional[float]:
333328
not isinstance(op.gate, (ExpZGate, ops.RotZGate))):
334329
return None
335330
h = op.gate.half_turns
336-
if isinstance(h, Symbol):
331+
if isinstance(h, value.Symbol):
337332
return None
338333
return h

cirq/google/eject_full_w_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_phases_partial_ws():
247247
),
248248
expected=quick_circuit(
249249
[],
250-
[cg.ExpWGate(axis_half_turns=0.25, half_turns=0.5).on(q)],
250+
[cg.ExpWGate(axis_half_turns=-0.25, half_turns=0.5).on(q)],
251251
[cg.ExpWGate().on(q)],
252252
))
253253

@@ -258,7 +258,7 @@ def test_phases_partial_ws():
258258
),
259259
expected=quick_circuit(
260260
[],
261-
[cg.ExpWGate(axis_half_turns=-0.5, half_turns=0.5).on(q)],
261+
[cg.ExpWGate(axis_half_turns=0.5, half_turns=0.5).on(q)],
262262
[cg.ExpWGate(axis_half_turns=0.25).on(q)],
263263
))
264264

@@ -273,6 +273,17 @@ def test_phases_partial_ws():
273273
[cg.ExpWGate(axis_half_turns=0.25).on(q)],
274274
))
275275

276+
assert_optimizes(
277+
before=quick_circuit(
278+
[cg.ExpWGate().on(q)],
279+
[cg.ExpWGate(half_turns=-0.25, axis_half_turns=0.5).on(q)]
280+
),
281+
expected=quick_circuit(
282+
[],
283+
[cg.ExpWGate(half_turns=-0.25, axis_half_turns=-0.5).on(q)],
284+
[cg.ExpWGate().on(q)],
285+
))
286+
276287

277288
def test_blocked_by_unknown_and_symbols():
278289
a = cirq.NamedQubit('a')

cirq/google/eject_z.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818

1919
from collections import defaultdict
2020

21-
from cirq import ops, extension
22-
from cirq.circuits import Circuit, OptimizationPass
21+
from cirq import circuits, ops, extension, value
2322
from cirq.google.decompositions import is_negligible_turn
2423
from cirq.google.xmon_gates import ExpZGate
25-
from cirq.value import Symbol
2624

2725
if TYPE_CHECKING:
2826
# pylint: disable=unused-import
2927
from typing import Dict, List, Tuple
3028

3129

32-
class EjectZ(OptimizationPass):
30+
class EjectZ(circuits.OptimizationPass):
3331
"""Pushes Z gates towards the end of the circuit.
3432
3533
As the Z gates get pushed they may absorb other Z gates, get absorbed into
@@ -50,7 +48,7 @@ def __init__(self,
5048
self.tolerance = tolerance
5149
self.ext = ext or extension.Extensions()
5250

53-
def optimize_circuit(self, circuit: Circuit):
51+
def optimize_circuit(self, circuit: circuits.Circuit):
5452
turns_state = defaultdict(lambda: 0) # type: Dict[ops.QubitId, float]
5553

5654
def dump_phases(qubits, index):
@@ -105,6 +103,6 @@ def _try_get_known_z_half_turns(op: ops.Operation) -> Optional[float]:
105103
if not isinstance(op.gate, (ExpZGate, ops.RotZGate)):
106104
return None
107105
h = op.gate.half_turns
108-
if isinstance(h, Symbol):
106+
if isinstance(h, value.Symbol):
109107
return None
110108
return h

0 commit comments

Comments
 (0)