Skip to content

Commit b82fd08

Browse files
srvasudetensorflower-gardener
authored andcommitted
Speed up special_test:
- Use fewer samples for tests (commonly from 100000 to 1000). In order to catch errors, we can run the tests multiple times with different seeds, so the extra sampling isn't completely necessary. PiperOrigin-RevId: 473095224
1 parent 58a7885 commit b82fd08

File tree

1 file changed

+55
-58
lines changed

1 file changed

+55
-58
lines changed

tensorflow_probability/python/math/special_test.py

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -231,27 +231,23 @@ def testBetaincBFloat16(self):
231231
@parameterized.parameters(np.float32, np.float64)
232232
@test_util.numpy_disable_gradient_test
233233
def testBetaincGradient(self, dtype):
234-
space = np.logspace(-2., 2., 10).tolist()
234+
space = np.logspace(-2., 2., 5).tolist()
235235
space_x = np.linspace(0.01, 0.99, 10).tolist()
236236
a, b, x = zip(*list(itertools.product(space, space, space_x)))
237237

238238
a = np.array(a, dtype=dtype)
239239
b = np.array(b, dtype=dtype)
240240
x = np.array(x, dtype=dtype)
241241

242-
# Wrap in tf.function and compile for faster computations.
243-
betainc = tf.function(special.betainc, autograph=False, jit_compile=True)
242+
# Wrap in tf.function for faster computations.
243+
betainc = tf.function(special.betainc, autograph=False)
244244

245245
delta = 1e-4 if dtype == np.float64 else 1e-3
246246
tolerance = 7e-3 if dtype == np.float64 else 7e-2
247247
tolerance_x = 1e-3 if dtype == np.float64 else 1e-1
248248

249249
err = self.compute_max_gradient_error(
250-
lambda z: betainc(z, b, x), [a], delta=delta)
251-
self.assertLess(err, tolerance)
252-
253-
err = self.compute_max_gradient_error(
254-
lambda z: betainc(a, z, x), [b], delta=delta)
250+
lambda r, s: betainc(r, s, x), [a, b], delta=delta)
255251
self.assertLess(err, tolerance)
256252

257253
err = self.compute_max_gradient_error(
@@ -263,8 +259,8 @@ def testBetaincGradient(self, dtype):
263259
def testBetaincDerivativeFinite(self, dtype):
264260
eps = np.finfo(dtype).eps
265261

266-
space = np.logspace(np.log10(eps), 5.).tolist()
267-
space_x = np.linspace(eps, 1. - eps).tolist()
262+
space = np.logspace(np.log10(eps), 5., 20).tolist()
263+
space_x = np.linspace(eps, 1. - eps, 20).tolist()
268264
a, b, x = zip(*list(itertools.product(space, space, space_x)))
269265

270266
a = np.array(a, dtype=dtype)
@@ -699,8 +695,8 @@ def testBetaincinvGradientFinite(self, dtype):
699695
eps = np.finfo(dtype).eps
700696
small = np.sqrt(eps)
701697

702-
space = np.logspace(np.log10(small), 4.).tolist()
703-
space_y = np.linspace(eps, 1. - small).tolist()
698+
space = np.logspace(np.log10(small), 4., 20).tolist()
699+
space_y = np.linspace(eps, 1. - small, 20).tolist()
704700
a, b, y = [
705701
tf.constant(z, dtype=dtype)
706702
for z in zip(*list(itertools.product(space, space, space_y)))]
@@ -822,7 +818,7 @@ def testDawsnOdd(self, dtype):
822818
seed_stream = test_util.test_seed_stream()
823819
x = self.evaluate(
824820
tf.random.uniform(
825-
[int(1e4)], 0., 100., dtype=dtype, seed=seed_stream()))
821+
[int(1e3)], 0., 100., dtype=dtype, seed=seed_stream()))
826822
self.assertAllClose(
827823
self.evaluate(special.dawsn(x)), self.evaluate(-special.dawsn(-x)))
828824

@@ -831,21 +827,21 @@ def testDawsnSmall(self, dtype):
831827
seed_stream = test_util.test_seed_stream()
832828
x = self.evaluate(
833829
tf.random.uniform(
834-
[int(1e4)], 0., 1., dtype=dtype, seed=seed_stream()))
830+
[int(1e3)], 0., 1., dtype=dtype, seed=seed_stream()))
835831
self.assertAllClose(scipy_special.dawsn(x), self.evaluate(special.dawsn(x)))
836832

