Skip to content

Commit ae0a1f5

Browse files
srvasudetensorflower-gardener
authored andcommitted
Speed up several tests:
- In special_test, ensure that we compute partial derivatives in one @tf.function instead of 3 of them. - In generic_test, disable eager mode tests for KahanSumJitTest - Disable several MCMC tests in eager mode. - Disable ODE Gradient tests in eager mode (some of these tests could take 10 mins). - Changed bessel_test and two_piece_normal_test to sample O(1000) samples instead of O(100000) per test. PiperOrigin-RevId: 474695632
1 parent a7a5742 commit ae0a1f5

File tree

9 files changed

+119
-139
lines changed

9 files changed

+119
-139
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4503,8 +4503,10 @@ multi_substrate_py_test(
45034503
name = "two_piece_normal_test",
45044504
size = "medium",
45054505
srcs = ["two_piece_normal_test.py"],
4506+
shard_count = 4,
45064507
deps = [
45074508
":two_piece_normal",
4509+
# absl/testing:parameterized dep,
45084510
# numpy dep,
45094511
# tensorflow dep,
45104512
"//tensorflow_probability/python/internal:test_util",

tensorflow_probability/python/distributions/two_piece_normal_test.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import itertools
1818

1919
# Dependency imports
20+
from absl.testing import parameterized
2021
import numpy as np
2122

2223
import tensorflow.compat.v2 as tf
@@ -28,7 +29,7 @@
2829

2930

3031
@test_util.test_all_tf_execution_regimes
31-
class _TwoPieceNormalTest(object):
32+
class _TwoPieceNormalTest(parameterized.TestCase):
3233

3334
def make_two_piece_normal(self):
3435
if self.dtype is np.float32:
@@ -280,7 +281,8 @@ def testMode(self):
280281
self.assertAllClose(mode, expected_mode)
281282

282283
@test_util.numpy_disable_gradient_test
283-
def testFiniteGradientAtDifficultPoints(self):
284+
@parameterized.parameters(0.75, 1., 1.33)
285+
def testFiniteGradientAtDifficultPoints(self, skewness):
284286
def make_fn(attr):
285287
x = np.array([-100, -20, -5., 0., 5., 20, 100]).astype(self.dtype)
286288
return lambda m, s, g: getattr( # pylint: disable=g-long-lambda
@@ -300,19 +302,19 @@ def make_fn(attr):
300302
# * Implementing the cdf method using the Gamma distribution function; and
301303
# * Implementing the cdf method using the Student's t distribution function
302304
# when value < loc.
303-
for skewness in [0.75, 1., 1.33]:
304-
for attr in ('prob', 'cdf', 'survival_function', 'log_prob'):
305-
value, grads = self.evaluate(
306-
gradient.value_and_gradient(
307-
make_fn(attr),
308-
[loc, scale, tf.constant(skewness, self.dtype)]))
309-
self.assertAllFinite(value)
310-
self.assertAllFinite(grads[0]) # d/d loc
311-
self.assertAllFinite(grads[1]) # d/d scale
312-
self.assertAllFinite(grads[2]) # d/d skewness
305+
for attr in ('prob', 'cdf', 'survival_function', 'log_prob'):
306+
value, grads = self.evaluate(
307+
gradient.value_and_gradient(
308+
make_fn(attr),
309+
[loc, scale, tf.constant(skewness, self.dtype)]))
310+
self.assertAllFinite(value)
311+
self.assertAllFinite(grads[0]) # d/d loc
312+
self.assertAllFinite(grads[1]) # d/d scale
313+
self.assertAllFinite(grads[2]) # d/d skewness
313314

314315
@test_util.numpy_disable_gradient_test
315-
def testQuantileFiniteGradientAtDifficultPoints(self):
316+
@parameterized.parameters(0.75, 1., 1.33)
317+
def testQuantileFiniteGradientAtDifficultPoints(self, skewness):
316318
def quantile(loc, scale, skewness, probs):
317319
dist = two_piece_normal.TwoPieceNormal(
318320
loc, scale=scale, skewness=skewness, validate_args=True)
@@ -325,18 +327,18 @@ def quantile(loc, scale, skewness, probs):
325327
[np.exp(x), np.exp(-2.), 1. - np.exp(-2.), 1. - np.exp(x)],
326328
dtype=self.dtype)
327329

328-
for skewness in [0.75, 1., 1.33]:
329-
value, grads = gradient.value_and_gradient(
330-
quantile,
331-
[loc, scale, tf.constant(skewness, self.dtype), probs])
332-
self.assertAllFinite(value)
333-
self.assertAllFinite(grads[0]) # d/d loc
334-
self.assertAllFinite(grads[1]) # d/d scale
335-
self.assertAllFinite(grads[2]) # d/d skewness
336-
self.assertAllFinite(grads[3]) # d/d probs
330+
value, grads = gradient.value_and_gradient(
331+
quantile,
332+
[loc, scale, tf.constant(skewness, self.dtype), probs])
333+
self.assertAllFinite(value)
334+
self.assertAllFinite(grads[0]) # d/d loc
335+
self.assertAllFinite(grads[1]) # d/d scale
336+
self.assertAllFinite(grads[2]) # d/d skewness
337+
self.assertAllFinite(grads[3]) # d/d probs
337338

338339
@test_util.numpy_disable_gradient_test
339-
def testFullyReparameterized(self):
340+
@parameterized.parameters(0.75, 1., 1.33)
341+
def testFullyReparameterized(self, skewness):
340342
n = 100
341343
def sampler(loc, scale, skewness):
342344
dist = two_piece_normal.TwoPieceNormal(
@@ -346,18 +348,18 @@ def sampler(loc, scale, skewness):
346348
loc = tf.constant(0., self.dtype)
347349
scale = tf.constant(1., self.dtype)
348350

349-
for skewness in [0.75, 1., 1.33]:
350-
_, grads = gradient.value_and_gradient(
351-
sampler, [loc, scale, tf.constant(skewness, self.dtype)])
352-
self.assertIsNotNone(grads[0]) # d/d loc
353-
self.assertIsNotNone(grads[1]) # d/d scale
354-
self.assertIsNotNone(grads[2]) # d/d skewness
351+
_, grads = gradient.value_and_gradient(
352+
sampler, [loc, scale, tf.constant(skewness, self.dtype)])
353+
self.assertIsNotNone(grads[0]) # d/d loc
354+
self.assertIsNotNone(grads[1]) # d/d scale
355+
self.assertIsNotNone(grads[2]) # d/d skewness
355356

356357
@test_util.numpy_disable_gradient_test
357-
def testDifferentiableSampleNumerically(self):
358+
@parameterized.parameters(0.75, 1., 1.33)
359+
def testDifferentiableSampleNumerically(self, skewness):
358360
"""Test the gradients of the samples w.r.t. skewness."""
359-
sample_shape = [int(2e5)]
360-
seed = test_util.test_seed()
361+
sample_shape = [int(2e3)]
362+
seed = test_util.test_seed(sampler_type='stateless')
361363

362364
def get_abs_sample_mean(skewness):
363365
loc = tf.constant(0., self.dtype)
@@ -366,15 +368,15 @@ def get_abs_sample_mean(skewness):
366368
loc, scale=scale, skewness=skewness, validate_args=True)
367369
return tf.reduce_mean(tf.abs(dist.sample(sample_shape, seed=seed)))
368370

369-
for skewness in [0.75, 1., 1.33]:
370-
err = self.compute_max_gradient_error(
371-
get_abs_sample_mean, [tf.constant(skewness, self.dtype)], delta=0.1)
372-
self.assertLess(err, 0.05)
371+
err = self.compute_max_gradient_error(
372+
get_abs_sample_mean, [tf.constant(skewness, self.dtype)], delta=1e-1)
373+
maxerr = 0.05 if self.dtype == np.float64 else 0.09
374+
self.assertLess(err, maxerr)
373375

374376
@test_util.numpy_disable_gradient_test
375377
def testDifferentiableSampleAnalytically(self):
376378
"""Test the gradients of the samples w.r.t. loc and scale."""
377-
n = 100
379+
n = 10
378380
sample_shape = [n, n]
379381
n_samples = np.prod(sample_shape)
380382

@@ -453,7 +455,7 @@ def testIncompatibleArgShapesGraph(self):
453455
tf.ones([2, 3], dtype=self.dtype), shape=tf.TensorShape(None))
454456
self.evaluate(skewness.initializer)
455457

456-
with self.assertRaisesRegexp(Exception, r'compatible shapes'):
458+
with self.assertRaisesRegex(Exception, r'compatible shapes'):
457459
dist = two_piece_normal.TwoPieceNormal(
458460
loc=tf.zeros([4, 1], dtype=self.dtype),
459461
scale=tf.ones([4, 1], dtype=self.dtype),
@@ -505,6 +507,8 @@ class TwoPieceNormalTestDynamicShapeFloat64(test_util.TestCase,
505507
dtype = np.float64
506508
use_static_shape = False
507509

510+
del _TwoPieceNormalTest
511+
508512

509513
if __name__ == '__main__':
510514
test_util.main()

0 commit comments

Comments
 (0)