Skip to content

Commit 9f3f390

Browse files
pravnartensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 451046918
1 parent 5a059fc commit 9f3f390

File tree

5 files changed

+72
-33
lines changed

5 files changed

+72
-33
lines changed

tensorflow_probability/python/bijectors/moyal_cdf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,15 @@ def _forward(self, x):
109109

110110
def _inverse(self, y):
111111
with tf.control_dependencies(self._maybe_assert_valid_y(y)):
112+
np_dtype = dtype_util.as_numpy_dtype(y.dtype)
112113
return (self.loc - self.scale *
113-
(np.log(2.) + 2. * tf.math.log(tfp_math.erfcinv(y))))
114+
(np.log(np_dtype(2.)) + 2. * tf.math.log(tfp_math.erfcinv(y))))
114115

115116
def _inverse_log_det_jacobian(self, y):
116117
with tf.control_dependencies(self._maybe_assert_valid_y(y)):
117-
return (tf.math.square(tfp_math.erfcinv(y)) +
118-
tf.math.log(self.scale) + 0.5 * np.log(np.pi) -
119-
tf.math.log(tfp_math.erfcinv(y)))
118+
np_dtype = dtype_util.as_numpy_dtype(y.dtype)
119+
return (tf.math.square(tfp_math.erfcinv(y)) + tf.math.log(self.scale) +
120+
0.5 * np_dtype(np.log(np.pi)) - tf.math.log(tfp_math.erfcinv(y)))
120121

121122
def _forward_log_det_jacobian(self, x):
122123
scale = tf.convert_to_tensor(self.scale)

tensorflow_probability/python/distributions/transformed_distribution.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
354354
tf.cast(fldj, base_distribution_log_prob.dtype))
355355

356356
def _log_prob(self, y, **kwargs):
357+
if self.bijector._is_injective: # pylint: disable=protected-access
358+
log_prob, _ = self.experimental_local_measure(
359+
y, backward_compat=True, **kwargs)
360+
return log_prob
361+
362+
# TODO(b/197680518): Support base measure handling for non-injective
363+
# bijectors.
357364
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
358365

359366
# For caching to work, it is imperative that the bijector is the first to
@@ -366,9 +373,6 @@ def _log_prob(self, y, **kwargs):
366373

367374
ildj = self.bijector.inverse_log_det_jacobian(
368375
y, event_ndims=event_ndims, **bijector_kwargs)
369-
if self.bijector._is_injective: # pylint: disable=protected-access
370-
base_log_prob = self.distribution.log_prob(x, **distribution_kwargs)
371-
return base_log_prob + tf.cast(ildj, base_log_prob.dtype)
372376

373377
# Compute log_prob on each element of the inverse image.
374378
lp_on_fibers = []
@@ -596,6 +600,32 @@ def _default_event_space_bijector(self):
596600
self.distribution.experimental_default_event_space_bijector())
597601
# pylint: enable=not-callable
598602

603+
def experimental_local_measure(self, y, backward_compat=False, **kwargs):
604+
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
605+
606+
# For caching to work, it is imperative that the bijector is the first to
607+
# modify the input.
608+
x = self.bijector.inverse(y, **bijector_kwargs)
609+
event_ndims = self.bijector.inverse_event_ndims(
610+
tf.nest.map_structure(ps.rank_from_shape, self._event_shape_tensor(),
611+
self.event_shape), **bijector_kwargs)
612+
613+
if self.bijector._is_injective: # pylint: disable=protected-access
614+
local_measure_fn = self.distribution.experimental_local_measure
615+
density_corr_fn = self.bijector.experimental_compute_density_correction
616+
base_log_prob, tangent_space = local_measure_fn(
617+
x, backward_compat=backward_compat, **distribution_kwargs)
618+
correction, new_tangent_space = density_corr_fn(
619+
x,
620+
tangent_space,
621+
backward_compat=backward_compat,
622+
event_ndims=event_ndims,
623+
**bijector_kwargs)
624+
log_prob = base_log_prob - tf.cast(correction, base_log_prob.dtype)
625+
return log_prob, new_tangent_space
626+
else:
627+
raise NotImplementedError
628+
599629

600630
class TransformedDistribution(
601631
_TransformedDistribution, distribution_lib.AutoCompositeTensorDistribution):
@@ -671,4 +701,3 @@ def _transformed_log_prob_ratio(p, x, q, y, name=None):
671701
ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio(
672702
p.bijector, x, q.bijector, y, event_ndims)
673703
return base_log_prob_ratio + tf.cast(ildj_ratio, base_log_prob_ratio.dtype)
674-

tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242
'LambertWNormal', # CDF gradient incorrect at 0.
4343
'SigmoidBeta', # inverse CDF numerical precision issues for large x
4444
'StudentT', # CDF gradient incorrect at 0 (and unstable near zero).
45-
)
45+
)
4646

4747
if JAX_MODE:
4848
PRECONDITIONING_FAILS_DISTS = (
4949
'VonMises', # Abstract eval for 'von_mises_cdf_jvp' not implemented.
50-
) + PRECONDITIONING_FAILS_DISTS
50+
) + PRECONDITIONING_FAILS_DISTS
5151

5252

5353
def _constrained_zeros_fn(shape, dtype, constraint_fn):
@@ -60,15 +60,18 @@ class DistributionBijectorsTest(test_util.TestCase):
6060

6161
def assertDistributionIsApproximatelyStandardNormal(self,
6262
dist,
63+
rtol=1e-6,
6364
logprob_atol=1e-2,
6465
grad_atol=1e-2):
6566
"""Verifies that dist's lps and gradients match those of Normal(0., 1.)."""
6667
batch_shape = dist.batch_shape_tensor()
68+
6769
def make_reference_values(event_shape):
6870
dist_shape = ps.concat([batch_shape, event_shape], axis=0)
6971
x = tf.reshape([-4., -2., 0., 2., 4.],
7072
ps.concat([[5], ps.ones_like(dist_shape)], axis=0))
7173
return tf.broadcast_to(x, ps.concat([[5], dist_shape], axis=0))
74+
7275
flat_event_shape = tf.nest.flatten(dist.event_shape_tensor())
7376
zs = [make_reference_values(s) for s in flat_event_shape]
7477
lp_dist, grad_dist = tfp.math.value_and_gradient(
@@ -83,11 +86,14 @@ def reference_value_and_gradient(z, event_shape):
8386
reference_vals_and_grads = [
8487
reference_value_and_gradient(z, event_shape)
8588
for (z, event_shape) in zip(zs, flat_event_shape)]
89+
8690
lps_reference = [lp for lp, grad in reference_vals_and_grads]
87-
self.assertAllClose(sum(lps_reference), lp_dist, atol=logprob_atol)
91+
self.assertAllClose(
92+
sum(lps_reference), lp_dist, rtol=rtol, atol=logprob_atol)
8893

8994
grads_reference = [grad for lp, grad in reference_vals_and_grads]
90-
self.assertAllCloseNested(grads_reference, grad_dist, atol=grad_atol)
95+
self.assertAllCloseNested(
96+
grads_reference, grad_dist, rtol=rtol, atol=grad_atol)
9197

9298
@parameterized.named_parameters(
9399
{'testcase_name': dname, 'dist_name': dname}
@@ -101,10 +107,11 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
101107
if dist_name in PRECONDITIONING_FAILS_DISTS:
102108
self.skipTest('Known failure.')
103109

104-
dist = data.draw(dhps.base_distributions(
105-
dist_name=dist_name,
106-
enable_vars=False,
107-
param_strategy_fn=_constrained_zeros_fn))
110+
dist = data.draw(
111+
dhps.base_distributions(
112+
dist_name=dist_name,
113+
enable_vars=False,
114+
param_strategy_fn=_constrained_zeros_fn))
108115
try:
109116
b = tfp.experimental.bijectors.make_distribution_bijector(dist)
110117
except NotImplementedError:
@@ -114,22 +121,20 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
114121

115122
@test_util.numpy_disable_gradient_test
116123
def test_multivariate_normal(self):
117-
d = tfd.MultivariateNormalFullCovariance(loc=[4., 8.],
118-
covariance_matrix=[[11., 0.099],
119-
[0.099, 0.1]])
124+
d = tfd.MultivariateNormalFullCovariance(
125+
loc=[4., 8.], covariance_matrix=[[11., 0.099], [0.099, 0.1]])
120126
b = tfp.experimental.bijectors.make_distribution_bijector(d)
121-
self.assertDistributionIsApproximatelyStandardNormal(
122-
tfb.Invert(b)(d))
127+
self.assertDistributionIsApproximatelyStandardNormal(tfb.Invert(b)(d))
123128

124129
@test_util.numpy_disable_gradient_test
125130
def test_markov_chain(self):
126131
d = tfd.MarkovChain(
127132
initial_state_prior=tfd.Uniform(low=0., high=1.),
128133
transition_fn=lambda _, x: tfd.Uniform(low=0., high=tf.nn.softplus(x)),
129-
num_steps=10)
134+
num_steps=3)
130135
b = tfp.experimental.bijectors.make_distribution_bijector(d)
131136
self.assertDistributionIsApproximatelyStandardNormal(
132-
tfb.Invert(b)(d))
137+
tfb.Invert(b)(d), rtol=1e-4)
133138

