Skip to content

Commit 5f8765c

Browse files
axchtensorflower-gardener
authored andcommitted
Enforce Distribution statistics' batch shape contract by broadcasting.
The contract is that a Distribution object is supposed to look like all of its parameters' batch shapes are broadcasted together to form the batch shape of the Distribution. This mostly happens automatically, but some statistics don't actually depend on all of the respective Distribution's parameters. (The most common case is the variance of a location-scale family doesn't depend on the location.) This CL inserts tf.broadcast_to at the outputs of appropriate statistic methods to enforce the contract, and updates a Hypothesis test to check for shape-correctness of instantiable Distributions mechanically. The Distribution methods affected are - GeneralizedPareto.entropy - PERT.mode - Triangular.entropy - LambertWNormal.{variance,stddev} - ExponentiallyModifiedGaussian.{mean,variance,stddev} PiperOrigin-RevId: 383723788
1 parent 9d88c41 commit 5f8765c

File tree

6 files changed

+19
-18
lines changed

6 files changed

+19
-18
lines changed

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@
5050

5151

5252
STATISTIC_CONSISTENT_SHAPES_TEST_BLOCK_LIST = (
53-
'BatchReshape', # b/183405889
54-
'Independent', # b/183405889
55-
'Mixture', # b/183405889
56-
'Sample', # b/183405889
57-
'TransformedDistribution', # b/183405889
5853
)
5954

6055

@@ -139,13 +134,12 @@ def check_statistic(
139134
try:
140135
with tfp_hps.no_tf_rank_errors():
141136
result = getattr(dist, statistic)()
142-
msg = 'Shape {} not compatible with expected {}.'.format(
143-
result.shape, expected_static_shape)
144-
self.assertTrue(expected_static_shape.is_compatible_with(
145-
tf.broadcast_static_shape(result.shape, expected_static_shape)), msg)
137+
msg = 'Shape {} of {} not compatible with expected {}.'.format(
138+
result.shape, statistic, expected_static_shape)
139+
self.assertTrue(
140+
expected_static_shape.is_compatible_with(result.shape), msg)
146141
self.assertAllEqual(self.evaluate(expected_dynamic_shape),
147-
self.evaluate(tf.broadcast_dynamic_shape(
148-
tf.shape(result), expected_dynamic_shape)))
142+
self.evaluate(tf.shape(result)))
149143
except NotImplementedError:
150144
pass
151145

tensorflow_probability/python/distributions/exponentially_modified_gaussian.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,13 @@ def _log_cdf(self, x):
198198
-u + vsquared / 2. + special_math.log_ndtr((u - vsquared) / v))
199199

200200
def _mean(self):
201-
return self.loc + 1 / self.rate
201+
return tf.broadcast_to(
202+
self.loc + 1 / self.rate, self._batch_shape_tensor())
202203

203204
def _variance(self):
204-
return tf.square(self.scale) + 1 / tf.square(self.rate)
205+
return tf.broadcast_to(
206+
tf.square(self.scale) + 1 / tf.square(self.rate),
207+
self._batch_shape_tensor())
205208

206209
def _parameter_control_dependencies(self, is_init):
207210
assertions = []

tensorflow_probability/python/distributions/generalized_pareto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def result():
290290
return result()
291291

292292
def _entropy(self):
293-
return tf.math.log(self.scale) + self.concentration + 1
293+
ans = tf.math.log(self.scale) + self.concentration + 1
294+
return tf.broadcast_to(ans, self._batch_shape_tensor())
294295

295296
# TODO(b/145620027): Finalize choice of bijector.
296297
def _default_event_space_bijector(self):

tensorflow_probability/python/distributions/lambertw_f.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,18 @@ def _variance(self):
269269
tf.convert_to_tensor(np.inf, dtype=self.dtype))
270270

271271
if self.allow_nan_stats:
272-
return tf.where(
272+
ans = tf.where(
273273
tailweight < 1.0,
274274
result_where_defined,
275275
tf.convert_to_tensor(np.nan, self.dtype))
276276
else:
277-
return distribution_util.with_dependencies([
277+
ans = distribution_util.with_dependencies([
278278
assert_util.assert_greater_equal(
279279
tf.ones([], dtype=self.dtype),
280280
tailweight,
281281
message="variance not defined for components of tailweight >= 1"),
282282
], result_where_defined)
283+
return tf.broadcast_to(ans, self._batch_shape_tensor())
283284

284285
def _mode(self):
285286
# Mode always exists (for any tail parameter) and equals the location / mean

tensorflow_probability/python/distributions/pert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def _mean(self):
213213
return self._transformed_beta().mean()
214214

215215
def _mode(self):
216-
return tf.convert_to_tensor(self.peak)
216+
return tf.broadcast_to(tf.convert_to_tensor(self.peak),
217+
self._batch_shape_tensor())
217218

218219
def _variance(self):
219220
low = tf.convert_to_tensor(self.low)

tensorflow_probability/python/distributions/triangular.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def _cdf(self, x):
251251
return tf.where(x >= high, tf.ones_like(x), result_if_not_big)
252252

253253
def _entropy(self):
254-
return 0.5 - np.log(2.) + tf.math.log(self.high - self.low)
254+
ans = 0.5 - np.log(2.) + tf.math.log(self.high - self.low)
255+
return tf.broadcast_to(ans, self._batch_shape_tensor())
255256

256257
def _mean(self):
257258
return (self.low + self.high + self.peak) / 3.

0 commit comments

Comments
 (0)