Skip to content

Commit 944af40

Browse files
authored
Merge branch 'tensorflow:main' into frighterafix#1505
2 parents 520d4c2 + a4da3a4 commit 944af40

35 files changed

+413
-314
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ def forward(self, x, name='forward', **kwargs):
13701370
@classmethod
13711371
def _is_increasing(cls, **kwargs):
13721372
"""Subclass implementation for `is_increasing` public function."""
1373-
raise NotImplementedError('`_is_increasing` not implemented.')
1373+
raise NotImplementedError(f'`_is_increasing` not implemented in {cls}.')
13741374

13751375
def _call_is_increasing(self, name, **kwargs):
13761376
"""Wraps call to _is_increasing, allowing extra shared logic."""
@@ -1684,7 +1684,8 @@ def forward_log_det_jacobian(self,
16841684
def experimental_compute_density_correction(self,
16851685
x,
16861686
tangent_space,
1687-
backward_compat=False):
1687+
backward_compat=False,
1688+
**kwargs):
16881689
"""Density correction for this transformation wrt the tangent space, at x.
16891690
16901691
Subclasses of Bijector may call the most specific applicable
@@ -1699,6 +1700,7 @@ def experimental_compute_density_correction(self,
16991700
the support manifold at `x`.
17001701
backward_compat: `bool` specifying whether to assume that the Bijector
17011702
is dimension-preserving.
1703+
**kwargs: Optional keyword arguments forwarded to tangent space methods.
17021704
17031705
Returns:
17041706
density_correction: `Tensor` representing the density correction---in log
@@ -1710,7 +1712,7 @@ def experimental_compute_density_correction(self,
17101712
17111713
"""
17121714
if backward_compat:
1713-
return tangent_space.transform_dimension_preserving(x, self)
1715+
return tangent_space.transform_dimension_preserving(x, self, **kwargs)
17141716
else:
17151717
raise TypeError(
17161718
'Please call the `TangentSpace` method applicable to this Bijector.')

