Skip to content

Commit c3586af

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Use tangent space in TransformedDistribution.prob.
PiperOrigin-RevId: 576156013
1 parent 346cb6f commit c3586af

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4540,6 +4540,7 @@ multi_substrate_py_test(
45404540
tags = ["colab-smoke"],
45414541
deps = [
45424542
":beta",
4543+
":dirichlet",
45434544
":exponential",
45444545
":independent",
45454546
":joint_distribution_auto_batched",
@@ -4552,6 +4553,7 @@ multi_substrate_py_test(
45524553
":normal",
45534554
":sample",
45544555
":transformed_distribution",
4556+
":uniform",
45554557
# numpy dep,
45564558
# scipy dep,
45574559
# tensorflow dep,

tensorflow_probability/python/distributions/transformed_distribution.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _log_prob(self, y, **kwargs):
388388
return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
389389

390390
def _prob(self, y, **kwargs):
391-
if not hasattr(self.distribution, '_prob'):
391+
if not hasattr(self.distribution, '_prob') or self.bijector._is_injective: # pylint: disable=protected-access
392392
return tf.exp(self._log_prob(y, **kwargs))
393393
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
394394

@@ -400,9 +400,6 @@ def _prob(self, y, **kwargs):
400400
)
401401
ildj = self.bijector.inverse_log_det_jacobian(
402402
y, event_ndims=event_ndims, **bijector_kwargs)
403-
if self.bijector._is_injective: # pylint: disable=protected-access
404-
base_prob = self.distribution.prob(x, **distribution_kwargs)
405-
return base_prob * tf.exp(tf.cast(ildj, base_prob.dtype))
406403

407404
# Compute prob on each element of the inverse image.
408405
prob_on_fibers = []

tensorflow_probability/python/distributions/transformed_distribution_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tensorflow_probability.python.bijectors import split
4343
from tensorflow_probability.python.bijectors import tanh
4444
from tensorflow_probability.python.distributions import beta
45+
from tensorflow_probability.python.distributions import dirichlet
4546
from tensorflow_probability.python.distributions import exponential
4647
from tensorflow_probability.python.distributions import independent
4748
from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab
@@ -54,6 +55,7 @@
5455
from tensorflow_probability.python.distributions import normal as normal_lib
5556
from tensorflow_probability.python.distributions import sample as sample_lib
5657
from tensorflow_probability.python.distributions import transformed_distribution
58+
from tensorflow_probability.python.distributions import uniform
5759
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
5860
from tensorflow_probability.python.internal import prefer_static as ps
5961
from tensorflow_probability.python.internal import tensorshape_util
@@ -650,6 +652,26 @@ def testLogProbRatio(self):
650652
# oracle_64, d0.log_prob(x0) - d1.log_prob(x1),
651653
# rtol=0., atol=0.007)
652654

655+
@test_util.numpy_disable_test_missing_functionality('b/306384754')
656+
def testLogProbMatchesProbDirichlet(self):
657+
# This was https://github.com/tensorflow/probability/issues/1761
658+
scaled_dir = transformed_distribution.TransformedDistribution(
659+
distribution=dirichlet.Dirichlet([2.0, 3.0]),
660+
bijector=scale_lib.Scale(2.0))
661+
x = np.array([0.2, 1.8], dtype=np.float32)
662+
self.assertAllClose(scaled_dir.prob(x),
663+
tf.exp(scaled_dir.log_prob(x)))
664+
665+
@test_util.numpy_disable_test_missing_functionality('b/306384754')
666+
def testLogProbMatchesProbUniform(self):
667+
# Uniform does not define _log_prob
668+
scaled_uniform = transformed_distribution.TransformedDistribution(
669+
distribution=uniform.Uniform(),
670+
bijector=scale_lib.Scale(2.0))
671+
x = np.array([0.2], dtype=np.float32)
672+
self.assertAllClose(scaled_uniform.prob(x),
673+
tf.exp(scaled_uniform.log_prob(x)))
674+
653675

654676
@test_util.test_all_tf_execution_regimes
655677
class ScalarToMultiTest(test_util.TestCase):

0 commit comments

Comments
 (0)