Skip to content

Commit 1e2e658

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Make IncrementLogProb a proper tfd.Distribution.
This enables slicing, returning the distribution from tf.function, etc. Also, fix a bug with log_prob which didn't handle a callable log prob increment. PiperOrigin-RevId: 429419375
1 parent 3db2cf2 commit 1e2e658

File tree

10 files changed

+205
-189
lines changed

10 files changed

+205
-189
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,7 @@ multi_substrate_py_library(
934934
# tensorflow dep,
935935
"//tensorflow_probability",
936936
"//tensorflow_probability/python/bijectors:hypothesis_testlib",
937+
"//tensorflow_probability/python/experimental/distributions",
937938
"//tensorflow_probability/python/internal:hypothesis_testlib",
938939
"//tensorflow_probability/python/internal:tensorshape_util",
939940
],

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
INSTANTIABLE_BUT_NOT_SLICABLE = (
5959
'BatchBroadcast',
6060
'BatchReshape',
61-
'Sample' # TODO(b/204210361)
61+
'IncrementLogProb',
62+
'Sample', # TODO(b/204210361)
6263
)
6364

6465

@@ -350,6 +351,7 @@ def testCanConstructAndSampleDistribution(self, data):
350351
'TruncatedNormal', 'Uniform')
351352
not_annotated_dists = ('Empirical|event_ndims=0', 'Empirical|event_ndims=1',
352353
'Empirical|event_ndims=2', 'FiniteDiscrete',
354+
'IncrementLogProb',
353355
# cov_perturb_factor is not annotated since its shape
354356
# could be a vector or a matrix.
355357
'MultivariateNormalDiagPlusLowRankCovariance',

tensorflow_probability/python/distributions/hypothesis_testlib.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorflow_probability.python import distributions as tfd
3030
from tensorflow_probability.python import util as tfp_util
3131
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
32+
from tensorflow_probability.python.experimental import distributions as tfed
3233
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
3334
from tensorflow_probability.python.internal import tensorshape_util
3435

@@ -448,6 +449,10 @@ def _instantiable_base_dists():
448449
functools.partial(tfd.Empirical, event_ndims=1), dict(samples=2))
449450
result['Empirical|event_ndims=2'] = DistInfo( #
450451
functools.partial(tfd.Empirical, event_ndims=2), dict(samples=3))
452+
453+
# We use a special strategy for instantiating this, so event_dims is set to a
454+
# dummy value.
455+
result['IncrementLogProb'] = DistInfo(tfed.IncrementLogProb, None)
451456
return result
452457

453458

@@ -680,6 +685,12 @@ class names so they will not be drawn at the top level.
680685
return draw(spherical_uniforms(
681686
batch_shape=batch_shape, event_dim=event_dim,
682687
validate_args=validate_args))
688+
elif dist_name == 'IncrementLogProb':
689+
return draw(
690+
increment_log_probs(
691+
batch_shape=batch_shape,
692+
enable_vars=enable_vars,
693+
validate_args=validate_args))
683694

684695
if params is None:
685696
params_unconstrained, batch_shape = draw(
@@ -738,6 +749,45 @@ def spherical_uniforms(
738749
return result_dist
739750

740751

752+
@hps.composite
753+
def increment_log_probs(draw,
754+
batch_shape=None,
755+
enable_vars=False,
756+
validate_args=True):
757+
"""Strategy for drawing `IncrementLogProb` distributions.
758+
759+
Args:
760+
draw: Hypothesis strategy sampler supplied by `@hps.composite`.
761+
batch_shape: An optional `TensorShape`. The batch shape of the resulting
762+
`IncrementLogProb` distribution.
763+
enable_vars: TODO(b/181859346): Make this `True` all the time and put
764+
variable initialization in slicing_test. If `False`, the returned
765+
parameters are all `tf.Tensor`s and not {`tf.Variable`,
766+
`tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}
767+
validate_args: Python `bool`; whether to enable runtime assertions.
768+
769+
Returns:
770+
dists: A strategy for drawing `IncrementLogProb` distributions with the
771+
specified `batch_shape` (or an arbitrary one if omitted).
772+
"""
773+
if batch_shape is None:
774+
batch_shape = draw(tfp_hps.shapes(min_ndims=0, max_side=4))
775+
776+
log_prob_value = draw(
777+
tfp_hps.maybe_variable(
778+
tfp_hps.constrained_tensors(tfp_hps.identity_fn,
779+
tensorshape_util.as_list(batch_shape)),
780+
enable_vars=enable_vars))
781+
782+
if draw(hps.booleans()):
783+
return tfed.IncrementLogProb(log_prob_value, validate_args=validate_args)
784+
else:
785+
return tfed.IncrementLogProb(
786+
lambda v: v,
787+
validate_args=validate_args,
788+
log_prob_increment_kwargs={'v': log_prob_value})
789+
790+
741791
@hps.composite
742792
def batch_broadcasts(
743793
draw, batch_shape=None, event_dim=None,

tensorflow_probability/python/distributions/jax_transformation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
JVP_LOGPROB_SAMPLE_BLOCKLIST = (
8383
'BetaQuotient', # https://b/178552958
8484
'GeneralizedExtremeValue', # http://b/175654800
85+
'IncrementLogProb', # Sample and log prob are independent.
8586
'NegativeBinomial', # Too slow: http://b/170871051
8687
)
8788
JVP_LOGPROB_PARAM_BLOCKLIST = (
@@ -95,6 +96,7 @@
9596
VJP_LOGPROB_SAMPLE_BLOCKLIST = (
9697
'BetaQuotient', # https://b/178552958
9798
'GeneralizedExtremeValue', # http://b/175654800
99+
'IncrementLogProb', # Sample and log prob are independent.
98100
'NegativeBinomial', # Too slow: http://b/170871051
99101
)
100102
VJP_LOGPROB_PARAM_BLOCKLIST = (

tensorflow_probability/python/experimental/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ multi_substrate_py_library(
7878
# numpy dep,
7979
# tensorflow dep,
8080
"//tensorflow_probability/python/bijectors:identity",
81+
"//tensorflow_probability/python/distributions:distribution",
8182
"//tensorflow_probability/python/internal:callable_util",
8283
"//tensorflow_probability/python/internal:prefer_static",
8384
"//tensorflow_probability/python/internal:reparameterization",

0 commit comments

Comments
 (0)