17
17
import itertools
18
18
19
19
# Dependency imports
20
+ from absl .testing import parameterized
20
21
import numpy as np
21
22
22
23
import tensorflow .compat .v2 as tf
28
29
29
30
30
31
@test_util .test_all_tf_execution_regimes
31
- class _TwoPieceNormalTest (object ):
32
+ class _TwoPieceNormalTest (parameterized . TestCase ):
32
33
33
34
def make_two_piece_normal (self ):
34
35
if self .dtype is np .float32 :
@@ -280,7 +281,8 @@ def testMode(self):
280
281
self .assertAllClose (mode , expected_mode )
281
282
282
283
@test_util .numpy_disable_gradient_test
283
- def testFiniteGradientAtDifficultPoints (self ):
284
+ @parameterized .parameters (0.75 , 1. , 1.33 )
285
+ def testFiniteGradientAtDifficultPoints (self , skewness ):
284
286
def make_fn (attr ):
285
287
x = np .array ([- 100 , - 20 , - 5. , 0. , 5. , 20 , 100 ]).astype (self .dtype )
286
288
return lambda m , s , g : getattr ( # pylint: disable=g-long-lambda
@@ -300,19 +302,19 @@ def make_fn(attr):
300
302
# * Implementing the cdf method using the Gamma distribution function; and
301
303
# * Implementing the cdf method using the Student's t distribution function
302
304
# when value < loc.
303
- for skewness in [0.75 , 1. , 1.33 ]:
304
- for attr in ('prob' , 'cdf' , 'survival_function' , 'log_prob' ):
305
- value , grads = self .evaluate (
306
- gradient .value_and_gradient (
307
- make_fn (attr ),
308
- [loc , scale , tf .constant (skewness , self .dtype )]))
309
- self .assertAllFinite (value )
310
- self .assertAllFinite (grads [0 ]) # d/d loc
311
- self .assertAllFinite (grads [1 ]) # d/d scale
312
- self .assertAllFinite (grads [2 ]) # d/d skewness
305
+ for attr in ('prob' , 'cdf' , 'survival_function' , 'log_prob' ):
306
+ value , grads = self .evaluate (
307
+ gradient .value_and_gradient (
308
+ make_fn (attr ),
309
+ [loc , scale , tf .constant (skewness , self .dtype )]))
310
+ self .assertAllFinite (value )
311
+ self .assertAllFinite (grads [0 ]) # d/d loc
312
+ self .assertAllFinite (grads [1 ]) # d/d scale
313
+ self .assertAllFinite (grads [2 ]) # d/d skewness
313
314
314
315
@test_util .numpy_disable_gradient_test
315
- def testQuantileFiniteGradientAtDifficultPoints (self ):
316
+ @parameterized .parameters (0.75 , 1. , 1.33 )
317
+ def testQuantileFiniteGradientAtDifficultPoints (self , skewness ):
316
318
def quantile (loc , scale , skewness , probs ):
317
319
dist = two_piece_normal .TwoPieceNormal (
318
320
loc , scale = scale , skewness = skewness , validate_args = True )
@@ -325,18 +327,18 @@ def quantile(loc, scale, skewness, probs):
325
327
[np .exp (x ), np .exp (- 2. ), 1. - np .exp (- 2. ), 1. - np .exp (x )],
326
328
dtype = self .dtype )
327
329
328
- for skewness in [0.75 , 1. , 1.33 ]:
329
- value , grads = gradient .value_and_gradient (
330
- quantile ,
331
- [loc , scale , tf .constant (skewness , self .dtype ), probs ])
332
- self .assertAllFinite (value )
333
- self .assertAllFinite (grads [0 ]) # d/d loc
334
- self .assertAllFinite (grads [1 ]) # d/d scale
335
- self .assertAllFinite (grads [2 ]) # d/d skewness
336
- self .assertAllFinite (grads [3 ]) # d/d probs
330
+ value , grads = gradient .value_and_gradient (
331
+ quantile ,
332
+ [loc , scale , tf .constant (skewness , self .dtype ), probs ])
333
+ self .assertAllFinite (value )
334
+ self .assertAllFinite (grads [0 ]) # d/d loc
335
+ self .assertAllFinite (grads [1 ]) # d/d scale
336
+ self .assertAllFinite (grads [2 ]) # d/d skewness
337
+ self .assertAllFinite (grads [3 ]) # d/d probs
337
338
338
339
@test_util .numpy_disable_gradient_test
339
- def testFullyReparameterized (self ):
340
+ @parameterized .parameters (0.75 , 1. , 1.33 )
341
+ def testFullyReparameterized (self , skewness ):
340
342
n = 100
341
343
def sampler (loc , scale , skewness ):
342
344
dist = two_piece_normal .TwoPieceNormal (
@@ -346,18 +348,18 @@ def sampler(loc, scale, skewness):
346
348
loc = tf .constant (0. , self .dtype )
347
349
scale = tf .constant (1. , self .dtype )
348
350
349
- for skewness in [0.75 , 1. , 1.33 ]:
350
- _ , grads = gradient .value_and_gradient (
351
- sampler , [loc , scale , tf .constant (skewness , self .dtype )])
352
- self .assertIsNotNone (grads [0 ]) # d/d loc
353
- self .assertIsNotNone (grads [1 ]) # d/d scale
354
- self .assertIsNotNone (grads [2 ]) # d/d skewness
351
+ _ , grads = gradient .value_and_gradient (
352
+ sampler , [loc , scale , tf .constant (skewness , self .dtype )])
353
+ self .assertIsNotNone (grads [0 ]) # d/d loc
354
+ self .assertIsNotNone (grads [1 ]) # d/d scale
355
+ self .assertIsNotNone (grads [2 ]) # d/d skewness
355
356
356
357
@test_util .numpy_disable_gradient_test
357
- def testDifferentiableSampleNumerically (self ):
358
+ @parameterized .parameters (0.75 , 1. , 1.33 )
359
+ def testDifferentiableSampleNumerically (self , skewness ):
358
360
"""Test the gradients of the samples w.r.t. skewness."""
359
- sample_shape = [int (2e5 )]
360
- seed = test_util .test_seed ()
361
+ sample_shape = [int (2e3 )]
362
+ seed = test_util .test_seed (sampler_type = 'stateless' )
361
363
362
364
def get_abs_sample_mean (skewness ):
363
365
loc = tf .constant (0. , self .dtype )
@@ -366,15 +368,15 @@ def get_abs_sample_mean(skewness):
366
368
loc , scale = scale , skewness = skewness , validate_args = True )
367
369
return tf .reduce_mean (tf .abs (dist .sample (sample_shape , seed = seed )))
368
370
369
- for skewness in [ 0.75 , 1. , 1.33 ]:
370
- err = self . compute_max_gradient_error (
371
- get_abs_sample_mean , [ tf . constant ( skewness , self .dtype )], delta = 0.1 )
372
- self .assertLess (err , 0.05 )
371
+ err = self . compute_max_gradient_error (
372
+ get_abs_sample_mean , [ tf . constant ( skewness , self . dtype )], delta = 1e-1 )
373
+ maxerr = 0.05 if self .dtype == np . float64 else 0.09
374
+ self .assertLess (err , maxerr )
373
375
374
376
@test_util .numpy_disable_gradient_test
375
377
def testDifferentiableSampleAnalytically (self ):
376
378
"""Test the gradients of the samples w.r.t. loc and scale."""
377
- n = 100
379
+ n = 10
378
380
sample_shape = [n , n ]
379
381
n_samples = np .prod (sample_shape )
380
382
@@ -453,7 +455,7 @@ def testIncompatibleArgShapesGraph(self):
453
455
tf .ones ([2 , 3 ], dtype = self .dtype ), shape = tf .TensorShape (None ))
454
456
self .evaluate (skewness .initializer )
455
457
456
- with self .assertRaisesRegexp (Exception , r'compatible shapes' ):
458
+ with self .assertRaisesRegex (Exception , r'compatible shapes' ):
457
459
dist = two_piece_normal .TwoPieceNormal (
458
460
loc = tf .zeros ([4 , 1 ], dtype = self .dtype ),
459
461
scale = tf .ones ([4 , 1 ], dtype = self .dtype ),
@@ -505,6 +507,8 @@ class TwoPieceNormalTestDynamicShapeFloat64(test_util.TestCase,
505
507
dtype = np .float64
506
508
use_static_shape = False
507
509
510
+ del _TwoPieceNormalTest
511
+
508
512
509
513
if __name__ == '__main__' :
510
514
test_util .main ()
0 commit comments