Skip to content

Commit 4547374

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add local measure to discrete distributions. This ensures Discrete distributions transform correctly under bijectors.
```python dist = tfp.distributions.Bernoulli(probs=0.5, dtype=tf.float32) transformed_dist = tfp.bijectors.Scale(2.)(dist) transformed_dist.prob(0.) # Expect this to be 0.5, but if we apply a det jacobian correction of 1 / 2, this would be 0.25 ``` PiperOrigin-RevId: 452449290
1 parent d0b9d34 commit 4547374

22 files changed

+172
-20
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ multi_substrate_py_library(
309309
# numpy dep,
310310
# tensorflow dep,
311311
"//tensorflow_probability/python/bijectors:sigmoid",
312+
"//tensorflow_probability/python/experimental/tangent_spaces",
312313
"//tensorflow_probability/python/internal:assert_util",
313314
"//tensorflow_probability/python/internal:batched_rejection_sampler",
314315
"//tensorflow_probability/python/internal:distribution_util",
@@ -1565,6 +1566,7 @@ multi_substrate_py_library(
15651566
":distribution",
15661567
# tensorflow dep,
15671568
"//tensorflow_probability/python/bijectors:softmax_centered",
1569+
"//tensorflow_probability/python/experimental/tangent_spaces",
15681570
"//tensorflow_probability/python/internal:assert_util",
15691571
"//tensorflow_probability/python/internal:distribution_util",
15701572
"//tensorflow_probability/python/internal:dtype_util",
@@ -1584,6 +1586,7 @@ multi_substrate_py_library(
15841586
":gamma",
15851587
# tensorflow dep,
15861588
"//tensorflow_probability/python/bijectors:sigmoid",
1589+
"//tensorflow_probability/python/experimental/tangent_spaces",
15871590
"//tensorflow_probability/python/internal:assert_util",
15881591
"//tensorflow_probability/python/internal:distribution_util",
15891592
"//tensorflow_probability/python/internal:dtype_util",

tensorflow_probability/python/distributions/bernoulli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from tensorflow_probability.python.internal import tensor_util
2929

3030

31-
class Bernoulli(distribution.AutoCompositeTensorDistribution):
31+
class Bernoulli(
32+
distribution.DiscreteDistributionMixin,
33+
distribution.AutoCompositeTensorDistribution):
3234
"""Bernoulli distribution.
3335
3436
The Bernoulli distribution with `probs` parameter, i.e., the probability of a

tensorflow_probability/python/distributions/beta_binomial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646
"""
4747

4848

49-
class BetaBinomial(distribution.AutoCompositeTensorDistribution):
49+
class BetaBinomial(
50+
distribution.DiscreteDistributionMixin,
51+
distribution.AutoCompositeTensorDistribution):
5052
"""Beta-Binomial compound distribution.
5153
5254
The Beta-Binomial distribution is parameterized by (a batch of) `total_count`

tensorflow_probability/python/distributions/binomial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ def _random_binomial(
262262
return sampler_impl(**params)
263263

264264

265-
class Binomial(distribution.AutoCompositeTensorDistribution):
265+
class Binomial(
266+
distribution.DiscreteDistributionMixin,
267+
distribution.AutoCompositeTensorDistribution):
266268
"""Binomial distribution.
267269
268270
This distribution is parameterized by `probs`, a (batch of) probabilities for

tensorflow_probability/python/distributions/categorical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def _broadcast_cat_event_and_params(event, params, base_dtype):
5858
return event, params
5959

6060

61-
class Categorical(distribution.AutoCompositeTensorDistribution):
61+
class Categorical(
62+
distribution.DiscreteDistributionMixin,
63+
distribution.AutoCompositeTensorDistribution):
6264
"""Categorical distribution over integers.
6365
6466
The Categorical distribution is parameterized by either probabilities or

tensorflow_probability/python/distributions/dirichlet_multinomial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151
with `self.concentration` and `self.total_count`."""
5252

5353

54-
class DirichletMultinomial(distribution.AutoCompositeTensorDistribution):
54+
class DirichletMultinomial(
55+
distribution.DiscreteDistributionMixin,
56+
distribution.AutoCompositeTensorDistribution):
5557
"""Dirichlet-Multinomial compound distribution.
5658
5759
The Dirichlet-Multinomial distribution is parameterized by a (batch of)

tensorflow_probability/python/distributions/distribution.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,48 @@ class MyDistribution(tfb.AutoCompositeTensorDistribution):
21192119
pass
21202120

21212121

2122+
class DiscreteDistributionMixin(object):
2123+
"""Mixin for Distributions over discrete spaces.
2124+
2125+
This mixin identifies a `Distribution` as a discrete distribution, which in
2126+
turn ensures that it is transformed properly under `TransformedDistribution`.
2127+
2128+
Normally, for a continuous distribution `dist` by a bijector `bij`, we have
2129+
the following formula for the `log_prob`:
2130+
`dist.log_prob(bij.inverse(y)) + bij.inverse_log_det_jacobian(y)`.
2131+
For a discrete distribution, we don't apply the `inverse_log_det_jacobian`
2132+
correction (hence just `dist.log_prob(bij.inverse(y))`). This difference
2133+
comes from transforming a probability density vs. probabilities.
2134+
2135+
As an example, we could take a Bernoulli distribution (
2136+
whose samples are `0` or `1`) and square it via `tfb.Square`. Samples from
2137+
this new distribution are still `0` or `1` and one would expect that the
2138+
probabilities for `0` and `1` are unchanged after this transformation.
2139+
2140+
```python
2141+
dist = tfp.distributions.Bernoulli(probs=0.5)
2142+
dist.prob(1.) # expect 0.5
2143+
transformed_dist = tfp.bijectors.Square()(dist)
2144+
transformed_dist.prob(1.) # expect 0.5
2145+
```
2146+
2147+
If we apply the jacobian correction, we would instead get the wrong answer
2148+
2149+
```python
2150+
# If we compute with the jacobian correction explicitly, we get the wrong
2151+
# answer.
2152+
bij = tfp.bijectors.Square()
2153+
prob_at_1 = dist.log_prob(bij.inverse(1.)) + bij.inverse_log_det_jacobian(1.)
2154+
prob_at_1 = tf.math.exp(prob_at_1) # This is 0.25
2155+
```
2156+
"""
2157+
2158+
@property
2159+
def _experimental_tangent_space(self):
2160+
from tensorflow_probability.python.experimental import tangent_spaces # pylint: disable=g-import-not-at-top
2161+
return tangent_spaces.ZeroSpace()
2162+
2163+
21222164
class _PrettyDict(dict):
21232165
"""`dict` with stable `repr`, `str`."""
21242166

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@
9191
})
9292

9393

94+
DISCRETE_BUT_NOT_TRANSFORMABLE = [
95+
# Samples are integers so can't be transformed by a float bijector.
96+
'DeterminantalPointProcess',
97+
]
98+
99+
94100
@test_util.test_all_tf_execution_regimes
95101
class StatisticConsistentShapesTest(test_util.TestCase):
96102

@@ -230,6 +236,43 @@ def testDistribution(self, dist_name, data):
230236
self.assertAllEqual(s1, s2)
231237

232238

239+
@test_util.test_all_tf_execution_regimes
240+
class TestDiscreteDistributions(test_util.TestCase):
241+
242+
@parameterized.named_parameters(
243+
{'testcase_name': dname, 'dist_name': dname}
244+
for dname in sorted(list(set(dhps.DISCRETE_DISTS) -
245+
set(DISCRETE_BUT_NOT_TRANSFORMABLE))))
246+
@hp.given(hps.data())
247+
@tfp_hps.tfp_hp_settings()
248+
def testNoJacobianCorrection(self, dist_name, data):
249+
250+
# Disable validate args since transforming with Softplus and inverting
251+
# might make arguments not as close to integers.
252+
dist = data.draw(dhps.distributions(
253+
dist_name=dist_name,
254+
enable_vars=False,
255+
validate_args=False))
256+
257+
# Ensure that these are distributions over floats so we can apply the
258+
# Softplus bijector.
259+
if 'dtype' in dist.parameters:
260+
dist = dist.copy(dtype=tf.float32)
261+
bij = tfb.Softplus()
262+
transformed_dist = tfd.TransformedDistribution(dist, bijector=bij)
263+
264+
seed = test_util.test_seed()
265+
samples = transformed_dist.sample(7, seed=seed)
266+
# Break bijector caching.
267+
samples = self.evaluate(
268+
samples + tf.constant(0., dtype=samples.dtype))
269+
270+
# Check that no jacobian correction is added for a discrete distribution.
271+
self.assertAllClose(
272+
self.evaluate(dist.log_prob(bij.inverse(samples))),
273+
self.evaluate(transformed_dist.log_prob(samples)))
274+
275+
233276
@test_util.test_all_tf_execution_regimes
234277
class SampleAndLogProbTest(test_util.TestCase):
235278

@@ -630,8 +673,8 @@ class TestMixingGraphAndEagerModes(test_util.TestCase):
630673

631674
@parameterized.named_parameters(
632675
{'testcase_name': dname, 'dist_name': dname}
633-
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) +
634-
list(dhps.INSTANTIABLE_META_DISTS))
676+
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) +
677+
list(dhps.INSTANTIABLE_META_DISTS))
635678
)
636679
@hp.given(hps.data())
637680
@tfp_hps.tfp_hp_settings()

tensorflow_probability/python/distributions/dpp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def body(i, vecs, cur_sample, seed):
240240
return tf.cast(sample, tf.int32)
241241

242242

243-
class DeterminantalPointProcess(distribution.AutoCompositeTensorDistribution):
243+
class DeterminantalPointProcess(
244+
distribution.DiscreteDistributionMixin,
245+
distribution.AutoCompositeTensorDistribution):
244246
"""Determinantal point process (DPP) distribution.
245247
246248
The DPP disribution parameterized by the eigenvalues and eigenvectors of the

tensorflow_probability/python/distributions/empirical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _broadcast_event_and_samples(event, samples, event_ndims):
5252
return event, samples
5353

5454

55-
class Empirical(distribution.AutoCompositeTensorDistribution):
55+
class Empirical(
56+
distribution.DiscreteDistributionMixin,
57+
distribution.AutoCompositeTensorDistribution):
5658
"""Empirical distribution.
5759
5860
The Empirical distribution is parameterized by a [batch] multiset of samples.

0 commit comments

Comments
 (0)