42
42
'LambertWNormal' , # CDF gradient incorrect at 0.
43
43
'SigmoidBeta' , # inverse CDF numerical precision issues for large x
44
44
'StudentT' , # CDF gradient incorrect at 0 (and unstable near zero).
45
- )
45
+ )
46
46
47
47
if JAX_MODE :
48
48
PRECONDITIONING_FAILS_DISTS = (
49
49
'VonMises' , # Abstract eval for 'von_mises_cdf_jvp' not implemented.
50
- ) + PRECONDITIONING_FAILS_DISTS
50
+ ) + PRECONDITIONING_FAILS_DISTS
51
51
52
52
53
53
def _constrained_zeros_fn (shape , dtype , constraint_fn ):
@@ -60,15 +60,18 @@ class DistributionBijectorsTest(test_util.TestCase):
60
60
61
61
def assertDistributionIsApproximatelyStandardNormal (self ,
62
62
dist ,
63
+ rtol = 1e-6 ,
63
64
logprob_atol = 1e-2 ,
64
65
grad_atol = 1e-2 ):
65
66
"""Verifies that dist's lps and gradients match those of Normal(0., 1.)."""
66
67
batch_shape = dist .batch_shape_tensor ()
68
+
67
69
def make_reference_values (event_shape ):
68
70
dist_shape = ps .concat ([batch_shape , event_shape ], axis = 0 )
69
71
x = tf .reshape ([- 4. , - 2. , 0. , 2. , 4. ],
70
72
ps .concat ([[5 ], ps .ones_like (dist_shape )], axis = 0 ))
71
73
return tf .broadcast_to (x , ps .concat ([[5 ], dist_shape ], axis = 0 ))
74
+
72
75
flat_event_shape = tf .nest .flatten (dist .event_shape_tensor ())
73
76
zs = [make_reference_values (s ) for s in flat_event_shape ]
74
77
lp_dist , grad_dist = tfp .math .value_and_gradient (
@@ -83,11 +86,14 @@ def reference_value_and_gradient(z, event_shape):
83
86
reference_vals_and_grads = [
84
87
reference_value_and_gradient (z , event_shape )
85
88
for (z , event_shape ) in zip (zs , flat_event_shape )]
89
+
86
90
lps_reference = [lp for lp , grad in reference_vals_and_grads ]
87
- self .assertAllClose (sum (lps_reference ), lp_dist , atol = logprob_atol )
91
+ self .assertAllClose (
92
+ sum (lps_reference ), lp_dist , rtol = rtol , atol = logprob_atol )
88
93
89
94
grads_reference = [grad for lp , grad in reference_vals_and_grads ]
90
- self .assertAllCloseNested (grads_reference , grad_dist , atol = grad_atol )
95
+ self .assertAllCloseNested (
96
+ grads_reference , grad_dist , rtol = rtol , atol = grad_atol )
91
97
92
98
@parameterized .named_parameters (
93
99
{'testcase_name' : dname , 'dist_name' : dname }
@@ -101,10 +107,11 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
101
107
if dist_name in PRECONDITIONING_FAILS_DISTS :
102
108
self .skipTest ('Known failure.' )
103
109
104
- dist = data .draw (dhps .base_distributions (
105
- dist_name = dist_name ,
106
- enable_vars = False ,
107
- param_strategy_fn = _constrained_zeros_fn ))
110
+ dist = data .draw (
111
+ dhps .base_distributions (
112
+ dist_name = dist_name ,
113
+ enable_vars = False ,
114
+ param_strategy_fn = _constrained_zeros_fn ))
108
115
try :
109
116
b = tfp .experimental .bijectors .make_distribution_bijector (dist )
110
117
except NotImplementedError :
@@ -114,22 +121,20 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
114
121
115
122
@test_util .numpy_disable_gradient_test
116
123
def test_multivariate_normal (self ):
117
- d = tfd .MultivariateNormalFullCovariance (loc = [4. , 8. ],
118
- covariance_matrix = [[11. , 0.099 ],
119
- [0.099 , 0.1 ]])
124
+ d = tfd .MultivariateNormalFullCovariance (
125
+ loc = [4. , 8. ], covariance_matrix = [[11. , 0.099 ], [0.099 , 0.1 ]])
120
126
b = tfp .experimental .bijectors .make_distribution_bijector (d )
121
- self .assertDistributionIsApproximatelyStandardNormal (
122
- tfb .Invert (b )(d ))
127
+ self .assertDistributionIsApproximatelyStandardNormal (tfb .Invert (b )(d ))
123
128
124
129
@test_util .numpy_disable_gradient_test
125
130
def test_markov_chain (self ):
126
131
d = tfd .MarkovChain (
127
132
initial_state_prior = tfd .Uniform (low = 0. , high = 1. ),
128
133
transition_fn = lambda _ , x : tfd .Uniform (low = 0. , high = tf .nn .softplus (x )),
129
- num_steps = 10 )
134
+ num_steps = 3 )
130
135
b = tfp .experimental .bijectors .make_distribution_bijector (d )
131
136
self .assertDistributionIsApproximatelyStandardNormal (
132
- tfb .Invert (b )(d ))
137
+ tfb .Invert (b )(d ), rtol = 1e-4 )
133
138
134
139
@test_util .numpy_disable_gradient_test
135
140
def test_markov_chain_joint (self ):
@@ -145,21 +150,22 @@ def test_markov_chain_joint(self):
145
150
num_steps = 10 )
146
151
b = tfp .experimental .bijectors .make_distribution_bijector (d )
147
152
self .assertDistributionIsApproximatelyStandardNormal (
148
- tfb .Invert (b )(d ))
153
+ tfb .Invert (b )(d ), rtol = 1e-4 )
149
154
150
155
@test_util .numpy_disable_gradient_test
151
156
def test_nested_joint_distribution (self ):
152
157
153
158
def model ():
154
159
x = yield tfd .Normal (loc = - 2. , scale = 1. )
155
160
yield tfd .JointDistributionSequentialAutoBatched ([
156
- tfd .Uniform (low = 1. + tf .exp (x ),
157
- high = 1 + tf .exp (x ) + tf .nn .softplus (x )),
161
+ tfd .Uniform (low = 1. - tf .exp (x ),
162
+ high = 2. + tf .exp (x ) + tf .nn .softplus (x )),
158
163
lambda v : tfd .Exponential (v )]) # pylint: disable=unnecessary-lambda
164
+
159
165
dist = tfd .JointDistributionCoroutineAutoBatched (model )
160
166
b = tfp .experimental .bijectors .make_distribution_bijector (dist )
161
167
self .assertDistributionIsApproximatelyStandardNormal (
162
- tfb .Invert (b )(dist ))
168
+ tfb .Invert (b )(dist ), rtol = 1e-4 )
163
169
164
170
@test_util .numpy_disable_gradient_test
165
171
@test_util .jax_disable_test_missing_functionality (
@@ -171,6 +177,7 @@ def model_with_funnel():
171
177
z = yield tfd .Normal (loc = - 1. , scale = 2. , name = 'z' )
172
178
x = yield tfd .Normal (loc = [0. ], scale = tf .exp (z ), name = 'x' )
173
179
yield tfd .Poisson (log_rate = x , name = 'y' )
180
+
174
181
pinned_model = model_with_funnel .experimental_pin (y = [1 ])
175
182
surrogate_posterior = tfp .experimental .vi .build_asvi_surrogate_posterior (
176
183
pinned_model )
@@ -191,15 +198,16 @@ def do_sample():
191
198
kernel = tfp .mcmc .DualAveragingStepSizeAdaptation (
192
199
tfp .mcmc .TransformedTransitionKernel (
193
200
tfp .mcmc .NoUTurnSampler (
194
- pinned_model .unnormalized_log_prob ,
195
- step_size = 0.1 ),
201
+ pinned_model .unnormalized_log_prob , step_size = 0.1 ),
196
202
bijector = bijector ),
197
203
num_adaptation_steps = 5 ),
198
204
current_state = surrogate_posterior .sample (),
199
205
num_burnin_steps = 5 ,
200
206
trace_fn = lambda _0 , _1 : [],
201
207
num_results = 10 )
208
+
202
209
do_sample ()
203
210
211
+
204
212
if __name__ == '__main__' :
205
213
test_util .main ()
0 commit comments