tensorflow_probability/python/bijectors/ffjord.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def _solve_ode(self, ode_fn, state, **kwargs):
356356
def _augmented_forward(self, x, **condition_kwargs):
357357
"""Computes forward and forward_log_det_jacobian transformations."""
358358
augmented_ode_fn = self._trace_augmentation_fn(
359-
self._state_time_derivative_fn, x.shape, x.dtype)
359+
self._state_time_derivative_fn, prefer_static.shape(x), x.dtype)
360360
augmented_x = (x, tf.zeros_like(x))
361361
if condition_kwargs:
362362
y, fldj = self._solve_ode(augmented_ode_fn, augmented_x,
@@ -368,7 +368,7 @@ def _augmented_forward(self, x, **condition_kwargs):
368368
def _augmented_inverse(self, y, **condition_kwargs):
369369
"""Computes inverse and inverse_log_det_jacobian transformations."""
370370
augmented_inv_ode_fn = self._trace_augmentation_fn(
371-
self._inv_state_time_derivative_fn, y.shape, y.dtype)
371+
self._inv_state_time_derivative_fn, prefer_static.shape(y), y.dtype)
372372
augmented_y = (y, tf.zeros_like(y))
373373
if condition_kwargs:
374374
x, ildj = self._solve_ode(augmented_inv_ode_fn, augmented_y,

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,7 @@ multi_substrate_py_library(
934934
# tensorflow dep,
935935
"//tensorflow_probability",
936936
"//tensorflow_probability/python/bijectors:hypothesis_testlib",
937+
"//tensorflow_probability/python/experimental/distributions",
937938
"//tensorflow_probability/python/internal:hypothesis_testlib",
938939
"//tensorflow_probability/python/internal:tensorshape_util",
939940
],

tensorflow_probability/python/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1926,7 +1926,7 @@ def experimental_local_measure(self, value, backward_compat=False, **kwargs):
19261926
"""
19271927
log_prob = self.log_prob(value, **kwargs)
19281928
tangent_space = None
1929-
if getattr(self, '_experimental_tangent_space'):
1929+
if hasattr(self, '_experimental_tangent_space'):
19301930
tangent_space = self._experimental_tangent_space
19311931
elif backward_compat:
19321932
# Import here rather than top-level to avoid circular import.

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
INSTANTIABLE_BUT_NOT_SLICABLE = (
5959
'BatchBroadcast',
6060
'BatchReshape',
61-
'Sample' # TODO(b/204210361)
61+
'IncrementLogProb',
62+
'Sample', # TODO(b/204210361)
6263
)
6364

6465

@@ -350,6 +351,7 @@ def testCanConstructAndSampleDistribution(self, data):
350351
'TruncatedNormal', 'Uniform')
351352
not_annotated_dists = ('Empirical|event_ndims=0', 'Empirical|event_ndims=1',
352353
'Empirical|event_ndims=2', 'FiniteDiscrete',
354+
'IncrementLogProb',
353355
# cov_perturb_factor is not annotated since its shape
354356
# could be a vector or a matrix.
355357
'MultivariateNormalDiagPlusLowRankCovariance',

tensorflow_probability/python/distributions/hypothesis_testlib.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorflow_probability.python import distributions as tfd
3030
from tensorflow_probability.python import util as tfp_util
3131
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
32+
from tensorflow_probability.python.experimental import distributions as tfed
3233
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
3334
from tensorflow_probability.python.internal import tensorshape_util
3435

@@ -448,6 +449,10 @@ def _instantiable_base_dists():
448449
functools.partial(tfd.Empirical, event_ndims=1), dict(samples=2))
449450
result['Empirical|event_ndims=2'] = DistInfo( #
450451
functools.partial(tfd.Empirical, event_ndims=2), dict(samples=3))
452+
453+
# We use a special strategy for instantiating this, so event_dims is set to a
454+
# dummy value.
455+
result['IncrementLogProb'] = DistInfo(tfed.IncrementLogProb, None)
451456
return result
452457

453458

@@ -680,6 +685,12 @@ class names so they will not be drawn at the top level.
680685
return draw(spherical_uniforms(
681686
batch_shape=batch_shape, event_dim=event_dim,
682687
validate_args=validate_args))
688+
elif dist_name == 'IncrementLogProb':
689+
return draw(
690+
increment_log_probs(
691+
batch_shape=batch_shape,
692+
enable_vars=enable_vars,
693+
validate_args=validate_args))
683694

684695
if params is None:
685696
params_unconstrained, batch_shape = draw(
@@ -738,6 +749,45 @@ def spherical_uniforms(
738749
return result_dist
739750

740751

752+
@hps.composite
753+
def increment_log_probs(draw,
754+
batch_shape=None,
755+
enable_vars=False,
756+
validate_args=True):
757+
"""Strategy for drawing `IncrementLogProb` distributions.
758+
759+
Args:
760+
draw: Hypothesis strategy sampler supplied by `@hps.composite`.
761+
batch_shape: An optional `TensorShape`. The batch shape of the resulting
762+
`IncrementLogProb` distribution.
763+
enable_vars: TODO(b/181859346): Make this `True` all the time and put
764+
variable initialization in slicing_test. If `False`, the returned
765+
parameters are all `tf.Tensor`s and not {`tf.Variable`,
766+
`tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}
767+
validate_args: Python `bool`; whether to enable runtime assertions.
768+
769+
Returns:
770+
dists: A strategy for drawing `IncrementLogProb` distributions with the
771+
specified `batch_shape` (or an arbitrary one if omitted).
772+
"""
773+
if batch_shape is None:
774+
batch_shape = draw(tfp_hps.shapes(min_ndims=0, max_side=4))
775+
776+
log_prob_value = draw(
777+
tfp_hps.maybe_variable(
778+
tfp_hps.constrained_tensors(tfp_hps.identity_fn,
779+
tensorshape_util.as_list(batch_shape)),
780+
enable_vars=enable_vars))
781+
782+
if draw(hps.booleans()):
783+
return tfed.IncrementLogProb(log_prob_value, validate_args=validate_args)
784+
else:
785+
return tfed.IncrementLogProb(
786+
lambda v: v,
787+
validate_args=validate_args,
788+
log_prob_increment_kwargs={'v': log_prob_value})
789+
790+
741791
@hps.composite
742792
def batch_broadcasts(
743793
draw, batch_shape=None, event_dim=None,

tensorflow_probability/python/distributions/jax_transformation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
JVP_LOGPROB_SAMPLE_BLOCKLIST = (
8383
'BetaQuotient', # https://b/178552958
8484
'GeneralizedExtremeValue', # http://b/175654800
85+
'IncrementLogProb', # Sample and log prob are independent.
8586
'NegativeBinomial', # Too slow: http://b/170871051
8687
)
8788
JVP_LOGPROB_PARAM_BLOCKLIST = (
@@ -95,6 +96,7 @@
9596
VJP_LOGPROB_SAMPLE_BLOCKLIST = (
9697
'BetaQuotient', # https://b/178552958
9798
'GeneralizedExtremeValue', # http://b/175654800
99+
'IncrementLogProb', # Sample and log prob are independent.
98100
'NegativeBinomial', # Too slow: http://b/170871051
99101
)
100102
VJP_LOGPROB_PARAM_BLOCKLIST = (

tensorflow_probability/python/experimental/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ multi_substrate_py_library(
7878
# numpy dep,
7979
# tensorflow dep,
8080
"//tensorflow_probability/python/bijectors:identity",
81+
"//tensorflow_probability/python/distributions:distribution",
8182
"//tensorflow_probability/python/internal:callable_util",
8283
"//tensorflow_probability/python/internal:prefer_static",
8384
"//tensorflow_probability/python/internal:reparameterization",

0 commit comments

Comments
 (0)