Skip to content

Commit a163a15

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Remove broadcasting_shapes(shape, n=1), which directly returns shape. Adds an error to prevent this case in the future.
Add comment showing how to use broadcast_compatible_shapes instead of broadcasting_shapes. Unfortunately doing so would reveal new failures. (TODO) Verify that the forward transform adds batch/event dims to match the forward transform of broadcasted-to-full-batch/event. (TODO) Verify that FLDJ [effectively] broadcasts before reducing. PiperOrigin-RevId: 375137840
1 parent fd8be3c commit a163a15

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,17 +429,28 @@ def check_event_space_bijector_constrains(self, dist, data):
429429
if event_space_bijector is None:
430430
return
431431

432+
# Draw a sample shape
433+
sample_shape = data.draw(tfp_hps.shapes())
434+
inv_event_shape = event_space_bijector.inverse_event_shape(
435+
tensorshape_util.concatenate(dist.batch_shape, dist.event_shape))
436+
437+
# Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
438+
# where `inverse_event_shape` is the event shape in the bijector's
439+
# domain. This is the shape of `y` in R**n, such that
440+
# x = event_space_bijector(y) has the event shape of the distribution.
441+
442+
# TODO(b/174778703): Actually draw broadcast compatible shapes.
443+
batch_inv_event_compat_shape = inv_event_shape
444+
# batch_inv_event_compat_shape = data.draw(
445+
# tfp_hps.broadcast_compatible_shape(inv_event_shape))
446+
# batch_inv_event_compat_shape = tensorshape_util.concatenate(
447+
# (1,) * (len(inv_event_shape) - len(batch_inv_event_compat_shape)),
448+
# batch_inv_event_compat_shape)
449+
432450
total_sample_shape = tensorshape_util.concatenate(
433-
# Draw a sample shape
434-
data.draw(tfp_hps.shapes()),
435-
# Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
436-
# where `inverse_event_shape` is the event shape in the bijector's
437-
# domain. This is the shape of `y` in R**n, such that
438-
# x = event_space_bijector(y) has the event shape of the distribution.
439-
data.draw(tfp_hps.broadcasting_shapes(
440-
event_space_bijector.inverse_event_shape(
441-
tensorshape_util.concatenate(
442-
dist.batch_shape, dist.event_shape)), n=1))[0])
451+
sample_shape, batch_inv_event_compat_shape)
452+
# full_sample_batch_event_shape = tensorshape_util.concatenate(
453+
# sample_shape, inv_event_shape)
443454

444455
y = data.draw(
445456
tfp_hps.constrained_tensors(
@@ -451,6 +462,14 @@ def check_event_space_bijector_constrains(self, dist, data):
451462
with tf.control_dependencies(dist._sample_control_dependencies(x)):
452463
self.evaluate(tf.identity(x))
453464

465+
# TODO(b/158874412): Verify DoF changing default bijectors.
466+
# y_bc = tf.broadcast_to(y, full_sample_batch_event_shape)
467+
# x_bc = event_space_bijector(y_bc)
468+
# self.assertAllClose(x, x_bc)
469+
# fldj = event_space_bijector.forward_log_det_jacobian(y)
470+
# fldj_bc = event_space_bijector.forward_log_det_jacobian(y_bc)
471+
# self.assertAllClose(fldj, fldj_bc)
472+
454473
@parameterized.named_parameters(
455474
{'testcase_name': dname, 'dist_name': dname}
456475
for dname in sorted(list(set(dhps.INSTANTIABLE_BASE_DISTS.keys()) -

tensorflow_probability/python/internal/hypothesis_testlib.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def broadcasting_named_shapes(draw, batch_shape, param_names):
593593
n = len(param_names)
594594
return dict(
595595
zip(draw(hps.permutations(param_names)),
596-
draw(broadcasting_shapes(batch_shape, n))))
596+
draw(broadcasting_shapes(batch_shape, n, no_error_n_eq_1=True))))
597597

598598

599599
def _compute_rank_and_fullsize_reqd(draw, target_shape, current_shape, is_last):
@@ -651,7 +651,7 @@ def broadcast_compatible_shape(shape):
651651

652652

653653
@hps.composite
654-
def broadcasting_shapes(draw, target_shape, n):
654+
def broadcasting_shapes(draw, target_shape, n, no_error_n_eq_1=False):
655655
"""Strategy for drawing a set of `n` shapes that broadcast to `target_shape`.
656656
657657
For each shape we need to choose its rank, and whether or not each axis i is 1
@@ -663,12 +663,17 @@ def broadcasting_shapes(draw, target_shape, n):
663663
draw: Hypothesis strategy sampler supplied by `@hps.composite`.
664664
target_shape: The target (fully-defined) batch shape.
665665
n: Python `int`, the number of shapes to draw.
666+
no_error_n_eq_1: If True, don't raise ValueError when n==1.
666667
667668
Returns:
668669
shapes: A strategy for drawing sequences of `tf.TensorShape` such that the
669670
set of shapes in each sequence broadcast to `target_shape`. The shapes are
670671
fully defined.
671672
"""
673+
if n == 1 and not no_error_n_eq_1:
674+
raise ValueError('`broadcasting_shapes(shp, n=1) is just `shp`. '
675+
'Did you want `broadcast_compatible_shape`? If `n` '
676+
'is stochastic, add arg `no_error_n_eq_1=True`.')
672677
target_shape = tf.TensorShape(target_shape)
673678
target_rank = tensorshape_util.rank(target_shape)
674679
result = []

0 commit comments

Comments
 (0)