2020
2121import cirq
2222from cirq .type_workarounds import NotImplementedType
23- from cirq .ops import AbstractControlValues
2423
2524
2625class GateUsingWorkspaceForApplyUnitary (cirq .testing .SingleQubitGate ):
@@ -89,14 +88,25 @@ def __str__(self):
8988
9089C0Y = cirq .ControlledGate (cirq .Y , control_values = [0 ])
9190C0C1H = cirq .ControlledGate (cirq .ControlledGate (cirq .H , control_values = [1 ]), control_values = [0 ])
91+
92+ nand_control_values = cirq .SumOfProducts ([(0 , 1 ), (1 , 0 ), (1 , 1 )])
93+ xor_control_values = cirq .SumOfProducts ([[0 , 1 ], [1 , 0 ]], name = "xor" )
94+ C_01_10_11H = cirq .ControlledGate (cirq .H , control_values = nand_control_values )
95+ C_xorH = cirq .ControlledGate (cirq .H , control_values = xor_control_values )
96+ C0C_xorH = cirq .ControlledGate (C_xorH , control_values = [0 ])
97+
9298C0Restricted = cirq .ControlledGate (RestrictedGate (), control_values = [0 ])
99+ C_xorRestricted = cirq .ControlledGate (RestrictedGate (), control_values = xor_control_values )
93100
94101C2Y = cirq .ControlledGate (cirq .Y , control_values = [2 ], control_qid_shape = (3 ,))
95102C2C2H = cirq .ControlledGate (
96103 cirq .ControlledGate (cirq .H , control_values = [2 ], control_qid_shape = (3 ,)),
97104 control_values = [2 ],
98105 control_qid_shape = (3 ,),
99106)
107+ C_02_20H = cirq .ControlledGate (
108+ cirq .H , control_values = cirq .SumOfProducts ([[0 , 2 ], [1 , 0 ]]), control_qid_shape = (2 , 3 )
109+ )
100110C2Restricted = cirq .ControlledGate (RestrictedGate (), control_values = [2 ], control_qid_shape = (3 ,))
101111
102112
@@ -107,7 +117,7 @@ def test_init():
107117
108118
109119def test_init2 ():
110- with pytest .raises (ValueError , match = r'len \(control_values\) != num_controls' ):
120+ with pytest .raises (ValueError , match = r'cirq\.num_qubits \(control_values\) != num_controls' ):
111121 cirq .ControlledGate (cirq .Z , num_controls = 1 , control_values = (1 , 0 ))
112122 with pytest .raises (ValueError , match = r'len\(control_qid_shape\) != num_controls' ):
113123 cirq .ControlledGate (cirq .Z , num_controls = 1 , control_qid_shape = (2 , 2 ))
@@ -125,15 +135,15 @@ def test_init2():
125135 gate = cirq .ControlledGate (cirq .Z , 1 )
126136 assert gate .sub_gate is cirq .Z
127137 assert gate .num_controls () == 1
128- assert gate .control_values == (( 1 ,),)
138+ assert gate .control_values == cirq . ProductOfSums ((( 1 ,),) )
129139 assert gate .control_qid_shape == (2 ,)
130140 assert gate .num_qubits () == 2
131141 assert cirq .qid_shape (gate ) == (2 , 2 )
132142
133143 gate = cirq .ControlledGate (cirq .Z , 2 )
134144 assert gate .sub_gate is cirq .Z
135145 assert gate .num_controls () == 2
136- assert gate .control_values == (( 1 ,), (1 ,))
146+ assert gate .control_values == cirq . ProductOfSums ((( 1 ,), (1 ,) ))
137147 assert gate .control_qid_shape == (2 , 2 )
138148 assert gate .num_qubits () == 3
139149 assert cirq .qid_shape (gate ) == (2 , 2 , 2 )
@@ -143,7 +153,7 @@ def test_init2():
143153 )
144154 assert gate .sub_gate is cirq .Z
145155 assert gate .num_controls () == 7
146- assert gate .control_values == (( 1 ,),) * 7
156+ assert gate .control_values == cirq . ProductOfSums ((( 1 ,),) * 7 )
147157 assert gate .control_qid_shape == (2 ,) * 7
148158 assert gate .num_qubits () == 8
149159 assert cirq .qid_shape (gate ) == (2 ,) * 8
@@ -162,15 +172,15 @@ def test_init2():
162172 gate = cirq .ControlledGate (cirq .Z , control_values = (0 , (0 , 1 )))
163173 assert gate .sub_gate is cirq .Z
164174 assert gate .num_controls () == 2
165- assert gate .control_values == (( 0 ,), (0 , 1 ))
175+ assert gate .control_values == cirq . ProductOfSums ((( 0 ,), (0 , 1 ) ))
166176 assert gate .control_qid_shape == (2 , 2 )
167177 assert gate .num_qubits () == 3
168178 assert cirq .qid_shape (gate ) == (2 , 2 , 2 )
169179
170180 gate = cirq .ControlledGate (cirq .Z , control_qid_shape = (3 , 3 ))
171181 assert gate .sub_gate is cirq .Z
172182 assert gate .num_controls () == 2
173- assert gate .control_values == (( 1 ,), (1 ,))
183+ assert gate .control_values == cirq . ProductOfSums ((( 1 ,), (1 ,) ))
174184 assert gate .control_qid_shape == (3 , 3 )
175185 assert gate .num_qubits () == 3
176186 assert cirq .qid_shape (gate ) == (3 , 3 , 2 )
@@ -232,9 +242,15 @@ def test_eq():
232242 eq .add_equality_group (
233243 cirq .ControlledGate (cirq .H , control_values = [1 , (0 , 2 )], control_qid_shape = [2 , 3 ]),
234244 cirq .ControlledGate (cirq .H , control_values = (1 , [0 , 2 ]), control_qid_shape = (2 , 3 )),
245+ cirq .ControlledGate (
246+ cirq .H , control_values = cirq .SumOfProducts ([[1 , 0 ], [1 , 2 ]]), control_qid_shape = (2 , 3 )
247+ ),
235248 )
236249 eq .add_equality_group (
237- cirq .ControlledGate (cirq .H , control_values = [(2 , 0 ), 1 ], control_qid_shape = [3 , 2 ])
250+ cirq .ControlledGate (cirq .H , control_values = [(2 , 0 ), 1 ], control_qid_shape = [3 , 2 ]),
251+ cirq .ControlledGate (
252+ cirq .H , control_values = cirq .SumOfProducts ([[2 , 1 ], [0 , 1 ]]), control_qid_shape = (3 , 2 )
253+ ),
238254 )
239255 eq .add_equality_group (
240256 cirq .ControlledGate (cirq .H , control_values = [1 , 0 ], control_qid_shape = [2 , 3 ]),
@@ -278,18 +294,21 @@ def _has_mixture_(self):
278294 g .controlled (control_values = [1 ]),
279295 g .controlled (control_qid_shape = (2 ,)),
280296 cirq .ControlledGate (g , num_controls = 1 ),
297+ g .controlled (control_values = cirq .SumOfProducts ([[1 ]])),
281298 )
282299 eq .add_equality_group (
283300 cirq .ControlledGate (g , num_controls = 2 ),
284301 g .controlled (control_values = [1 , 1 ]),
285302 g .controlled (control_qid_shape = [2 , 2 ]),
286303 g .controlled (num_controls = 2 ),
287304 g .controlled ().controlled (),
305+ g .controlled (control_values = cirq .SumOfProducts ([[1 , 1 ]])),
288306 )
289307 eq .add_equality_group (
290308 cirq .ControlledGate (g , control_values = [0 , 1 ]),
291309 g .controlled (control_values = [0 , 1 ]),
292310 g .controlled (control_values = [1 ]).controlled (control_values = [0 ]),
311+ g .controlled (control_values = cirq .SumOfProducts ([[1 ]])).controlled (control_values = [0 ]),
293312 )
294313 eq .add_equality_group (g .controlled (control_values = [0 ]).controlled (control_values = [1 ]))
295314 eq .add_equality_group (
@@ -350,6 +369,20 @@ def test_unitary():
350369 atol = 1e-8 ,
351370 )
352371
372+ C_xorX = cirq .ControlledGate (cirq .X , control_values = xor_control_values )
373+ # fmt: off
374+ np .testing .assert_allclose (cirq .unitary (C_xorX ), np .array ([
375+ [1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
376+ [0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ],
377+ [0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 ],
378+ [0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 ],
379+ [0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 ],
380+ [0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 ],
381+ [0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ],
382+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ]]
383+ ))
384+ # fmt: on
385+
353386
354387@pytest .mark .parametrize (
355388 'gate, should_decompose_to_target' ,
@@ -380,6 +413,10 @@ def test_unitary():
380413 (cirq .MatrixGate (cirq .testing .random_unitary (4 , random_state = 1234 )), False ),
381414 (cirq .XX ** sympy .Symbol ("s" ), True ),
382415 (cirq .CZ ** sympy .Symbol ("s" ), True ),
416+ # Non-trivial `cirq.ProductOfSum` controls.
417+ (C_01_10_11H , False ),
418+ (C_xorH , False ),
419+ (C0C_xorH , False ),
383420 ],
384421)
385422def test_controlled_gate_is_consistent (gate : cirq .Gate , should_decompose_to_target ):
@@ -507,7 +544,7 @@ def _has_unitary_(self):
507544 return True
508545
509546
510- def test_circuit_diagram ():
547+ def test_circuit_diagram_product_of_sums ():
511548 qubits = cirq .LineQubit .range (3 )
512549 c = cirq .Circuit ()
513550 c .append (cirq .ControlledGate (MultiH (2 ))(* qubits ))
@@ -542,6 +579,35 @@ def test_circuit_diagram():
542579 )
543580
544581
582+ def test_circuit_diagram_sum_of_products ():
583+ q = cirq .LineQubit .range (4 )
584+ c = cirq .Circuit (C_xorH .on (* q [:3 ]), C_01_10_11H .on (* q [:3 ]), C0C_xorH .on (* q ))
585+ cirq .testing .assert_has_diagram (
586+ c ,
587+ """
588+ 0: ───@────────@(011)───@(00)───
589+ │ │ │
590+ 1: ───@(xor)───@(101)───@(01)───
591+ │ │ │
592+ 2: ───H────────H────────@(10)───
593+ │
594+ 3: ─────────────────────H───────
595+ """ ,
596+ )
597+ q = cirq .LineQid .for_qid_shape ((2 , 3 , 2 ))
598+ c = cirq .Circuit (C_02_20H (* q ))
599+ cirq .testing .assert_has_diagram (
600+ c ,
601+ """
602+ 0 (d=2): ───@(01)───
603+ │
604+ 1 (d=3): ───@(20)───
605+ │
606+ 2 (d=2): ───H───────
607+ """ ,
608+ )
609+
610+
545611class MockGate (cirq .testing .TwoQubitGate ):
546612 def _circuit_diagram_info_ (self , args : cirq .CircuitDiagramInfoArgs ) -> cirq .CircuitDiagramInfo :
547613 self .captured_diagram_args = args
@@ -571,12 +637,21 @@ def test_bounded_effect():
571637 assert cirq .trace_distance_bound (cirq .ControlledGate (cirq .X ** foo )) == 1
572638
573639
574- def test_repr ():
575- cirq .testing .assert_equivalent_repr (cirq .ControlledGate (cirq .Z ))
576- cirq .testing .assert_equivalent_repr (cirq .ControlledGate (cirq .Z , num_controls = 1 ))
577- cirq .testing .assert_equivalent_repr (cirq .ControlledGate (cirq .Z , num_controls = 2 ))
578- cirq .testing .assert_equivalent_repr (C0C1H )
579- cirq .testing .assert_equivalent_repr (C2C2H )
640+ @pytest .mark .parametrize (
641+ 'gate' ,
642+ [
643+ cirq .ControlledGate (cirq .Z ),
644+ cirq .ControlledGate (cirq .Z , num_controls = 1 ),
645+ cirq .ControlledGate (cirq .Z , num_controls = 2 ),
646+ C0C1H ,
647+ C2C2H ,
648+ C_01_10_11H ,
649+ C_xorH ,
650+ C_02_20H ,
651+ ],
652+ )
653+ def test_repr (gate ):
654+ cirq .testing .assert_equivalent_repr (gate )
580655
581656
582657def test_str ():
@@ -597,48 +672,3 @@ def test_controlled_mixture():
597672 c_yes = cirq .ControlledGate (sub_gate = cirq .phase_flip (0.25 ), num_controls = 1 )
598673 assert cirq .has_mixture (c_yes )
599674 assert cirq .approx_eq (cirq .mixture (c_yes ), [(0.75 , np .eye (4 )), (0.25 , cirq .unitary (cirq .CZ ))])
600-
601-
602- class MockControlValues (AbstractControlValues ):
603- def __and__ (self , other ):
604- pass
605-
606- def _expand (self ):
607- pass
608-
609- def diagram_repr (self ):
610- pass
611-
612- def _number_variables (self ):
613- pass
614-
615- def __len__ (self ):
616- return 1
617-
618- def _identifier (self ):
619- pass
620-
621- def __hash__ (self ):
622- pass
623-
624- def __repr__ (self ):
625- pass
626-
627- def validate (self , shapes ):
628- pass
629-
630- def _are_ones (self ):
631- pass
632-
633- def _json_dict_ (self ):
634- pass
635-
636-
637- def test_decompose_applies_only_to_ProductOfSums ():
638- g = cirq .ControlledGate (cirq .X , control_values = MockControlValues ())
639- assert g ._decompose_ (None ) is NotImplemented
640-
641-
642- def test_circuit_diagram_info_applies_only_to_ProductOfSums ():
643- g = cirq .ControlledGate (cirq .X , control_values = MockControlValues ())
644- assert g ._circuit_diagram_info_ (None ) is NotImplemented
0 commit comments