Skip to content

Commit fa99585

Browse files
authored
Add symmetric depol (#3361)
For issue #3220
1 parent 730cc24 commit fa99585

File tree

4 files changed

+151
-43
lines changed

4 files changed

+151
-43
lines changed

cirq/ops/common_channels.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Quantum channels that are commonly used in the literature."""
1616

17+
import itertools
1718
from typing import (Any, Dict, Iterable, Optional, Sequence, Tuple, Union,
1819
TYPE_CHECKING)
1920

@@ -77,8 +78,13 @@ def __init__(self,
7778
raise ValueError(f"{k} must have {num_qubits} Pauli gates.")
7879
for k, v in error_probabilities.items():
7980
value.validate_probability(v, f"p({k})")
80-
value.validate_probability(sum(error_probabilities.values()),
81-
'sum(error_probabilities)')
81+
sum_probs = sum(error_probabilities.values())
82+
# TODO(tonybruguier): Instead of forcing the probabilities to add up
83+
# to 1, check whether the identity is missing, and if that is the
84+
# case, automatically add it with the missing probability mass.
85+
if abs(sum_probs - 1.0) > 1e-6:
86+
raise ValueError(
87+
f"Probabilities do not add up to 1 but to {sum_probs}")
8288
self._num_qubits = num_qubits
8389
self._error_probabilities = error_probabilities
8490
else:
@@ -227,33 +233,53 @@ def asymmetric_depolarize(p_x: Optional[float] = None,
227233
class DepolarizingChannel(gate_features.SingleQubitGate):
228234
"""A channel that depolarizes a qubit."""
229235

230-
def __init__(self, p: float) -> None:
236+
def __init__(self, p: float, n_qubits: int = 1) -> None:
231237
r"""The symmetric depolarizing channel.
232238
233-
This channel applies one of four disjoint possibilities: nothing (the
234-
identity channel) or one of the three pauli gates. The disjoint
235-
probabilities of the three gates are all the same, p / 3, and the
236-
identity is done with probability 1 - p. The supplied probability
237-
must be a valid probability or else this constructor will raise a
238-
ValueError.
239+
This channel applies one of 4**n disjoint possibilities: nothing (the
240+
identity channel) or one of the 4**n - 1 pauli gates. The disjoint
241+
probabilities of the non-identity Pauli gates are all the same,
242+
p / (4**n - 1), and the identity is done with probability 1 - p. The
243+
supplied probability must be a valid probability or else this
244+
constructor will raise a ValueError.
245+
239246
240247
This channel evolves a density matrix via
241248
242249
$$
243250
\rho \rightarrow (1 - p) \rho
244-
+ (p / 3) X \rho X + (p / 3) Y \rho Y + (p / 3) Z \rho Z
251+
+ 1 / (4**n - 1) \sum _i P_i X P_i
245252
$$
246253
254+
where P_i are the 4**n - 1 Pauli gates (excluding the identity).
255+
247256
Args:
248257
p: The probability that one of the Pauli gates is applied. Each of
249-
the Pauli gates is applied independently with probability p / 3.
258+
the Pauli gates is applied independently with probability
259+
p / (4**n - 1).
260+
n_qubits: the number of qubits.
250261
251262
Raises:
252263
ValueError: if p is not a valid probability.
253264
"""
254265

266+
error_probabilities = {}
267+
268+
p_depol = p / (4**n_qubits - 1)
269+
p_identity = 1.0 - p
270+
for pauli_tuple in itertools.product(['I', 'X', 'Y', 'Z'],
271+
repeat=n_qubits):
272+
pauli_string = ''.join(pauli_tuple)
273+
if pauli_string == 'I' * n_qubits:
274+
error_probabilities[pauli_string] = p_identity
275+
else:
276+
error_probabilities[pauli_string] = p_depol
277+
255278
self._p = p
256-
self._delegate = AsymmetricDepolarizingChannel(p / 3, p / 3, p / 3)
279+
self._n_qubits = n_qubits
280+
281+
self._delegate = AsymmetricDepolarizingChannel(
282+
error_probabilities=error_probabilities)
257283

258284
def _mixture_(self) -> Sequence[Tuple[float, np.ndarray]]:
259285
return self._delegate._mixture_()
@@ -265,55 +291,70 @@ def _value_equality_values_(self):
265291
return self._p
266292

267293
def __repr__(self) -> str:
268-
return 'cirq.depolarize(p={!r})'.format(self._p)
294+
if self._n_qubits == 1:
295+
return f"cirq.depolarize(p={self._p})"
296+
return f"cirq.depolarize(p={self._p},n_qubits={self._n_qubits})"
269297

270298
def __str__(self) -> str:
271-
return 'depolarize(p={!r})'.format(self._p)
299+
if self._n_qubits == 1:
300+
return f"depolarize(p={self._p})"
301+
return f"depolarize(p={self._p},n_qubits={self._n_qubits})"
272302

273303
def _circuit_diagram_info_(self,
274304
args: 'protocols.CircuitDiagramInfoArgs') -> str:
275305
if args.precision is not None:
276-
f = '{:.' + str(args.precision) + 'g}'
277-
return 'D({})'.format(f).format(self._p)
278-
return 'D({!r})'.format(self._p)
306+
return f"D({self._p:.{args.precision}g})"
307+
return f"D({self._p})"
279308

280309
@property
281310
def p(self) -> float:
282311
"""The probability that one of the Pauli gates is applied.
283312
284-
Each of the Pauli gates is applied independently with probability p / 3.
313+
Each of the Pauli gates is applied independently with probability
314+
p / (4**n_qubits - 1).
285315
"""
286316
return self._p
287317

318+
@property
319+
def n_qubits(self) -> int:
320+
"""The number of qubits"""
321+
return self._n_qubits
322+
288323
def _json_dict_(self) -> Dict[str, Any]:
289-
return protocols.obj_to_dict_helper(self, ['p'])
324+
if self._n_qubits == 1:
325+
return protocols.obj_to_dict_helper(self, ['p'])
326+
return protocols.obj_to_dict_helper(self, ['p', 'n_qubits'])
290327

291328

292-
def depolarize(p: float) -> DepolarizingChannel:
329+
def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
293330
r"""Returns a DepolarizingChannel with given probability of error.
294331
295-
This channel applies one of four disjoint possibilities: nothing (the
296-
identity channel) or one of the three pauli gates. The disjoint
297-
probabilities of the three gates are all the same, p / 3, and the
298-
identity is done with probability 1 - p. The supplied probability
299-
must be a valid probability or else this constructor will raise a
300-
ValueError.
332+
This channel applies one of 4**n disjoint possibilities: nothing (the
333+
identity channel) or one of the 4**n - 1 pauli gates. The disjoint
334+
probabilities of the non-identity Pauli gates are all the same,
335+
p / (4**n - 1), and the identity is done with probability 1 - p. The
336+
supplied probability must be a valid probability or else this constructor
337+
will raise a ValueError.
301338
302339
This channel evolves a density matrix via
303340
304341
$$
305342
\rho \rightarrow (1 - p) \rho
306-
+ (p / 3) X \rho X + (p / 3) Y \rho Y + (p / 3) Z \rho Z
343+
+ 1 / (4**n - 1) \sum _i P_i X P_i
307344
$$
308345
346+
where P_i are the 4**n - 1 Pauli gates (excluding the identity).
347+
309348
Args:
310349
p: The probability that one of the Pauli gates is applied. Each of
311-
the Pauli gates is applied independently with probability p / 3.
350+
the Pauli gates is applied independently with probability
351+
p / (4**n - 1).
352+
n_qubits: The number of qubits.
312353
313354
Raises:
314355
ValueError: if p is not a valid probability.
315356
"""
316-
return DepolarizingChannel(p)
357+
return DepolarizingChannel(p, n_qubits)
317358

318359

319360
@value.value_equality

cirq/ops/common_channels_test.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,31 +131,77 @@ def test_asymmetric_depolarizing_channel_text_diagram():
131131

132132
def test_depolarizing_channel():
133133
d = cirq.depolarize(0.3)
134-
np.testing.assert_almost_equal(cirq.channel(d),
135-
(np.sqrt(0.7) * np.eye(2),
136-
np.sqrt(0.1) * X,
137-
np.sqrt(0.1) * Y,
138-
np.sqrt(0.1) * Z))
134+
np.testing.assert_almost_equal(cirq.channel(d), (
135+
np.sqrt(0.7) * np.eye(2),
136+
np.sqrt(0.1) * X,
137+
np.sqrt(0.1) * Y,
138+
np.sqrt(0.1) * Z,
139+
))
140+
assert cirq.has_channel(d)
141+
142+
143+
def test_depolarizing_channel_two_qubits():
144+
d = cirq.depolarize(0.15, n_qubits=2)
145+
np.testing.assert_almost_equal(cirq.channel(d), (
146+
np.sqrt(0.85) * np.eye(4),
147+
np.sqrt(0.01) * np.kron(np.eye(2), X),
148+
np.sqrt(0.01) * np.kron(np.eye(2), Y),
149+
np.sqrt(0.01) * np.kron(np.eye(2), Z),
150+
np.sqrt(0.01) * np.kron(X, np.eye(2)),
151+
np.sqrt(0.01) * np.kron(X, X),
152+
np.sqrt(0.01) * np.kron(X, Y),
153+
np.sqrt(0.01) * np.kron(X, Z),
154+
np.sqrt(0.01) * np.kron(Y, np.eye(2)),
155+
np.sqrt(0.01) * np.kron(Y, X),
156+
np.sqrt(0.01) * np.kron(Y, Y),
157+
np.sqrt(0.01) * np.kron(Y, Z),
158+
np.sqrt(0.01) * np.kron(Z, np.eye(2)),
159+
np.sqrt(0.01) * np.kron(Z, X),
160+
np.sqrt(0.01) * np.kron(Z, Y),
161+
np.sqrt(0.01) * np.kron(Z, Z),
162+
))
139163
assert cirq.has_channel(d)
140164

141165
def test_depolarizing_mixture():
142166
d = cirq.depolarize(0.3)
143167
assert_mixtures_equal(cirq.mixture(d),
144-
((0.7, np.eye(2)),
145-
(0.1, X),
146-
(0.1, Y),
147-
(0.1, Z)))
168+
((0.7, np.eye(2)), (0.1, X), (0.1, Y), (0.1, Z)))
169+
assert cirq.has_mixture(d)
170+
171+
172+
def test_depolarizing_mixture_two_qubits():
173+
d = cirq.depolarize(0.15, n_qubits=2)
174+
assert_mixtures_equal(cirq.mixture(d),
175+
((0.85, np.eye(4)), (0.01, np.kron(np.eye(2), X)),
176+
(0.01, np.kron(np.eye(2), Y)),
177+
(0.01, np.kron(np.eye(2), Z)),
178+
(0.01, np.kron(X, np.eye(2))), (0.01, np.kron(X, X)),
179+
(0.01, np.kron(X, Y)), (0.01, np.kron(X, Z)),
180+
(0.01, np.kron(Y, np.eye(2))), (0.01, np.kron(Y, X)),
181+
(0.01, np.kron(Y, Y)), (0.01, np.kron(Y, Z)),
182+
(0.01, np.kron(Z, np.eye(2))), (0.01, np.kron(Z, X)),
183+
(0.01, np.kron(Z, Y)), (0.01, np.kron(Z, Z))))
148184
assert cirq.has_mixture(d)
149185

150186

151187
def test_depolarizing_channel_repr():
152188
cirq.testing.assert_equivalent_repr(cirq.DepolarizingChannel(0.3))
153189

154190

191+
def test_depolarizing_channel_repr_two_qubits():
192+
cirq.testing.assert_equivalent_repr(
193+
cirq.DepolarizingChannel(0.3, n_qubits=2))
194+
195+
155196
def test_depolarizing_channel_str():
156197
assert str(cirq.depolarize(0.3)) == 'depolarize(p=0.3)'
157198

158199

200+
def test_depolarizing_channel_str_two_qubits():
201+
assert str(cirq.depolarize(0.3,
202+
n_qubits=2)) == 'depolarize(p=0.3,n_qubits=2)'
203+
204+
159205
def test_depolarizing_channel_eq():
160206
et = cirq.testing.EqualsTester()
161207
c = cirq.depolarize(0.0)
@@ -166,9 +212,9 @@ def test_depolarizing_channel_eq():
166212

167213

168214
def test_depolarizing_channel_invalid_probability():
169-
with pytest.raises(ValueError, match='was less than 0'):
215+
with pytest.raises(ValueError, match=re.escape('p(I) was greater than 1.')):
170216
cirq.depolarize(-0.1)
171-
with pytest.raises(ValueError, match='was greater than 1'):
217+
with pytest.raises(ValueError, match=re.escape('p(I) was less than 0.')):
172218
cirq.depolarize(1.1)
173219

174220

@@ -180,6 +226,22 @@ def test_depolarizing_channel_text_diagram():
180226
assert (cirq.circuit_diagram_info(
181227
d, args=round_to_2_prec) == cirq.CircuitDiagramInfo(
182228
wire_symbols=('D(0.12)',)))
229+
assert (cirq.circuit_diagram_info(
230+
d, args=no_precision) == cirq.CircuitDiagramInfo(
231+
wire_symbols=('D(0.1234567)',)))
232+
233+
234+
def test_depolarizing_channel_text_diagram_two_qubits():
235+
d = cirq.depolarize(0.1234567, n_qubits=2)
236+
assert (cirq.circuit_diagram_info(
237+
d, args=round_to_6_prec) == cirq.CircuitDiagramInfo(
238+
wire_symbols=('D(0.123457)',)))
239+
assert (cirq.circuit_diagram_info(
240+
d, args=round_to_2_prec) == cirq.CircuitDiagramInfo(
241+
wire_symbols=('D(0.12)',)))
242+
assert (cirq.circuit_diagram_info(
243+
d, args=no_precision) == cirq.CircuitDiagramInfo(
244+
wire_symbols=('D(0.1234567)',)))
183245

184246

185247
def test_generalized_amplitude_damping_channel():
@@ -539,9 +601,8 @@ def test_bad_error_probabilities_gate():
539601
def test_bad_probs():
540602
with pytest.raises(ValueError, match=re.escape('p(X) was greater than 1.')):
541603
cirq.asymmetric_depolarize(error_probabilities={'X': 1.1, 'Y': -0.1})
542-
with pytest.raises(
543-
ValueError,
544-
match=re.escape('sum(error_probabilities) was greater than 1.')):
604+
with pytest.raises(ValueError,
605+
match=re.escape('Probabilities do not add up to 1')):
545606
cirq.asymmetric_depolarize(error_probabilities={'X': 0.7, 'Y': 0.6})
546607

547608

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"cirq_type": "DepolarizingChannel",
3+
"p": 0.5,
4+
"n_qubits": 2
5+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.depolarize(p=0.5,n_qubits=2)

0 commit comments

Comments
 (0)