837833
@parameterized.parameters(np.float32, np.float64)
838834
def testDawsnMedium(self, dtype):
839835
seed_stream = test_util.test_seed_stream()
840836
x = self.evaluate(
841-
tf.random.uniform([int(1e4)], 1., 10., dtype=dtype, seed=seed_stream()))
837+
tf.random.uniform([int(1e3)], 1., 10., dtype=dtype, seed=seed_stream()))
842838
self.assertAllClose(scipy_special.dawsn(x), self.evaluate(special.dawsn(x)))
843839

844840
@parameterized.parameters(np.float32, np.float64)
845841
def testDawsnLarge(self, dtype):
846842
seed_stream = test_util.test_seed_stream()
847843
x = self.evaluate(tf.random.uniform(
848-
[int(1e4)], 10., 100., dtype=dtype, seed=seed_stream()))
844+
[int(1e3)], 10., 100., dtype=dtype, seed=seed_stream()))
849845
self.assertAllClose(scipy_special.dawsn(x), self.evaluate(special.dawsn(x)))
850846

851847
@test_util.numpy_disable_gradient_test
@@ -885,50 +881,50 @@ def test_igammainv_bounds(self):
885881
@parameterized.parameters((np.float32, 1.5e-4), (np.float64, 1e-6))
886882
def test_igammainv_inverse_small_a(self, dtype, rtol):
887883
seed_stream = test_util.test_seed_stream()
888-
a = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
889-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
884+
a = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
885+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
890886
igammainv, a, p = self.evaluate([special.igammainv(a, p), a, p])
891887
self.assertAllClose(scipy_special.gammaincinv(a, p), igammainv, rtol=rtol)
892888

893889
@parameterized.parameters((np.float32, 1.5e-4), (np.float64, 1e-6))
894890
def test_igammacinv_inverse_small_a(self, dtype, rtol):
895891
seed_stream = test_util.test_seed_stream()
896-
a = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
897-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
892+
a = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
893+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
898894
igammacinv, a, p = self.evaluate([special.igammacinv(a, p), a, p])
899895
self.assertAllClose(scipy_special.gammainccinv(a, p), igammacinv, rtol=rtol)
900896

901897
@parameterized.parameters((np.float32, 1e-4), (np.float64, 1e-6))
902898
def test_igammainv_inverse_medium_a(self, dtype, rtol):
903899
seed_stream = test_util.test_seed_stream()
904-
a = tf.random.uniform([int(1e4)], 1., 100., dtype=dtype, seed=seed_stream())
905-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
900+
a = tf.random.uniform([int(1e3)], 1., 100., dtype=dtype, seed=seed_stream())
901+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
906902
igammainv, a, p = self.evaluate([special.igammainv(a, p), a, p])
907903
self.assertAllClose(scipy_special.gammaincinv(a, p), igammainv, rtol=rtol)
908904

909905
@parameterized.parameters((np.float32, 1e-4), (np.float64, 1e-6))
910906
def test_igammacinv_inverse_medium_a(self, dtype, rtol):
911907
seed_stream = test_util.test_seed_stream()
912-
a = tf.random.uniform([int(1e4)], 1., 100., dtype=dtype, seed=seed_stream())
913-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
908+
a = tf.random.uniform([int(1e3)], 1., 100., dtype=dtype, seed=seed_stream())
909+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
914910
igammacinv, a, p = self.evaluate([special.igammacinv(a, p), a, p])
915911
self.assertAllClose(scipy_special.gammainccinv(a, p), igammacinv, rtol=rtol)
916912

917913
@parameterized.parameters((np.float32, 3e-4), (np.float64, 1e-6))
918914
def test_igammainv_inverse_large_a(self, dtype, rtol):
919915
seed_stream = test_util.test_seed_stream()
920916
a = tf.random.uniform(
921-
[int(1e4)], 100., 10000., dtype=dtype, seed=seed_stream())
922-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
917+
[int(1e3)], 100., 10000., dtype=dtype, seed=seed_stream())
918+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
923919
igammainv, a, p = self.evaluate([special.igammainv(a, p), a, p])
924920
self.assertAllClose(scipy_special.gammaincinv(a, p), igammainv, rtol=rtol)
925921

