Skip to content

Commit d85f921

Browse files
axchtensorflower-gardener
authored andcommitted
Suppress TensorFlow rank errors when batch shape testing.
PiperOrigin-RevId: 387639055
1 parent 4048d82 commit d85f921

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ def testInferredBatchShapeMatchesTrueBatchShape(self, dist_name, data):
373373
with tfp_hps.no_cholesky_decomposition_errors():
374374
dist = data.draw(
375375
dhps.distributions(dist_name=dist_name, validate_args=False))
376-
lp = dist.log_prob(dist.sample(seed=test_util.test_seed()))
376+
with tfp_hps.no_tf_rank_errors():
377+
lp = dist.log_prob(dist.sample(seed=test_util.test_seed()))
377378

378379
self.assertAllEqual(dist.batch_shape_tensor(), tf.shape(lp))
379380
self.assertAllEqual(dist.batch_shape, tf.shape(lp))

0 commit comments

Comments
 (0)