|
29 | 29 |
|
30 | 30 | from tensorflow_probability.python.internal import assert_util
|
31 | 31 | from tensorflow_probability.python.internal import auto_composite_tensor
|
| 32 | +from tensorflow_probability.python.internal import batch_shape_lib |
32 | 33 | from tensorflow_probability.python.internal import cache_util
|
33 | 34 | from tensorflow_probability.python.internal import dtype_util
|
34 | 35 | from tensorflow_probability.python.internal import name_util
|
35 | 36 | from tensorflow_probability.python.internal import nest_util
|
36 | 37 | from tensorflow_probability.python.internal import prefer_static as ps
|
| 38 | +from tensorflow_probability.python.internal import tensorshape_util |
37 | 39 | from tensorflow_probability.python.math import gradient
|
38 | 40 | # pylint: disable=g-direct-tensorflow-import
|
39 | 41 | from tensorflow.python.util import deprecation
|
@@ -1069,6 +1071,151 @@ def inverse_event_shape(self, output_shape):
|
1069 | 1071 | self.forward_min_event_ndims, tf.TensorShape,
|
1070 | 1072 | self._inverse_event_shape(output_shape))
|
1071 | 1073 |
|
| 1074 | + def _get_x_event_ndims(self, x_event_ndims=None, y_event_ndims=None): |
| 1075 | + if x_event_ndims is None: |
| 1076 | + if y_event_ndims is not None: |
| 1077 | + x_event_ndims = self.inverse_event_ndims(y_event_ndims) |
| 1078 | + else: # Default to `min_event_ndims` if not explicitly specified. |
| 1079 | + return self.forward_min_event_ndims |
| 1080 | + elif y_event_ndims is not None: |
| 1081 | + raise ValueError( |
| 1082 | + 'Only one of `x_event_ndims` and `y_event_ndims` may be specified.') |
| 1083 | + return x_event_ndims |
| 1084 | + |
| 1085 | + def _batch_shape(self, x_event_ndims): |
| 1086 | + if not self._params_event_ndims(): |
| 1087 | + # Skip requirement for a unique difference in event ndims if this bijector |
| 1088 | + # wouldn't have batch shape anyway. |
| 1089 | + return tensorshape_util.constant_value_as_shape([]) |
| 1090 | + |
| 1091 | + # Infer batch shape from annotations returned by `_parameter_properties()`. |
| 1092 | + # Batch shape inference assumes that the provided and minimum event ndims |
| 1093 | + # differ by the same amount in all parts. Bijectors with multiple |
| 1094 | + # independent parts will need to override this method, or inherit from a |
| 1095 | + # class (such as Composition) that does so. |
| 1096 | + return batch_shape_lib.inferred_batch_shape( |
| 1097 | + self, |
| 1098 | + additional_event_ndims=_unique_difference(x_event_ndims, |
| 1099 | + self.forward_min_event_ndims)) |
| 1100 | + |
| 1101 | + def experimental_batch_shape(self, x_event_ndims=None, y_event_ndims=None): |
| 1102 | + """Returns the batch shape of this bijector for inputs of the given rank. |
| 1103 | +
|
| 1104 | + The batch shape of a bijector decribes the set of distinct |
| 1105 | + transformations it represents on events of a given size. For example: the |
| 1106 | + bijector `tfb.Scale([1., 2.])` has batch shape `[2]` for scalar events |
| 1107 | + (`event_ndims = 0`), because applying it to a scalar event produces |
| 1108 | + two scalar outputs, the result of two different scaling transformations. |
| 1109 | + The same bijector has batch shape `[]` for vector events, because applying |
| 1110 | + it to a vector produces (via elementwise multiplication) a single vector |
| 1111 | + output. |
| 1112 | +
|
| 1113 | + Bijectors that operate independently on multiple state parts, such as |
| 1114 | + `tfb.JointMap`, must broadcast to a coherent batch shape. Some events may |
| 1115 | + not be valid: for example, the bijector |
| 1116 | + `tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])` does not |
| 1117 | + produce a valid batch shape when `event_ndims = [0, 0]`, since the batch |
| 1118 | + shapes of the two parts are inconsistent. The same bijector |
| 1119 | + does define valid batch shapes of `[]`, `[2]`, and `[3]` if `event_ndims` |
| 1120 | + is `[1, 1]`, `[0, 1]`, or `[1, 0]`, respectively. |
| 1121 | +
|
| 1122 | + Since transforming a single event produces a scalar log-det-Jacobian, the |
| 1123 | + batch shape of a bijector with non-constant Jacobian is expected to equal |
| 1124 | + the shape of `forward_log_det_jacobian(x, event_ndims=x_event_ndims)` |
| 1125 | + or `inverse_log_det_jacobian(y, event_ndims=y_event_ndims)`, for `x` |
| 1126 | + or `y` of the specified `ndims`. |
| 1127 | +
|
| 1128 | + Args: |
| 1129 | + x_event_ndims: Optional Python `int` (structure) number of dimensions in |
| 1130 | + a probabilistic event passed to `forward`; this must be greater than |
| 1131 | + or equal to `self.forward_min_event_ndims`. If `None`, defaults to |
| 1132 | + `self.forward_min_event_ndims`. Mutually exclusive with `y_event_ndims`. |
| 1133 | + Default value: `None`. |
| 1134 | + y_event_ndims: Optional Python `int` (structure) number of dimensions in |
| 1135 | + a probabilistic event passed to `inverse`; this must be greater than |
| 1136 | + or equal to `self.inverse_min_event_ndims`. Mutually exclusive with |
| 1137 | + `x_event_ndims`. |
| 1138 | + Default value: `None`. |
| 1139 | + Returns: |
| 1140 | + batch_shape: `TensorShape` batch shape of this bijector for a |
| 1141 | + value with the given event rank. May be unknown or partially defined. |
| 1142 | + """ |
| 1143 | + x_event_ndims = self._get_x_event_ndims(x_event_ndims, y_event_ndims) |
| 1144 | + # Cache batch shape to avoid the overhead of recomputing it. |
| 1145 | + if not hasattr(self, '_cached_batch_shapes'): |
| 1146 | + self._cached_batch_shapes = self._no_dependency({}) |
| 1147 | + key = _deep_tuple(x_event_ndims) # Avoid hashing lists/dicts. |
| 1148 | + if key not in self._cached_batch_shapes: |
| 1149 | + self._cached_batch_shapes[key] = self._batch_shape(x_event_ndims) |
| 1150 | + return self._cached_batch_shapes[key] |
| 1151 | + |
| 1152 | + def _batch_shape_tensor(self, x_event_ndims): |
| 1153 | + if not self._params_event_ndims(): |
| 1154 | + # Skip requirement for a unique difference in event ndims if this bijector |
| 1155 | + # wouldn't have batch shape anyway. |
| 1156 | + return [] |
| 1157 | + |
| 1158 | + # Infer batch shape from annotations returned by `_parameter_properties()`. |
| 1159 | + # Batch shape inference assumes that the provided and minimum event ndims |
| 1160 | + # differ by the same amount in all parts. Bijectors with multiple |
| 1161 | + # independent parts will need to override this method, or inherit from a |
| 1162 | + # class (such as Composition) that does so. |
| 1163 | + return batch_shape_lib.inferred_batch_shape_tensor( |
| 1164 | + self, additional_event_ndims=_unique_difference( |
| 1165 | + x_event_ndims, self.forward_min_event_ndims)) |
| 1166 | + |
| 1167 | + def experimental_batch_shape_tensor(self, |
| 1168 | + x_event_ndims=None, |
| 1169 | + y_event_ndims=None): |
| 1170 | + """Returns the batch shape of this bijector for inputs of the given rank. |
| 1171 | +
|
| 1172 | + The batch shape of a bijector decribes the set of distinct |
| 1173 | + transformations it represents on events of a given size. For example: the |
| 1174 | + bijector `tfb.Scale([1., 2.])` has batch shape `[2]` for scalar events |
| 1175 | + (`event_ndims = 0`), because applying it to a scalar event produces |
| 1176 | + two scalar outputs, the result of two different scaling transformations. |
| 1177 | + The same bijector has batch shape `[]` for vector events, because applying |
| 1178 | + it to a vector produces (via elementwise multiplication) a single vector |
| 1179 | + output. |
| 1180 | +
|
| 1181 | + Bijectors that operate independently on multiple state parts, such as |
| 1182 | + `tfb.JointMap`, must broadcast to a coherent batch shape. Some events may |
| 1183 | + not be valid: for example, the bijector |
| 1184 | + `tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])` does not |
| 1185 | + produce a valid batch shape when `event_ndims = [0, 0]`, since the batch |
| 1186 | + shapes of the two parts are inconsistent. The same bijector |
| 1187 | + does define valid batch shapes of `[]`, `[2]`, and `[3]` if `event_ndims` |
| 1188 | + is `[1, 1]`, `[0, 1]`, or `[1, 0]`, respectively. |
| 1189 | +
|
| 1190 | + Since transforming a single event produces a scalar log-det-Jacobian, the |
| 1191 | + batch shape of a bijector with non-constant Jacobian is expected to equal |
| 1192 | + the shape of `forward_log_det_jacobian(x, event_ndims=x_event_ndims)` |
| 1193 | + or `inverse_log_det_jacobian(y, event_ndims=y_event_ndims)`, for `x` |
| 1194 | + or `y` of the specified `ndims`. |
| 1195 | +
|
| 1196 | + Args: |
| 1197 | + x_event_ndims: Optional Python `int` (structure) number of dimensions in |
| 1198 | + a probabilistic event passed to `forward`; this must be greater than |
| 1199 | + or equal to `self.forward_min_event_ndims`. If `None`, defaults to |
| 1200 | + `self.forward_min_event_ndims`. Mutually exclusive with `y_event_ndims`. |
| 1201 | + Default value: `None`. |
| 1202 | + y_event_ndims: Optional Python `int` (structure) number of dimensions in |
| 1203 | + a probabilistic event passed to `inverse`; this must be greater than |
| 1204 | + or equal to `self.inverse_min_event_ndims`. Mutually exclusive with |
| 1205 | + `x_event_ndims`. |
| 1206 | + Default value: `None`. |
| 1207 | + Returns: |
| 1208 | + batch_shape_tensor: integer `Tensor` batch shape of this bijector for a |
| 1209 | + value with the given event rank. |
| 1210 | + """ |
| 1211 | + with tf.name_scope('experimental_batch_shape_tensor'): |
| 1212 | + x_event_ndims = self._get_x_event_ndims(x_event_ndims, y_event_ndims) |
| 1213 | + # Try to get the static batch shape. |
| 1214 | + batch_shape = self.experimental_batch_shape(x_event_ndims=x_event_ndims) |
| 1215 | + if not tensorshape_util.is_fully_defined(batch_shape): |
| 1216 | + batch_shape = self._batch_shape_tensor(x_event_ndims) |
| 1217 | + return batch_shape |
| 1218 | + |
1072 | 1219 | @classmethod
|
1073 | 1220 | def _parameter_properties(cls, dtype):
|
1074 | 1221 | raise NotImplementedError(
|
@@ -1108,6 +1255,7 @@ def _params_event_ndims(cls):
|
1108 | 1255 | return {
|
1109 | 1256 | param_name: param.event_ndims
|
1110 | 1257 | for param_name, param in cls.parameter_properties().items()
|
| 1258 | + if param.event_ndims is not None |
1111 | 1259 | }
|
1112 | 1260 |
|
1113 | 1261 | def _forward(self, x):
|
@@ -1968,3 +2116,22 @@ def _autodiff_log_det_jacobian(fn, x):
|
1968 | 2116 | raise ValueError('Cannot compute log det jacobian; function {} has `None` '
|
1969 | 2117 | 'gradient.'.format(fn))
|
1970 | 2118 | return tf.math.log(tf.abs(grads))
|
| 2119 | + |
| 2120 | + |
| 2121 | +def _unique_difference(structure1, structure2): |
| 2122 | + differences = [a - b |
| 2123 | + for a, b in |
| 2124 | + zip(tf.nest.flatten(structure1), tf.nest.flatten(structure2))] |
| 2125 | + if all([d == differences[0] for d in differences]): |
| 2126 | + return differences[0] |
| 2127 | + raise ValueError('Could not find unique difference between {} and {}' |
| 2128 | + .format(structure1, structure2)) |
| 2129 | + |
| 2130 | + |
| 2131 | +def _deep_tuple(x): |
| 2132 | + """Converts nested `tuple`, `list`, or `dict` to nested `tuple`.""" |
| 2133 | + if hasattr(x, 'keys'): |
| 2134 | + return _deep_tuple(tuple(x.items())) |
| 2135 | + elif isinstance(x, (list, tuple)): |
| 2136 | + return tuple(map(_deep_tuple, x)) |
| 2137 | + return x |
0 commit comments