|
29 | 29 | from tensorflow_probability.python import distributions as tfd
|
30 | 30 | from tensorflow_probability.python import util as tfp_util
|
31 | 31 | from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
|
| 32 | +from tensorflow_probability.python.experimental import distributions as tfed |
32 | 33 | from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
|
33 | 34 | from tensorflow_probability.python.internal import tensorshape_util
|
34 | 35 |
|
@@ -448,6 +449,10 @@ def _instantiable_base_dists():
|
448 | 449 | functools.partial(tfd.Empirical, event_ndims=1), dict(samples=2))
|
449 | 450 | result['Empirical|event_ndims=2'] = DistInfo( #
|
450 | 451 | functools.partial(tfd.Empirical, event_ndims=2), dict(samples=3))
|
| 452 | + |
| 453 | + # We use a special strategy for instantiating this, so event_dims is set to a |
| 454 | + # dummy value. |
| 455 | + result['IncrementLogProb'] = DistInfo(tfed.IncrementLogProb, None) |
451 | 456 | return result
|
452 | 457 |
|
453 | 458 |
|
@@ -680,6 +685,12 @@ class names so they will not be drawn at the top level.
|
680 | 685 | return draw(spherical_uniforms(
|
681 | 686 | batch_shape=batch_shape, event_dim=event_dim,
|
682 | 687 | validate_args=validate_args))
|
| 688 | + elif dist_name == 'IncrementLogProb': |
| 689 | + return draw( |
| 690 | + increment_log_probs( |
| 691 | + batch_shape=batch_shape, |
| 692 | + enable_vars=enable_vars, |
| 693 | + validate_args=validate_args)) |
683 | 694 |
|
684 | 695 | if params is None:
|
685 | 696 | params_unconstrained, batch_shape = draw(
|
@@ -738,6 +749,45 @@ def spherical_uniforms(
|
738 | 749 | return result_dist
|
739 | 750 |
|
740 | 751 |
|
| 752 | +@hps.composite |
| 753 | +def increment_log_probs(draw, |
| 754 | + batch_shape=None, |
| 755 | + enable_vars=False, |
| 756 | + validate_args=True): |
| 757 | + """Strategy for drawing `IncrementLogProb` distributions. |
| 758 | +
|
| 759 | + Args: |
| 760 | + draw: Hypothesis strategy sampler supplied by `@hps.composite`. |
| 761 | + batch_shape: An optional `TensorShape`. The batch shape of the resulting |
| 762 | + `IncrementLogProb` distribution. |
| 763 | + enable_vars: TODO(b/181859346): Make this `True` all the time and put |
| 764 | + variable initialization in slicing_test. If `False`, the returned |
| 765 | + parameters are all `tf.Tensor`s and not {`tf.Variable`, |
| 766 | + `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`} |
| 767 | + validate_args: Python `bool`; whether to enable runtime assertions. |
| 768 | +
|
| 769 | + Returns: |
| 770 | + dists: A strategy for drawing `IncrementLogProb` distributions with the |
| 771 | + specified `batch_shape` (or an arbitrary one if omitted). |
| 772 | + """ |
| 773 | + if batch_shape is None: |
| 774 | + batch_shape = draw(tfp_hps.shapes(min_ndims=0, max_side=4)) |
| 775 | + |
| 776 | + log_prob_value = draw( |
| 777 | + tfp_hps.maybe_variable( |
| 778 | + tfp_hps.constrained_tensors(tfp_hps.identity_fn, |
| 779 | + tensorshape_util.as_list(batch_shape)), |
| 780 | + enable_vars=enable_vars)) |
| 781 | + |
| 782 | + if draw(hps.booleans()): |
| 783 | + return tfed.IncrementLogProb(log_prob_value, validate_args=validate_args) |
| 784 | + else: |
| 785 | + return tfed.IncrementLogProb( |
| 786 | + lambda v: v, |
| 787 | + validate_args=validate_args, |
| 788 | + log_prob_increment_kwargs={'v': log_prob_value}) |
| 789 | + |
| 790 | + |
741 | 791 | @hps.composite
|
742 | 792 | def batch_broadcasts(
|
743 | 793 | draw, batch_shape=None, event_dim=None,
|
|
0 commit comments