Skip to content

Commit bba0153

Browse files
authored
Execute TODO to have the prob mass auto added (#3363)
Issue #3220
1 parent fa99585 commit bba0153

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

cirq/ops/common_channels.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __init__(self,
3636
p_x: Optional[float] = None,
3737
p_y: Optional[float] = None,
3838
p_z: Optional[float] = None,
39-
error_probabilities: Optional[Dict[str, float]] = None
40-
) -> None:
39+
error_probabilities: Optional[Dict[str, float]] = None,
40+
tol: float = 1e-8) -> None:
4141
r"""The asymmetric depolarizing channel.
4242
4343
This channel applies one of 4**n disjoint possibilities: nothing (the
@@ -58,7 +58,10 @@ def __init__(self,
5858
p_y: The probability that a Pauli Y and no other gate occurs.
5959
p_z: The probability that a Pauli Z and no other gate occurs.
6060
error_probabilities: Dictionary of string (Pauli operator) to its
61-
probability
61+
probability. If the identity is missing from the list, it will
62+
be added so that the total probability mass is 1.
63+
tol: The tolerance used making sure the total probability mass is
64+
equal to 1.
6265
6366
Examples of calls:
6467
* Single qubit: AsymmetricDepolarizingChannel(0.2, 0.1, 0.3)
@@ -79,10 +82,10 @@ def __init__(self,
7982
for k, v in error_probabilities.items():
8083
value.validate_probability(v, f"p({k})")
8184
sum_probs = sum(error_probabilities.values())
82-
# TODO(tonybruguier): Instead of forcing the probabilities to add up
83-
# to 1, check whether the identity is missing, and if that is the
84-
# case, automatically add it with the missing probability mass.
85-
if abs(sum_probs - 1.0) > 1e-6:
85+
identity = 'I' * num_qubits
86+
if sum_probs < 1.0 - tol and identity not in error_probabilities:
87+
error_probabilities[identity] = 1.0 - sum_probs
88+
elif abs(sum_probs - 1.0) > tol:
8689
raise ValueError(
8790
f"Probabilities do not add up to 1 but to {sum_probs}")
8891
self._num_qubits = num_qubits
@@ -190,11 +193,12 @@ def _json_dict_(self) -> Dict[str, Any]:
190193
return protocols.obj_to_dict_helper(self, ['error_probabilities'])
191194

192195

193-
def asymmetric_depolarize(p_x: Optional[float] = None,
194-
p_y: Optional[float] = None,
195-
p_z: Optional[float] = None,
196-
error_probabilities: Optional[Dict[str, float]] = None
197-
) -> AsymmetricDepolarizingChannel:
196+
def asymmetric_depolarize(
197+
p_x: Optional[float] = None,
198+
p_y: Optional[float] = None,
199+
p_z: Optional[float] = None,
200+
error_probabilities: Optional[Dict[str, float]] = None,
201+
tol: float = 1e-8) -> AsymmetricDepolarizingChannel:
198202
r"""Returns a AsymmetricDepolarizingChannel with given parameter.
199203
200204
This channel applies one of 4**n disjoint possibilities: nothing (the
@@ -215,7 +219,10 @@ def asymmetric_depolarize(p_x: Optional[float] = None,
215219
p_y: The probability that a Pauli Y and no other gate occurs.
216220
p_z: The probability that a Pauli Z and no other gate occurs.
217221
error_probabilities: Dictionary of string (Pauli operator) to its
218-
probability
222+
probability. If the identity is missing from the list, it will
223+
be added so that the total probability mass is 1.
224+
tol: The tolerance used making sure the total probability mass is
225+
equal to 1.
219226
220227
Examples of calls:
221228
* Single qubit: AsymmetricDepolarizingChannel(0.2, 0.1, 0.3)
@@ -226,7 +233,8 @@ def asymmetric_depolarize(p_x: Optional[float] = None,
226233
Raises:
227234
ValueError: if the args or the sum of the args are not probabilities.
228235
"""
229-
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities)
236+
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities,
237+
tol)
230238

231239

232240
@value.value_equality

cirq/ops/common_channels_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,13 @@ def test_bad_probs():
606606
cirq.asymmetric_depolarize(error_probabilities={'X': 0.7, 'Y': 0.6})
607607

608608

609+
def test_missing_prob_mass():
610+
with pytest.raises(ValueError, match='Probabilities do not add up to 1'):
611+
cirq.asymmetric_depolarize(error_probabilities={'X': 0.1, 'I': 0.2})
612+
d = cirq.asymmetric_depolarize(error_probabilities={'X': 0.1})
613+
np.testing.assert_almost_equal(d.error_probabilities['I'], 0.9)
614+
615+
609616
def test_multi_asymmetric_depolarizing_channel():
610617
d = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
611618
np.testing.assert_almost_equal(

0 commit comments

Comments
 (0)