Skip to content

Commit b35effd

Browse files
axchtensorflower-gardener
authored andcommitted
Suppress unrelated Cholesky decomposition error in Hypothesis test of batch shapes.
PiperOrigin-RevId: 386462611
1 parent e656867 commit b35effd

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def testCanConstructAndSampleDistribution(self, data):
372372
def testInferredBatchShapeMatchesTrueBatchShape(self, dist_name, data):
373373
dist = data.draw(
374374
dhps.distributions(dist_name=dist_name, validate_args=False))
375-
lp = dist.log_prob(dist.sample(seed=test_util.test_seed()))
375+
with tfp_hps.no_cholesky_decomposition_errors():
376+
lp = dist.log_prob(dist.sample(seed=test_util.test_seed()))
376377

377378
self.assertAllEqual(dist.batch_shape_tensor(), tf.shape(lp))
378379
self.assertAllEqual(dist.batch_shape, tf.shape(lp))

tensorflow_probability/python/internal/hypothesis_testlib.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,24 @@ def no_tf_rank_errors():
763763
raise
764764

765765

766+
@contextlib.contextmanager
767+
def no_cholesky_decomposition_errors():
768+
# Use this to suppress Cholesky errors if needed in tests where the
769+
# numerics are beside the point.
770+
pat = ('Cholesky decomposition was not successful. '
771+
'The input might not be valid. [Op:Cholesky]')
772+
try:
773+
yield
774+
except tf.errors.InvalidArgumentError as e:
775+
msg = str(e)
776+
if pat in msg:
777+
# Tried an input regime where a Cholesky decomposition failed
778+
# and crashed the program.
779+
hp.assume(False)
780+
else:
781+
raise
782+
783+
766784
@contextlib.contextmanager
767785
def finite_ground_truth_only():
768786
# Recognizing the error message from python/internal/numerics_testing.py

0 commit comments

Comments
 (0)