Skip to content

Commit 09f92c2

Browse files
axchtensorflower-gardener
authored andcommitted
Suppress high-rank broadcast errors in psd_kernel_properties_test.testCompositeTensor.
PiperOrigin-RevId: 386863877
1 parent f320a36 commit 09f92c2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def testCompositeTensor(self, kernel_name, data):
140140
example_ndims=1,
141141
feature_dim=2,
142142
feature_ndims=1)))
143-
diag = kernel.apply(xs, xs, example_ndims=1)
143+
with tfp_hps.no_tf_rank_errors():
144+
diag = kernel.apply(xs, xs, example_ndims=1)
144145

145146
# Test flatten/unflatten.
146147
flat = tf.nest.flatten(kernel, expand_composites=True)
@@ -152,8 +153,9 @@ def diag_fn(k):
152153
return k.apply(xs, xs, example_ndims=1)
153154

154155
self.evaluate([v.initializer for v in kernel.variables])
155-
self.assertAllClose(diag, diag_fn(kernel))
156-
self.assertAllClose(diag, diag_fn(unflat))
156+
with tfp_hps.no_tf_rank_errors():
157+
self.assertAllClose(diag, diag_fn(kernel))
158+
self.assertAllClose(diag, diag_fn(unflat))
157159

158160

159161
CONSTRAINTS = {

0 commit comments

Comments
 (0)