Skip to content

Commit d29b17e

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add quantiles to Student-T, Beta and SigmoidBeta.
PiperOrigin-RevId: 473158097
1 parent b82fd08 commit d29b17e

File tree

6 files changed

+93
-6
lines changed

6 files changed

+93
-6
lines changed

tensorflow_probability/python/distributions/beta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,12 @@ def _cdf(self, x):
332332
return distribution_util.extend_cdf_outside_support(
333333
x, answer, low=0., high=1.)
334334

335+
@distribution_util.AppendDocstring(_beta_sample_note)
336+
def _quantile(self, p):
337+
concentration1 = tf.convert_to_tensor(self.concentration1)
338+
concentration0 = tf.convert_to_tensor(self.concentration0)
339+
return special.betaincinv(concentration1, concentration0, p)
340+
335341
def _log_unnormalized_prob(self, x, concentration1, concentration0):
336342
return (tf.math.xlogy(concentration1 - 1., x) +
337343
tf.math.xlog1py(concentration0 - 1., -x))

tensorflow_probability/python/distributions/beta_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def testBetaSampleMultidimensional(self):
291291

292292
@parameterized.parameters((np.float32, 5e-3), (np.float64, 1e-4))
293293
def testBetaCdf(self, dt, rtol):
294-
shape = (30, 40, 50)
294+
shape = (30, 4, 5)
295295
a = 10. * np.random.random(shape).astype(dt)
296296
b = 10. * np.random.random(shape).astype(dt)
297297
x = np.random.random(shape).astype(dt)
@@ -304,9 +304,27 @@ def testBetaCdfBeyondSupport(self):
304304
cdf = beta.Beta(2., 3., validate_args=False).cdf([-3.7, 1.03])
305305
self.assertAllEqual([0., 1.], self.evaluate(cdf))
306306

307+
@parameterized.parameters((np.float32, 5e-3), (np.float64, 1e-4))
308+
def testBetaQuantile(self, dt, rtol):
309+
shape = (30, 4, 5)
310+
a = 5. * np.random.random(shape).astype(dt)
311+
b = 5. * np.random.random(shape).astype(dt)
312+
p = np.random.uniform(low=0., high=1., size=shape).astype(dt)
313+
quantile = tf.function(beta.Beta(a, b).quantile)
314+
actual = self.evaluate(quantile(p))
315+
# Pass f64 values to avoid errors in scipy.
316+
self.assertAllClose(
317+
sp_stats.beta.ppf(
318+
p.astype(np.float64),
319+
a.astype(np.float64),
320+
b.astype(np.float64)),
321+
actual,
322+
rtol=rtol,
323+
atol=1e-10)
324+
307325
@parameterized.parameters((np.float32, 5e-3), (np.float64, 1e-4))
308326
def testBetaLogCdf(self, dt, rtol):
309-
shape = (30, 40, 50)
327+
shape = (30, 4, 5)
310328
a = 10. * np.random.random(shape).astype(dt)
311329
b = 10. * np.random.random(shape).astype(dt)
312330
x = np.random.random(shape).astype(dt)

tensorflow_probability/python/distributions/sigmoid_beta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,5 +244,11 @@ def _cdf(self, x):
244244
sig_x = tf.math.sigmoid(x)
245245
return special.betainc(concentration1, concentration0, sig_x)
246246

247+
def _quantile(self, p):
248+
concentration1 = tf.convert_to_tensor(self.concentration1)
249+
concentration0 = tf.convert_to_tensor(self.concentration0)
250+
y = special.betaincinv(concentration1, concentration0, p)
251+
return tf.math.log(y) - tf.math.log1p(-y)
252+
247253
def _mode(self):
248254
return tf.math.log(self.concentration1 / self.concentration0)

tensorflow_probability/python/distributions/sigmoid_beta_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ def testCDF(self):
176176
bijector=invert.Invert(sigmoid.Sigmoid())).cdf(x)
177177
self.assertAllClose(cdf, expected_cdf)
178178

179+
def testQuantile(self):
180+
batch_shape = [6, 1]
181+
a = 2. * np.ones(batch_shape, dtype=np.float32)
182+
b = 3. * np.ones(batch_shape, dtype=np.float32)
183+
p = np.logspace(-4., -0.01, 20).astype(np.float32)
184+
dist = sigmoid_beta.SigmoidBeta(a, b, validate_args=True)
185+
quantile = dist.quantile(p)
186+
self.assertEqual(quantile.shape, (6, 20))
187+
188+
expected_quantile = transformed_distribution.TransformedDistribution(
189+
distribution=beta.Beta(a, b),
190+
bijector=invert.Invert(sigmoid.Sigmoid())).quantile(p)
191+
self.assertAllClose(quantile, expected_quantile)
192+
179193
def testMode(self):
180194
a = tf.constant(1.0)
181195
b = tf.constant(2.0)