134139
@test_util.numpy_disable_gradient_test
135140
def test_markov_chain_joint(self):
@@ -145,21 +150,22 @@ def test_markov_chain_joint(self):
145150
num_steps=10)
146151
b = tfp.experimental.bijectors.make_distribution_bijector(d)
147152
self.assertDistributionIsApproximatelyStandardNormal(
148-
tfb.Invert(b)(d))
153+
tfb.Invert(b)(d), rtol=1e-4)
149154

150155
@test_util.numpy_disable_gradient_test
151156
def test_nested_joint_distribution(self):
152157

153158
def model():
154159
x = yield tfd.Normal(loc=-2., scale=1.)
155160
yield tfd.JointDistributionSequentialAutoBatched([
156-
tfd.Uniform(low=1. + tf.exp(x),
157-
high=1 + tf.exp(x) + tf.nn.softplus(x)),
161+
tfd.Uniform(low=1. - tf.exp(x),
162+
high=2. + tf.exp(x) + tf.nn.softplus(x)),
158163
lambda v: tfd.Exponential(v)]) # pylint: disable=unnecessary-lambda
164+
159165
dist = tfd.JointDistributionCoroutineAutoBatched(model)
160166
b = tfp.experimental.bijectors.make_distribution_bijector(dist)
161167
self.assertDistributionIsApproximatelyStandardNormal(
162-
tfb.Invert(b)(dist))
168+
tfb.Invert(b)(dist), rtol=1e-4)
163169

164170
@test_util.numpy_disable_gradient_test
165171
@test_util.jax_disable_test_missing_functionality(
@@ -171,6 +177,7 @@ def model_with_funnel():
171177
z = yield tfd.Normal(loc=-1., scale=2., name='z')
172178
x = yield tfd.Normal(loc=[0.], scale=tf.exp(z), name='x')
173179
yield tfd.Poisson(log_rate=x, name='y')
180+
174181
pinned_model = model_with_funnel.experimental_pin(y=[1])
175182
surrogate_posterior = tfp.experimental.vi.build_asvi_surrogate_posterior(
176183
pinned_model)
@@ -191,15 +198,16 @@ def do_sample():
191198
kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
192199
tfp.mcmc.TransformedTransitionKernel(
193200
tfp.mcmc.NoUTurnSampler(
194-
pinned_model.unnormalized_log_prob,
195-
step_size=0.1),
201+
pinned_model.unnormalized_log_prob, step_size=0.1),
196202
bijector=bijector),
197203
num_adaptation_steps=5),
198204
current_state=surrogate_posterior.sample(),
199205
num_burnin_steps=5,
200206
trace_fn=lambda _0, _1: [],
201207
num_results=10)
208+
202209
do_sample()
203210

211+
204212
if __name__ == '__main__':
205213
test_util.main()

tensorflow_probability/python/experimental/util/jit_public_methods.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
'dtype',
3737
'kl_divergence', # Wrapping applied explicitly in `_traced_kl_divergence`.
3838
'experimental_default_event_space_bijector',
39+
'experimental_local_measure',
3940
# tfb.Bijector
4041
# TODO(davmre): Test wrapping bijectors.
4142
'forward_event_shape',
@@ -45,7 +46,8 @@
4546
'forward_dtype',
4647
'inverse_dtype',
4748
'forward_event_ndims',
48-
'inverse_event_ndims'
49+
'inverse_event_ndims',
50+
'experimental_compute_density_correction',
4951
)
5052

5153
if NUMPY_MODE:

tensorflow_probability/python/sts/forecast_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,8 @@ def _run():
181181

182182
@test_util.jax_disable_test_missing_functionality('fit_with_hmc')
183183
def test_forecast_from_hmc(self):
184-
if not (tf1.control_flow_v2_enabled() or self.use_static_shape):
185-
self.skipTest('test_forecast_from_hmc does not currently work with TF1 '
186-
'and dynamic shapes')
184+
if not tf1.control_flow_v2_enabled():
185+
self.skipTest('test_forecast_from_hmc does not currently work with TF1')
187186

188187
# test that we can directly plug in the output of an HMC chain as
189188
# the input to `forecast`, as done in the example, with no `sess.run` call.

0 commit comments

Comments
 (0)