Skip to content

Commit ede69e0

Browse files
davmretensorflower-gardener
authored andcommitted
Remove JD overrides of Distribution public methods for event and batch shape.
These are no longer needed since Distribution now supports structured shape. Using the public methods allows us to benefit from static optimizations, e.g., batch_shape_tensor won't run if batch_shape is fully defined. This also removes the `sample_shape` argument to event_shape_tensor and batch_shape_tensor. This theoretically could be a breaking change, but we are proceeding without a deprecation cycle because passing a nontrivial sample shape was *already* broken (and has been since at least Jan 2020). PiperOrigin-RevId: 380024311
1 parent 2d08c76 commit ede69e0

File tree

3 files changed

+37
-103
lines changed

3 files changed

+37
-103
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3075,7 +3075,7 @@ multi_substrate_py_test(
30753075
name = "joint_distribution_coroutine_test",
30763076
size = "medium",
30773077
srcs = ["joint_distribution_coroutine_test.py"],
3078-
shard_count = 2,
3078+
shard_count = 5,
30793079
deps = [
30803080
# numpy dep,
30813081
# tensorflow dep,

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 19 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from tensorflow_probability.python.internal import distribution_util
3636
from tensorflow_probability.python.internal import docstring_util
3737
from tensorflow_probability.python.internal import nest_util
38-
from tensorflow_probability.python.internal import prefer_static
38+
from tensorflow_probability.python.internal import prefer_static as ps
3939
from tensorflow_probability.python.internal import samplers
4040
from tensorflow_probability.python.util.seed_stream import SeedStream
4141
from tensorflow_probability.python.util.seed_stream import TENSOR_SEED_MSG_PREFIX
@@ -315,56 +315,15 @@ def experimental_shard_axis_names(self):
315315
def use_vectorized_map(self):
316316
return False
317317

318-
@property
319-
def batch_shape(self):
320-
"""Shape of a single sample from a single event index as a `TensorShape`.
321-
322-
May be partially defined or unknown.
323-
324-
The batch dimensions are indexes into independent, non-identical
325-
parameterizations of this distribution.
326-
327-
Returns:
328-
batch_shape: `tuple` of `TensorShape`s representing the `batch_shape` for
329-
each distribution in `model`.
330-
"""
318+
def _batch_shape(self):
331319
return self._model_unflatten([
332320
d.batch_shape for d in self._get_single_sample_distributions()])
333321

334-
def batch_shape_tensor(self, sample_shape=(), name='batch_shape_tensor'):
335-
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
336-
337-
The batch dimensions are indexes into independent, non-identical
338-
parameterizations of this distribution.
339-
340-
Args:
341-
sample_shape: The sample shape under which to evaluate the joint
342-
distribution. Sample shape at root (toplevel) nodes may affect the batch
343-
or event shapes of child nodes.
344-
name: name to give to the op
345-
346-
Returns:
347-
batch_shape: `Tensor` representing batch shape of each distribution in
348-
`model`.
349-
"""
350-
with self._name_and_control_scope(name):
351-
return self._model_unflatten(
352-
self._map_attr_over_dists(
353-
'batch_shape_tensor',
354-
dists=(self.sample_distributions(sample_shape)
355-
if sample_shape else None)))
356-
357-
@property
358-
def event_shape(self):
359-
"""Shape of a single sample from a single batch as a `TensorShape`.
322+
def _batch_shape_tensor(self):
323+
return self._model_unflatten(
324+
self._map_attr_over_dists('batch_shape_tensor'))
360325

361-
May be partially defined or unknown.
362-
363-
Returns:
364-
event_shape: `tuple` of `TensorShape`s representing the `event_shape` for
365-
each distribution in `model`.
366-
"""
367-
# Caching will not leak graph Tensors since this is a static attribute.
326+
def _event_shape(self):
368327
if not hasattr(self, '_cached_event_shape'):
369328
self._cached_event_shape = [
370329
d.event_shape
@@ -373,24 +332,9 @@ def event_shape(self):
373332
# wrapping the returned value.
374333
return self._model_unflatten(self._cached_event_shape)
375334

376-
def event_shape_tensor(self, sample_shape=(), name='event_shape_tensor'):
377-
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
378-
379-
Args:
380-
sample_shape: The sample shape under which to evaluate the joint
381-
distribution. Sample shape at root (toplevel) nodes may affect the batch
382-
or event shapes of child nodes.
383-
name: name to give to the op
384-
Returns:
385-
event_shape: `tuple` of `Tensor`s representing the `event_shape` for each
386-
distribution in `model`.
387-
"""
388-
with self._name_and_control_scope(name):
389-
return self._model_unflatten(
390-
self._map_attr_over_dists(
391-
'event_shape_tensor',
392-
dists=(self.sample_distributions(sample_shape)
393-
if sample_shape else None)))
335+
def _event_shape_tensor(self):
336+
return self._model_unflatten(
337+
self._map_attr_over_dists('event_shape_tensor'))
394338

395339
def sample_distributions(self, sample_shape=(), seed=None, value=None,
396340
name='sample_distributions', **kwargs):
@@ -847,9 +791,9 @@ def _assert_compatible_shape(self, index, sample_shape, samples):
847791
requested_shape, _ = self._expand_sample_shape_to_vector(
848792
tf.convert_to_tensor(sample_shape, dtype=tf.int32),
849793
name='requested_shape')
850-
actual_shape = prefer_static.shape(samples)
851-
actual_rank = prefer_static.rank_from_shape(actual_shape)
852-
requested_rank = prefer_static.rank_from_shape(requested_shape)
794+
actual_shape = ps.shape(samples)
795+
actual_rank = ps.rank_from_shape(actual_shape)
796+
requested_rank = ps.rank_from_shape(requested_shape)
853797

854798
# We test for two properties we expect of yielded distributions:
855799
# (1) The rank of the tensor of generated samples must be at least
@@ -1068,8 +1012,8 @@ def maybe_check_wont_broadcast(flat_xs, validate_args):
10681012
# Only when `validate_args` is `True` do we enforce the validation.
10691013
return flat_xs
10701014
msg = 'Broadcasting probably indicates an error in model specification.'
1071-
s = tuple(prefer_static.shape(x) for x in flat_xs)
1072-
if all(prefer_static.is_numpy(s_) for s_ in s):
1015+
s = tuple(ps.shape(x) for x in flat_xs)
1016+
if all(ps.is_numpy(s_) for s_ in s):
10731017
if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
10741018
raise ValueError(msg)
10751019
return flat_xs
@@ -1092,7 +1036,7 @@ def __init__(self, jd, parameters=None, bijector_fn=None):
10921036
bijectors = tuple(bijector_fn(d)
10931037
for d in jd._get_single_sample_distributions())
10941038
i_min_event_ndims = tf.nest.map_structure(
1095-
prefer_static.size, jd.event_shape)
1039+
ps.size, jd.event_shape)
10961040
f_min_event_ndims = jd._model_unflatten([
10971041
b.inverse_event_ndims(nd) for b, nd in
10981042
zip(bijectors, jd._model_flatten(i_min_event_ndims))])
@@ -1207,9 +1151,9 @@ def _jd_log_prob_ratio(p, x, q, y, name=None):
12071151
"""Implements `log_prob_ratio` for tfd.JointDistribution*."""
12081152
with tf.name_scope(name or 'jd_log_prob_ratio'):
12091153
tf.nest.assert_same_structure(x, y)
1210-
ps, _ = p.sample_distributions(value=x, seed=samplers.zeros_seed())
1211-
qs, _ = q.sample_distributions(value=y, seed=samplers.zeros_seed())
1212-
tf.nest.assert_same_structure(ps, qs)
1154+
p_dists, _ = p.sample_distributions(value=x, seed=samplers.zeros_seed())
1155+
q_dists, _ = q.sample_distributions(value=y, seed=samplers.zeros_seed())
1156+
tf.nest.assert_same_structure(p_dists, q_dists)
12131157
log_prob_ratio_parts = nest.map_structure_up_to(
1214-
ps, log_prob_ratio.log_prob_ratio, ps, x, qs, y)
1158+
p_dists, log_prob_ratio.log_prob_ratio, p_dists, x, q_dists, y)
12151159
return tf.add_n(tf.nest.flatten(log_prob_ratio_parts))

tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,52 +77,42 @@ def __init__(self, *args, **kwargs):
7777
def batch_ndims(self):
7878
return self._batch_ndims
7979

80-
@property
8180
def _batch_shape_parts(self):
8281
return [d.batch_shape[:self.batch_ndims]
8382
for d in self._get_single_sample_distributions()]
8483

85-
@property
86-
def batch_shape(self):
84+
def _batch_shape(self):
8785
# Caching will not leak graph Tensors since this is a static attribute.
88-
if not hasattr(self, '_cached_batch_shape'):
89-
reduce_fn = ((lambda a, b: a.merge_with(b)) if self.validate_args
90-
else tf.broadcast_static_shape) # Allows broadcasting.
91-
self._cached_batch_shape = functools.reduce(
92-
reduce_fn, self._batch_shape_parts)
93-
return self._cached_batch_shape
86+
reduce_fn = ((lambda a, b: a.merge_with(b)) if self.validate_args
87+
else tf.broadcast_static_shape) # Allows broadcasting.
88+
return functools.reduce(reduce_fn, self._batch_shape_parts())
9489

9590
def _batch_shape_tensor_parts(self):
9691
return [d.batch_shape_tensor()[:self.batch_ndims]
9792
for d in self._get_single_sample_distributions()]
9893

