14
14
# ==============================================================================
15
15
"""Tests of deep factorized distribution."""
16
16
17
+ from absl .testing import parameterized
17
18
import tensorflow .compat .v2 as tf
18
19
import tensorflow_probability as tfp
19
20
20
21
from tensorflow_compression .python .distributions import deep_factorized
21
22
from tensorflow_compression .python .distributions import helpers
22
23
23
24
24
- class DeepFactorizedTest (tf .test .TestCase ):
25
+ class DeepFactorizedTest (tf .test .TestCase , parameterized . TestCase ):
25
26
26
27
def test_can_instantiate_scalar (self ):
27
28
df = deep_factorized .DeepFactorized ()
@@ -37,56 +38,31 @@ def test_can_instantiate_batched(self):
37
38
self .assertEqual (df .num_filters , (3 , 3 ))
38
39
self .assertEqual (df .init_scale , 10 )
39
40
40
- def test_logistic_is_special_case_prob (self ):
41
+ @parameterized .parameters (
42
+ "prob" , "log_prob" ,
43
+ "cdf" , "log_cdf" ,
44
+ "survival_function" , "log_survival_function" ,
45
+ )
46
+ def test_logistic_is_special_case (self , method ):
41
47
# With no hidden units, the density should collapse to a logistic
42
48
# distribution.
43
49
df = deep_factorized .DeepFactorized (num_filters = (), init_scale = 1 )
44
50
logistic = tfp .distributions .Logistic (loc = - df ._biases [0 ][0 , 0 ], scale = 1. )
45
51
x = tf .linspace (- 5. , 5. , 20 )
46
- prob_df = df .prob (x )
47
- prob_logistic = logistic .prob (x )
48
- self .assertAllClose (prob_df , prob_logistic )
49
-
50
- def test_logistic_is_special_case_cdf (self ):
51
- # With no hidden units, the density should collapse to a logistic
52
- # distribution.
53
- df = deep_factorized .DeepFactorized (num_filters = (), init_scale = 1 )
54
- logistic = tfp .distributions .Logistic (loc = - df ._biases [0 ][0 , 0 ], scale = 1. )
55
- x = tf .linspace (- 5. , 5. , 20 )
56
- cdf_df = df .cdf (x )
57
- cdf_logistic = logistic .cdf (x )
58
- self .assertAllClose (cdf_df , cdf_logistic )
59
-
60
- def test_logistic_is_special_case_log_prob (self ):
61
- # With no hidden units, the density should collapse to a logistic
62
- # distribution.
63
- df = deep_factorized .DeepFactorized (num_filters = (), init_scale = 1 )
64
- logistic = tfp .distributions .Logistic (loc = - df ._biases [0 ][0 , 0 ], scale = 1. )
65
- x = tf .linspace (- 5000. , 5000. , 1000 )
66
- log_prob_df = df .log_prob (x )
67
- log_prob_logistic = logistic .log_prob (x )
68
- self .assertAllClose (log_prob_df , log_prob_logistic )
69
-
70
- def test_logistic_is_special_case_log_cdf (self ):
71
- # With no hidden units, the density should collapse to a logistic
72
- # distribution.
73
- df = deep_factorized .DeepFactorized (num_filters = (), init_scale = 1 )
74
- logistic = tfp .distributions .Logistic (loc = - df ._biases [0 ][0 , 0 ], scale = 1. )
75
- x = tf .linspace (- 5000. , 5000. , 1000 )
76
- log_cdf_df = df .log_cdf (x )
77
- log_cdf_logistic = logistic .log_cdf (x )
78
- self .assertAllClose (log_cdf_df , log_cdf_logistic )
79
-
80
- def test_logistic_is_special_case_log_survival_function (self ):
81
- # With no hidden units, the density should collapse to a logistic
82
- # distribution.
83
- df = deep_factorized .DeepFactorized (num_filters = (), init_scale = 1 )
84
- logistic = tfp .distributions .Logistic (loc = - df ._biases [0 ][0 , 0 ], scale = 1. )
85
- x = tf .linspace (- 5000. , 5000. , 1000 )
86
- log_survival_function_df = df .log_survival_function (x )
87
- log_survival_function_logistic = logistic .log_survival_function (x )
88
- self .assertAllClose (log_survival_function_df ,
89
- log_survival_function_logistic )
52
+ val_df = getattr (df , method )(x )
53
+ val_logistic = getattr (logistic , method )(x )
54
+ self .assertAllClose (val_df , val_logistic )
55
+
56
+ @parameterized .parameters (
57
+ "prob" , "log_prob" ,
58
+ "cdf" , "log_cdf" ,
59
+ "survival_function" , "log_survival_function" ,
60
+ )
61
+ def test_broadcasts_correctly (self , method ):
62
+ df = deep_factorized .DeepFactorized (batch_shape = (2 , 3 ))
63
+ x = tf .reshape (tf .linspace (- 5. , 5. , 20 ), (4 , 5 , 1 , 1 ))
64
+ val = getattr (df , method )(x )
65
+ self .assertEqual (val .shape , (4 , 5 , 2 , 3 ))
90
66
91
67
92
68
class NoisyDeepFactorizedTest (tf .test .TestCase ):
@@ -140,13 +116,11 @@ def test_quantization_offset_is_zero(self):
140
116
df = deep_factorized .NoisyDeepFactorized ()
141
117
self .assertEqual (helpers .quantization_offset (df ), 0 )
142
118
143
- def test_tails_and_offset_are_in_order (self ):
119
+ def test_tails_are_in_order (self ):
144
120
df = deep_factorized .NoisyDeepFactorized ()
145
- offset = helpers .quantization_offset (df )
146
121
lower_tail = helpers .lower_tail (df , 2 ** - 8 )
147
122
upper_tail = helpers .upper_tail (df , 2 ** - 8 )
148
- self .assertGreater (upper_tail , offset )
149
- self .assertGreater (offset , lower_tail )
123
+ self .assertGreater (upper_tail , lower_tail )
150
124
151
125
def test_stats_throw_error (self ):
152
126
df = deep_factorized .NoisyDeepFactorized ()
0 commit comments