|
42 | 42 | from tensorflow_probability.python.bijectors import split
|
43 | 43 | from tensorflow_probability.python.bijectors import tanh
|
44 | 44 | from tensorflow_probability.python.distributions import beta
|
| 45 | +from tensorflow_probability.python.distributions import dirichlet |
45 | 46 | from tensorflow_probability.python.distributions import exponential
|
46 | 47 | from tensorflow_probability.python.distributions import independent
|
47 | 48 | from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab
|
|
54 | 55 | from tensorflow_probability.python.distributions import normal as normal_lib
|
55 | 56 | from tensorflow_probability.python.distributions import sample as sample_lib
|
56 | 57 | from tensorflow_probability.python.distributions import transformed_distribution
|
| 58 | +from tensorflow_probability.python.distributions import uniform |
57 | 59 | from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
|
58 | 60 | from tensorflow_probability.python.internal import prefer_static as ps
|
59 | 61 | from tensorflow_probability.python.internal import tensorshape_util
|
@@ -650,6 +652,26 @@ def testLogProbRatio(self):
|
650 | 652 | # oracle_64, d0.log_prob(x0) - d1.log_prob(x1),
|
651 | 653 | # rtol=0., atol=0.007)
|
652 | 654 |
|
| 655 | + @test_util.numpy_disable_test_missing_functionality('b/306384754') |
| 656 | + def testLogProbMatchesProbDirichlet(self): |
| 657 | + # This was https://github.com/tensorflow/probability/issues/1761 |
| 658 | + scaled_dir = transformed_distribution.TransformedDistribution( |
| 659 | + distribution=dirichlet.Dirichlet([2.0, 3.0]), |
| 660 | + bijector=scale_lib.Scale(2.0)) |
| 661 | + x = np.array([0.2, 1.8], dtype=np.float32) |
| 662 | + self.assertAllClose(scaled_dir.prob(x), |
| 663 | + tf.exp(scaled_dir.log_prob(x))) |
| 664 | + |
| 665 | + @test_util.numpy_disable_test_missing_functionality('b/306384754') |
| 666 | + def testLogProbMatchesProbUniform(self): |
| 667 | + # Uniform does not define _log_prob |
| 668 | + scaled_uniform = transformed_distribution.TransformedDistribution( |
| 669 | + distribution=uniform.Uniform(), |
| 670 | + bijector=scale_lib.Scale(2.0)) |
| 671 | + x = np.array([0.2], dtype=np.float32) |
| 672 | + self.assertAllClose(scaled_uniform.prob(x), |
| 673 | + tf.exp(scaled_uniform.log_prob(x))) |
| 674 | + |
653 | 675 |
|
654 | 676 | @test_util.test_all_tf_execution_regimes
|
655 | 677 | class ScalarToMultiTest(test_util.TestCase):
|
|
0 commit comments