tensorflow_probability/python/distributions/student_t.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,28 @@ def cdf(x, df, loc, scale):
120120
return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
121121

122122

123+
def quantile(p, df, loc, scale):
124+
"""Compute quantile function of Student T distribution.
125+
126+
Note that scale can be negative.
127+
128+
Args:
129+
p: Floating-point `Tensor`. Probabilities from 0 to 1.
130+
df: Floating-point `Tensor`. The degrees of freedom of the
131+
distribution(s). `df` must contain only positive values.
132+
loc: Floating-point `Tensor`; the location(s) of the distribution(s).
133+
scale: Floating-point `Tensor`; the scale(s) of the distribution(s).
134+
135+
Returns:
136+
A `Tensor` with shape broadcast according to the arguments.
137+
"""
138+
df = tf.convert_to_tensor(df)
139+
p_adjusted = tf.where(p < 0.5, p, 1. - p)
140+
y = special.betaincinv(0.5 * df, 0.5, 2 * p_adjusted)
141+
return loc + tf.math.abs(
142+
scale) * tf.math.sign(p - 0.5) * tf.math.sqrt(df * (1 - y) / y)
143+
144+
123145
def entropy(df, scale, batch_shape, dtype):
124146
"""Compute entropy of the StudentT distribution.
125147
@@ -350,6 +372,10 @@ def _cdf(self, x):
350372
df = tf.convert_to_tensor(self.df)
351373
return cdf(x, df, self.loc, self.scale)
352374

375+
def _quantile(self, x):
376+
df = tf.convert_to_tensor(self.df)
377+
return quantile(x, df, self.loc, self.scale)
378+
353379
def _entropy(self):
354380
df = tf.convert_to_tensor(self.df)
355381
scale = tf.convert_to_tensor(self.scale)

tensorflow_probability/python/distributions/student_t_test.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def testStudentPDFAndLogPDF(self):
4242
student = student_t.StudentT(df, loc=mu, scale=-sigma, validate_args=True)
4343

4444
log_pdf = student.log_prob(t)
45-
self.assertEquals(log_pdf.shape, (6,))
45+
self.assertAllEqual(log_pdf.shape, (6,))
4646
log_pdf_values = self.evaluate(log_pdf)
4747
pdf = student.prob(t)
48-
self.assertEquals(pdf.shape, (6,))
48+
self.assertAllEqual(pdf.shape, (6,))
4949
pdf_values = self.evaluate(pdf)
5050

5151
expected_log_pdf = sp_stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
@@ -92,10 +92,10 @@ def testStudentCDFAndLogCDF(self):
9292
student = student_t.StudentT(df, loc=mu, scale=sigma, validate_args=True)
9393

9494
log_cdf = student.log_cdf(t)
95-
self.assertEquals(log_cdf.shape, (6,))
95+
self.assertAllEqual(log_cdf.shape, (6,))
9696
log_cdf_values = self.evaluate(log_cdf)
9797
cdf = student.cdf(t)
98-
self.assertEquals(cdf.shape, (6,))
98+
self.assertAllEqual(cdf.shape, (6,))
9999
cdf_values = self.evaluate(cdf)
100100

101101
expected_log_cdf = sp_stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
@@ -107,6 +107,23 @@ def testStudentCDFAndLogCDF(self):
107107
self.assertAllClose(
108108
np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
109109

110+
def testStudentQuantile(self):
111+
batch_shape = (40, 1)
112+
df = np.random.uniform(
113+
low=0.5, high=10., size=batch_shape).astype(np.float32)
114+
mu = 7. * np.ones(batch_shape, dtype=np.float32)
115+
sigma = -8.
116+
p = np.logspace(-4., -0.01, 20).astype(np.float32)
117+
student = student_t.StudentT(
118+
df, loc=mu, scale=sigma, validate_args=True)
119+
120+
quantile = student.quantile(p)
121+
self.assertAllEqual(quantile.shape, (40, 20))
122+
quantile_values = self.evaluate(quantile)
123+
124+
expected_quantile = sp_stats.t.ppf(p, df, loc=mu, scale=np.abs(sigma))
125+
self.assertAllClose(expected_quantile, quantile_values, rtol=5e-4)
126+
110127
def testStudentEntropy(self):
111128
df_v = np.array([[2., 3., 7.]]) # 1x3
112129
mu_v = np.array([[1., -1, 0]]) # 1x3

0 commit comments

Comments
 (0)