Skip to content

Commit 5838e0e

Browse files
davmretensorflower-gardener
authored andcommitted
Define experimental batch shape methods for Bijectors.
This opens the door to correctly calculating `log_det_jacobians` via autodiff (the current implementation is incorrect for bijectors with batch shape), automatically broadcasting the batch shapes of TransformedDistributions, batch slicing for TransformedDistributions, and returning Bijectors and TransformedDistributions from vectorized_map calls, among possibly other things. Q: Why do bijector batch shapes depend on event_ndims? A: One way to think about this is that batch shape is (up to broadcasting) the shape of the LDJ. Since LDJs are conditional on `event_ndims`, it's not surprising that batch shape would be too. We *could* try to define a batch shape inherent to each bijector, but JointMap (and other bijectors with `parts_are_independent==True`) makes this hard because the different parts can have different batch shapes. In the absence of a concrete event_ndims to tell us how the parts align, the only batch shape we could define would be structured. It turns out that working with structured batch shapes is a pretty big pain (eg, it puts some heavy demands on TransformedDistribution's parameter_properties.event_ndims annotations). After going down this road for a while, it seemed a lot simpler to live in a world where the batch shape is always just a single shape. PiperOrigin-RevId: 374960009
1 parent 96b024d commit 5838e0e

File tree

8 files changed

+370
-15
lines changed

8 files changed