99-
def batch_shape_tensor(self, sample_shape=(), name='batch_shape_tensor'):
100-
del sample_shape # Unused.
101-
with self._name_and_control_scope(name):
102-
return tf.convert_to_tensor(functools.reduce(
103-
prefer_static.broadcast_shape, self._batch_shape_tensor_parts()))
94+
def _batch_shape_tensor(self):
95+
return tf.convert_to_tensor(functools.reduce(
96+
prefer_static.broadcast_shape, self._batch_shape_tensor_parts()))
10497

105-
@property
106-
def event_shape(self):
98+
def _event_shape(self):
10799
if not hasattr(self, '_cached_event_shape'):
108100
self._cached_event_shape = list([
109101
tf.nest.map_structure( # Recurse over joint component distributions.
110102
d.batch_shape[self.batch_ndims:].concatenate,
111103
d.event_shape) for d in self._get_single_sample_distributions()])
112104
return self._model_unflatten(self._cached_event_shape)
113105

114-
def event_shape_tensor(self, sample_shape=(), name='event_shape_tensor'):
106+
def _event_shape_tensor(self):
115107
"""Shape of a single sample from a single batch."""
116-
del sample_shape # Unused.
117-
with self._name_and_control_scope(name):
118-
component_shapes = []
119-
for d in self._get_single_sample_distributions():
120-
iid_event_shape = d.batch_shape_tensor()[self.batch_ndims:]
121-
# Recurse over the (potentially joint) component distribution's event.
122-
component_shapes.append(tf.nest.map_structure(
123-
lambda a, b=iid_event_shape: prefer_static.concat([b, a], axis=0),
124-
d.event_shape_tensor()))
125-
return self._model_unflatten(component_shapes)
108+
component_shapes = []
109+
for d in self._get_single_sample_distributions():
110+
iid_event_shape = d.batch_shape_tensor()[self.batch_ndims:]
111+
# Recurse over the (potentially joint) component distribution's event.
112+
component_shapes.append(tf.nest.map_structure(
113+
lambda a, b=iid_event_shape: prefer_static.concat([b, a], axis=0),
114+
d.event_shape_tensor()))
115+
return self._model_unflatten(component_shapes)
126116

127117
def _reduce_measure_over_dists(self, xs, reduce_fn):
128118
num_trailing_batch_dims_treated_as_event = [

0 commit comments

Comments
 (0)