926922
@parameterized.parameters((np.float32, 3e-4), (np.float64, 1e-6))
927923
def test_igammacinv_inverse_large_a(self, dtype, rtol):
928924
seed_stream = test_util.test_seed_stream()
929925
a = tf.random.uniform(
930-
[int(1e4)], 100., 10000., dtype=dtype, seed=seed_stream())
931-
p = tf.random.uniform([int(1e4)], 0., 1., dtype=dtype, seed=seed_stream())
926+
[int(1e3)], 100., 10000., dtype=dtype, seed=seed_stream())
927+
p = tf.random.uniform([int(1e3)], 0., 1., dtype=dtype, seed=seed_stream())
932928
igammacinv, a, p = self.evaluate([special.igammacinv(a, p), a, p])
933929
self.assertAllClose(scipy_special.gammainccinv(a, p), igammacinv, rtol=rtol)
934930

@@ -1031,13 +1027,13 @@ def testOwensTOddEven(self, dtype):
10311027
def testOwensTSmall(self, dtype):
10321028
seed_stream = test_util.test_seed_stream()
10331029
a = tf.random.uniform(
1034-
shape=[int(1e4)],
1030+
shape=[int(1e3)],
10351031
minval=0.,
10361032
maxval=1.,
10371033
dtype=dtype,
10381034
seed=seed_stream())
10391035
h = tf.random.uniform(
1040-
shape=[int(1e4)],
1036+
shape=[int(1e3)],
10411037
minval=0.,
10421038
maxval=1.,
10431039
dtype=dtype,
@@ -1049,13 +1045,13 @@ def testOwensTSmall(self, dtype):
10491045
def testOwensTLarger(self, dtype):
10501046
seed_stream = test_util.test_seed_stream()
10511047
a = tf.random.uniform(
1052-
shape=[int(1e4)],
1048+
shape=[int(1e3)],
10531049
minval=1.,
10541050
maxval=100.,
10551051
dtype=dtype,
10561052
seed=seed_stream())
10571053
h = tf.random.uniform(
1058-
shape=[int(1e4)],
1054+
shape=[int(1e3)],
10591055
minval=1.,
10601056
maxval=100.,
10611057
dtype=dtype,
@@ -1067,13 +1063,13 @@ def testOwensTLarger(self, dtype):
10671063
def testOwensTLarge(self, dtype):
10681064
seed_stream = test_util.test_seed_stream()
10691065
a = tf.random.uniform(
1070-
shape=[int(1e4)],
1066+
shape=[int(1e3)],
10711067
minval=100.,
10721068
maxval=1000.,
10731069
dtype=dtype,
10741070
seed=seed_stream())
10751071
h = tf.random.uniform(
1076-
shape=[int(1e4)],
1072+
shape=[int(1e3)],
10771073
minval=100.,
10781074
maxval=1000.,
10791075
dtype=dtype,
@@ -1119,13 +1115,13 @@ class SpecialTest(test_util.TestCase):
11191115
def testAtanDifferenceSmall(self, dtype):
11201116
seed_stream = test_util.test_seed_stream()
11211117
x = tf.random.uniform(
1122-
shape=[int(1e5)],
1118+
shape=[int(1e3)],
11231119
minval=-10.,
11241120
maxval=10.,
11251121
dtype=dtype,
11261122
seed=seed_stream())
11271123
y = tf.random.uniform(
1128-
shape=[int(1e5)],
1124+
shape=[int(1e3)],
11291125
minval=-10.,
11301126
maxval=10.,
11311127
dtype=dtype,
@@ -1138,13 +1134,13 @@ def testAtanDifferenceSmall(self, dtype):
11381134
def testAtanDifferenceLarge(self, dtype):
11391135
seed_stream = test_util.test_seed_stream()
11401136
x = tf.random.uniform(
1141-
shape=[int(1e5)],
1137+
shape=[int(1e3)],
11421138
minval=-100.,
11431139
maxval=100.,
11441140
dtype=dtype,
11451141
seed=seed_stream())
11461142
y = tf.random.uniform(
1147-
shape=[int(1e5)],
1143+
shape=[int(1e3)],
11481144
minval=-100.,
11491145
maxval=100.,
11501146
dtype=dtype,
@@ -1166,7 +1162,7 @@ def testAtanDifferenceCloseInputs(self, dtype):
11661162
def testAtanDifferenceProductIsNegativeOne(self, dtype):
11671163
seed_stream = test_util.test_seed_stream()
11681164
x = tf.random.uniform(
1169-
shape=[int(1e5)],
1165+
shape=[int(1e3)],
11701166
minval=-10.,
11711167
maxval=10.,
11721168
dtype=dtype,
@@ -1180,7 +1176,7 @@ def testAtanDifferenceProductIsNegativeOne(self, dtype):
11801176
def testErfcinvPreservesDtype(self, dtype):
11811177
x = self.evaluate(
11821178
tf.random.uniform(
1183-
shape=[int(1e5)],
1179+
shape=[int(1e3)],
11841180
minval=0.,
11851181
maxval=1.,
11861182
dtype=dtype,
@@ -1190,7 +1186,7 @@ def testErfcinvPreservesDtype(self, dtype):
11901186
def testErfcinv(self):
11911187
x = self.evaluate(
11921188
tf.random.uniform(
1193-
shape=[int(1e5)],
1189+
shape=[int(1e3)],
11941190
minval=0.,
11951191
maxval=1.,
11961192
seed=test_util.test_seed()))
@@ -1206,7 +1202,7 @@ def testErfcinv(self):
12061202
@parameterized.parameters(tf.float32, tf.float64)
12071203
def testErfcxSmall(self, dtype):
12081204
x = tf.random.uniform(
1209-
shape=[int(1e5)],
1205+
shape=[int(1e3)],
12101206
minval=0.,
12111207
maxval=1.,
12121208
dtype=dtype,
@@ -1218,7 +1214,7 @@ def testErfcxSmall(self, dtype):
12181214
@parameterized.parameters(tf.float32, tf.float64)
12191215
def testErfcxMedium(self, dtype):
12201216
x = tf.random.uniform(
1221-
shape=[int(1e5)],
1217+
shape=[int(1e3)],
12221218
minval=1.,
12231219
maxval=20.,
12241220
dtype=dtype,
@@ -1230,7 +1226,7 @@ def testErfcxMedium(self, dtype):
12301226
@parameterized.parameters(tf.float32, tf.float64)
12311227
def testErfcxLarge(self, dtype):
12321228
x = tf.random.uniform(
1233-
shape=[int(1e5)],
1229+
shape=[int(1e3)],
12341230
minval=20.,
12351231
maxval=100.,
12361232
dtype=dtype,
@@ -1242,7 +1238,7 @@ def testErfcxLarge(self, dtype):
12421238
@parameterized.parameters(tf.float32, tf.float64)
12431239
def testErfcxSmallNegative(self, dtype):
12441240
x = tf.random.uniform(
1245-
shape=[int(1e5)],
1241+
shape=[int(1e3)],
12461242
minval=-1.,
12471243
maxval=0.,
12481244
dtype=dtype,
@@ -1254,7 +1250,7 @@ def testErfcxSmallNegative(self, dtype):
12541250
@parameterized.parameters(tf.float32, tf.float64)
12551251
def testErfcxMediumNegative(self, dtype):
12561252
x = tf.random.uniform(
1257-
shape=[int(1e5)],
1253+
shape=[int(1e3)],
12581254
minval=-20.,
12591255
maxval=-1.,
12601256
dtype=dtype,
@@ -1266,7 +1262,7 @@ def testErfcxMediumNegative(self, dtype):
12661262
@parameterized.parameters(tf.float32, tf.float64)
12671263
def testErfcxLargeNegative(self, dtype):
12681264
x = tf.random.uniform(
1269-
shape=[int(1e5)],
1265+
shape=[int(1e3)],
12701266
minval=-100.,
12711267
maxval=-20.,
12721268
dtype=dtype,
@@ -1291,7 +1287,7 @@ def testErfcxSecondDerivative(self):
12911287
@parameterized.parameters(tf.float32, tf.float64)
12921288
def testLogErfc(self, dtype):
12931289
x = tf.random.uniform(
1294-
shape=[int(1e5)],
1290+
shape=[int(1e3)],
12951291
minval=-3.,
12961292
maxval=3.,
12971293
dtype=dtype,
@@ -1317,7 +1313,7 @@ def testLogErfcValueAndGradientNoNaN(self, dtype):
13171313
@parameterized.parameters(tf.float32, tf.float64)
13181314
def testLogErfcx(self, dtype):
13191315
x = tf.random.uniform(
1320-
shape=[int(1e5)],
1316+
shape=[int(1e3)],
13211317
minval=-3.,
13221318
maxval=3.,
13231319
dtype=dtype,
@@ -1400,7 +1396,7 @@ def testLambertWGradient(self, value, expected):
14001396

14011397
def testLogGammaCorrection(self):
14021398
x = half_cauchy.HalfCauchy(
1403-
loc=8., scale=10.).sample(10000, test_util.test_seed())
1399+
loc=8., scale=10.).sample(int(1e3), test_util.test_seed())
14041400
pi = 3.14159265
14051401
stirling = x * tf.math.log(x) - x + 0.5 * tf.math.log(2 * pi / x)
14061402
tfp_gamma_ = stirling + special.log_gamma_correction(x)
@@ -1409,12 +1405,13 @@ def testLogGammaCorrection(self):
14091405

14101406
def testLogGammaDifference(self):
14111407
y = half_cauchy.HalfCauchy(
1412-
loc=8., scale=10.).sample(10000, test_util.test_seed())
1408+
loc=8., scale=10.).sample(int(1e3), test_util.test_seed())
14131409
y_64 = tf.cast(y, tf.float64)
14141410
# Not testing x near zero because the naive method is too inaccurate.
14151411
# We will get implicit coverage in testLogBeta, where a good reference
14161412
# implementation is available (scipy_special.betaln).
1417-
x = uniform.Uniform(low=4., high=12.).sample(10000, test_util.test_seed())
1413+
x = uniform.Uniform(
1414+
low=4., high=12.).sample(int(1e3), test_util.test_seed())
14181415
x_64 = tf.cast(x, tf.float64)
14191416
naive_64_ = tf.math.lgamma(y_64) - tf.math.lgamma(x_64 + y_64)
14201417
naive_64, sophisticated, sophisticated_64 = self.evaluate([
@@ -1437,8 +1434,8 @@ def simple_difference(x, y):
14371434
return tf.math.lgamma(y) - tf.math.lgamma(x + y)
14381435

14391436
y = half_cauchy.HalfCauchy(
1440-
loc=8., scale=10.).sample(10000, test_util.test_seed())
1441-
x = uniform.Uniform(low=0., high=8.).sample(10000, test_util.test_seed())
1437+
loc=8., scale=10.).sample(int(1e3), test_util.test_seed())
1438+
x = uniform.Uniform(low=0., high=8.).sample(int(1e3), test_util.test_seed())
14421439
_, [simple_gx_,
14431440
simple_gy_] = gradient.value_and_gradient(simple_difference, [x, y])
14441441
_, [gx_, gy_] = gradient.value_and_gradient(special.log_gamma_difference,
@@ -1465,9 +1462,9 @@ def simple_difference(x, y):
14651462

14661463
def testLogBeta(self):
14671464
strm = test_util.test_seed_stream()
1468-
x = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(10000, strm())
1465+
x = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(int(1e3), strm())
14691466
x = self.evaluate(x)
1470-
y = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(10000, strm())
1467+
y = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(int(1e3), strm())
14711468
y = self.evaluate(y)
14721469
# Why not 1e-8?
14731470
# - Could be because scipy does the reduction loops recommended
@@ -1484,8 +1481,8 @@ def testLogBetaGradient(self):
14841481
def simple_lbeta(x, y):
14851482
return tf.math.lgamma(x) + tf.math.lgamma(y) - tf.math.lgamma(x + y)
14861483
strm = test_util.test_seed_stream()
1487-
x = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(10000, strm())
1488-
y = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(10000, strm())
1484+
x = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(int(1e3), strm())
1485+
y = half_cauchy.HalfCauchy(loc=1., scale=15.).sample(int(1e3), strm())
14891486
_, [simple_gx_,
14901487
simple_gy_] = gradient.value_and_gradient(simple_lbeta, [x, y])
14911488
_, [gx_, gy_] = gradient.value_and_gradient(special.lbeta, [x, y])

0 commit comments

Comments
 (0)