+370
-15
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ multi_substrate_py_library(
128128
# tensorflow dep,
129129
"//tensorflow_probability/python/internal:assert_util",
130130
"//tensorflow_probability/python/internal:auto_composite_tensor",
131+
"//tensorflow_probability/python/internal:batch_shape_lib",
131132
"//tensorflow_probability/python/internal:cache_util",
133+
"//tensorflow_probability/python/internal:docstring_util",
132134
"//tensorflow_probability/python/internal:dtype_util",
133135
"//tensorflow_probability/python/internal:name_util",
134136
"//tensorflow_probability/python/internal:nest_util",
@@ -972,6 +974,7 @@ multi_substrate_py_library(
972974
srcs = ["split.py"],
973975
deps = [
974976
":bijector",
977+
":invert",
975978
# numpy dep,
976979
# tensorflow dep,
977980
"//tensorflow_probability/python/internal:assert_util",
@@ -1094,6 +1097,7 @@ multi_substrate_py_test(
10941097
# mock dep,
10951098
# numpy dep,
10961099
# tensorflow dep,
1100+
"//tensorflow_probability",
10971101
"//tensorflow_probability/python/distributions",
10981102
"//tensorflow_probability/python/internal:cache_util",
10991103
"//tensorflow_probability/python/internal:parameter_properties",

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929

3030
from tensorflow_probability.python.internal import assert_util
3131
from tensorflow_probability.python.internal import auto_composite_tensor
32+
from tensorflow_probability.python.internal import batch_shape_lib
3233
from tensorflow_probability.python.internal import cache_util
3334
from tensorflow_probability.python.internal import dtype_util
3435
from tensorflow_probability.python.internal import name_util
3536
from tensorflow_probability.python.internal import nest_util
3637
from tensorflow_probability.python.internal import prefer_static as ps
38+
from tensorflow_probability.python.internal import tensorshape_util
3739
from tensorflow_probability.python.math import gradient
3840
# pylint: disable=g-direct-tensorflow-import
3941
from tensorflow.python.util import deprecation
@@ -1069,6 +1071,151 @@ def inverse_event_shape(self, output_shape):
10691071
self.forward_min_event_ndims, tf.TensorShape,
10701072
self._inverse_event_shape(output_shape))
10711073

1074+
def _get_x_event_ndims(self, x_event_ndims=None, y_event_ndims=None):
1075+
if x_event_ndims is None:
1076+
if y_event_ndims is not None:
1077+
x_event_ndims = self.inverse_event_ndims(y_event_ndims)
1078+
else: # Default to `min_event_ndims` if not explicitly specified.
1079+
return self.forward_min_event_ndims
1080+
elif y_event_ndims is not None:
1081+
raise ValueError(
1082+
'Only one of `x_event_ndims` and `y_event_ndims` may be specified.')
1083+
return x_event_ndims
1084+
1085+
def _batch_shape(self, x_event_ndims):
1086+
if not self._params_event_ndims():
1087+
# Skip requirement for a unique difference in event ndims if this bijector
1088+
# wouldn't have batch shape anyway.
1089+
return tensorshape_util.constant_value_as_shape([])
1090+
1091+
# Infer batch shape from annotations returned by `_parameter_properties()`.
1092+
# Batch shape inference assumes that the provided and minimum event ndims
1093+
# differ by the same amount in all parts. Bijectors with multiple
1094+
# independent parts will need to override this method, or inherit from a
1095+
# class (such as Composition) that does so.
1096+
return batch_shape_lib.inferred_batch_shape(
1097+
self,
1098+
additional_event_ndims=_unique_difference(x_event_ndims,
1099+
self.forward_min_event_ndims))
1100+
1101+
def experimental_batch_shape(self, x_event_ndims=None, y_event_ndims=None):
1102+
"""Returns the batch shape of this bijector for inputs of the given rank.
1103+
1104+
The batch shape of a bijector decribes the set of distinct
1105+
transformations it represents on events of a given size. For example: the
1106+
bijector `tfb.Scale([1., 2.])` has batch shape `[2]` for scalar events
1107+
(`event_ndims = 0`), because applying it to a scalar event produces
1108+
two scalar outputs, the result of two different scaling transformations.
1109+
The same bijector has batch shape `[]` for vector events, because applying
1110+
it to a vector produces (via elementwise multiplication) a single vector
1111+
output.
1112+
1113+
Bijectors that operate independently on multiple state parts, such as
1114+
`tfb.JointMap`, must broadcast to a coherent batch shape. Some events may
1115+
not be valid: for example, the bijector
1116+
`tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])` does not
1117+
produce a valid batch shape when `event_ndims = [0, 0]`, since the batch
1118+
shapes of the two parts are inconsistent. The same bijector
1119+
does define valid batch shapes of `[]`, `[2]`, and `[3]` if `event_ndims`
1120+
is `[1, 1]`, `[0, 1]`, or `[1, 0]`, respectively.
1121+
1122+
Since transforming a single event produces a scalar log-det-Jacobian, the
1123+
batch shape of a bijector with non-constant Jacobian is expected to equal
1124+
the shape of `forward_log_det_jacobian(x, event_ndims=x_event_ndims)`
1125+
or `inverse_log_det_jacobian(y, event_ndims=y_event_ndims)`, for `x`
1126+
or `y` of the specified `ndims`.
1127+
1128+
Args:
1129+
x_event_ndims: Optional Python `int` (structure) number of dimensions in
1130+
a probabilistic event passed to `forward`; this must be greater than
1131+
or equal to `self.forward_min_event_ndims`. If `None`, defaults to
1132+
`self.forward_min_event_ndims`. Mutually exclusive with `y_event_ndims`.
1133+
Default value: `None`.
1134+
y_event_ndims: Optional Python `int` (structure) number of dimensions in
1135+
a probabilistic event passed to `inverse`; this must be greater than
1136+
or equal to `self.inverse_min_event_ndims`. Mutually exclusive with
1137+
`x_event_ndims`.
1138+
Default value: `None`.
1139+
Returns:
1140+
batch_shape: `TensorShape` batch shape of this bijector for a
1141+
value with the given event rank. May be unknown or partially defined.
1142+
"""
1143+
x_event_ndims = self._get_x_event_ndims(x_event_ndims, y_event_ndims)
1144+
# Cache batch shape to avoid the overhead of recomputing it.
1145+
if not hasattr(self, '_cached_batch_shapes'):
1146+
self._cached_batch_shapes = self._no_dependency({})
1147+
key = _deep_tuple(x_event_ndims) # Avoid hashing lists/dicts.
1148+
if key not in self._cached_batch_shapes:
1149+
self._cached_batch_shapes[key] = self._batch_shape(x_event_ndims)
1150+
return self._cached_batch_shapes[key]
1151+
1152+
def _batch_shape_tensor(self, x_event_ndims):
1153+
if not self._params_event_ndims():
1154+
# Skip requirement for a unique difference in event ndims if this bijector
1155+
# wouldn't have batch shape anyway.
1156+
return []
1157+
1158+
# Infer batch shape from annotations returned by `_parameter_properties()`.
1159+
# Batch shape inference assumes that the provided and minimum event ndims
1160+
# differ by the same amount in all parts. Bijectors with multiple
1161+
# independent parts will need to override this method, or inherit from a
1162+
# class (such as Composition) that does so.
1163+
return batch_shape_lib.inferred_batch_shape_tensor(
1164+
self, additional_event_ndims=_unique_difference(
1165+
x_event_ndims, self.forward_min_event_ndims))
1166+
1167+
def experimental_batch_shape_tensor(self,
1168+
x_event_ndims=None,
1169+
y_event_ndims=None):
1170+
"""Returns the batch shape of this bijector for inputs of the given rank.
1171+
1172+
The batch shape of a bijector decribes the set of distinct
1173+
transformations it represents on events of a given size. For example: the
1174+
bijector `tfb.Scale([1., 2.])` has batch shape `[2]` for scalar events
1175+
(`event_ndims = 0`), because applying it to a scalar event produces
1176+
two scalar outputs, the result of two different scaling transformations.
1177+
The same bijector has batch shape `[]` for vector events, because applying
1178+
it to a vector produces (via elementwise multiplication) a single vector
1179+
output.
1180+
1181+
Bijectors that operate independently on multiple state parts, such as
1182+
`tfb.JointMap`, must broadcast to a coherent batch shape. Some events may
1183+
not be valid: for example, the bijector
1184+
`tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])` does not
1185+
produce a valid batch shape when `event_ndims = [0, 0]`, since the batch
1186+
shapes of the two parts are inconsistent. The same bijector
1187+
does define valid batch shapes of `[]`, `[2]`, and `[3]` if `event_ndims`
1188+
is `[1, 1]`, `[0, 1]`, or `[1, 0]`, respectively.
1189+
1190+
Since transforming a single event produces a scalar log-det-Jacobian, the
1191+
batch shape of a bijector with non-constant Jacobian is expected to equal
1192+
the shape of `forward_log_det_jacobian(x, event_ndims=x_event_ndims)`
1193+
or `inverse_log_det_jacobian(y, event_ndims=y_event_ndims)`, for `x`
1194+
or `y` of the specified `ndims`.
1195+
1196+
Args:
1197+
x_event_ndims: Optional Python `int` (structure) number of dimensions in
1198+
a probabilistic event passed to `forward`; this must be greater than
1199+
or equal to `self.forward_min_event_ndims`. If `None`, defaults to
1200+
`self.forward_min_event_ndims`. Mutually exclusive with `y_event_ndims`.
1201+
Default value: `None`.
1202+
y_event_ndims: Optional Python `int` (structure) number of dimensions in
1203+
a probabilistic event passed to `inverse`; this must be greater than
1204+
or equal to `self.inverse_min_event_ndims`. Mutually exclusive with
1205+
`x_event_ndims`.
1206+
Default value: `None`.
1207+
Returns:
1208+
batch_shape_tensor: integer `Tensor` batch shape of this bijector for a
1209+
value with the given event rank.
1210+
"""
1211+
with tf.name_scope('experimental_batch_shape_tensor'):
1212+
x_event_ndims = self._get_x_event_ndims(x_event_ndims, y_event_ndims)
1213+
# Try to get the static batch shape.
1214+
batch_shape = self.experimental_batch_shape(x_event_ndims=x_event_ndims)
1215+
if not tensorshape_util.is_fully_defined(batch_shape):
1216+
batch_shape = self._batch_shape_tensor(x_event_ndims)
1217+
return batch_shape
1218+
10721219
@classmethod
10731220
def _parameter_properties(cls, dtype):
10741221
raise NotImplementedError(
@@ -1108,6 +1255,7 @@ def _params_event_ndims(cls):
11081255
return {
11091256
param_name: param.event_ndims
11101257
for param_name, param in cls.parameter_properties().items()
1258+
if param.event_ndims is not None
11111259
}
11121260

11131261
def _forward(self, x):
@@ -1968,3 +2116,22 @@ def _autodiff_log_det_jacobian(fn, x):
19682116
raise ValueError('Cannot compute log det jacobian; function {} has `None` '
19692117
'gradient.'.format(fn))
19702118
return tf.math.log(tf.abs(grads))
2119+
2120+
2121+
def _unique_difference(structure1, structure2):
2122+
differences = [a - b
2123+
for a, b in
2124+
zip(tf.nest.flatten(structure1), tf.nest.flatten(structure2))]
2125+
if all([d == differences[0] for d in differences]):
2126+
return differences[0]
2127+
raise ValueError('Could not find unique difference between {} and {}'
2128+
.format(structure1, structure2))
2129+
2130+
2131+
def _deep_tuple(x):
2132+
"""Converts nested `tuple`, `list`, or `dict` to nested `tuple`."""
2133+
if hasattr(x, 'keys'):
2134+
return _deep_tuple(tuple(x.items()))
2135+
elif isinstance(x, (list, tuple)):
2136+
return tuple(map(_deep_tuple, x))
2137+
return x

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@
3030
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
3131
from tensorflow_probability.python.bijectors import invert as invert_lib
3232
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
33-
from tensorflow_probability.python.internal import prefer_static
33+
from tensorflow_probability.python.internal import prefer_static as ps
3434
from tensorflow_probability.python.internal import samplers
3535
from tensorflow_probability.python.internal import tensor_util
3636
from tensorflow_probability.python.internal import tensorshape_util
3737
from tensorflow_probability.python.internal import test_util
3838
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3939

40+
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
4041

4142
TF2_FRIENDLY_BIJECTORS = (
4243
'Ascending',
@@ -701,6 +702,21 @@ def exception(bijector):
701702
self.assertTrue(bijector._is_constant_jacobian)
702703
self.assertAllEqual(ldj, 0.)
703704

705+
# Verify correctness of batch shape.
706+
xs_batch_shapes = tf.nest.map_structure(
707+
lambda x, nd: ps.shape(x)[:ps.rank(x) - nd],
708+
xs,
709+
bijector.inverse_event_ndims(event_ndims))
710+
empirical_batch_shape = functools.reduce(
711+
ps.broadcast_shape,
712+
nest.flatten_up_to(bijector.forward_min_event_ndims, xs_batch_shapes))
713+
batch_shape = bijector.experimental_batch_shape(y_event_ndims=event_ndims)
714+
if tensorshape_util.is_fully_defined(batch_shape):
715+
self.assertAllEqual(empirical_batch_shape, batch_shape)
716+
self.assertAllEqual(empirical_batch_shape,
717+
bijector.experimental_batch_shape_tensor(
718+
y_event_ndims=event_ndims))
719+
704720
# Check that the outputs of forward_dtype and inverse_dtype match the dtypes
705721
# of the outputs of forward and inverse.
706722
self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
@@ -737,8 +753,8 @@ def testParameterProperties(self, bijector_name, data):
737753
# Extract the full shape of an output from this bijector.
738754
xs = self._draw_domain_tensor(bijector, data, event_dim)
739755
ys = bijector.forward(xs)
740-
output_shape = prefer_static.shape(ys)
741-
sample_and_batch_ndims = (prefer_static.rank_from_shape(output_shape) -
756+
output_shape = ps.shape(ys)
757+
sample_and_batch_ndims = (ps.rank_from_shape(output_shape) -
742758
bijector.inverse_min_event_ndims)
743759

744760
try:
@@ -761,7 +777,7 @@ def testParameterProperties(self, bijector_name, data):
761777
'parameter {}.'.format(bijector_name, param_name))
762778
self.assertGreaterEqual(
763779
param.event_ndims,
764-
prefer_static.rank_from_shape(param_shape) - sample_and_batch_ndims)
780+
ps.rank_from_shape(param_shape) - sample_and_batch_ndims)
765781

766782
if param.is_preferred:
767783
try:
@@ -827,7 +843,7 @@ def testAutoVectorization(self, bijector_name, data):
827843
event_ndims = data.draw(
828844
hps.integers(
829845
min_value=bijector.forward_min_event_ndims,
830-
max_value=prefer_static.rank_from_shape(xs.shape) - 1))
846+
max_value=ps.rank_from_shape(xs.shape) - 1))
831847
fldj_fn = functools.partial(bijector.forward_log_det_jacobian,
832848
event_ndims=event_ndims)
833849
vectorized_fldj = tf.vectorized_map(fldj_fn, xs,
@@ -848,7 +864,7 @@ def testAutoVectorization(self, bijector_name, data):
848864
event_ndims = data.draw(
849865
hps.integers(
850866
min_value=bijector.inverse_min_event_ndims,
851-
max_value=prefer_static.rank_from_shape(ys.shape) - 1))
867+
max_value=ps.rank_from_shape(ys.shape) - 1))
852868
ildj_fn = functools.partial(bijector.inverse_log_det_jacobian,
853869
event_ndims=event_ndims)
854870
vectorized_ildj = tf.vectorized_map(ildj_fn, ys,

0 commit comments

Comments
 (0)