Skip to content

Commit b9fc082

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Replace einsum in JDAB test for TF 2.5 compatibility.
PiperOrigin-RevId: 374774401
1 parent 26aeb77 commit b9fc082

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3030,7 +3030,6 @@ multi_substrate_py_test(
30303030
# numpy dep,
30313031
# tensorflow dep,
30323032
"//tensorflow_probability",
3033-
"//tensorflow_probability/python/internal:prefer_static",
30343033
"//tensorflow_probability/python/internal:test_util",
30353034
# tensorflow/compiler/jit dep,
30363035
],

tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import tensorflow.compat.v2 as tf
3131
import tensorflow_probability as tfp
3232

33-
from tensorflow_probability.python.internal import prefer_static as ps
3433
from tensorflow_probability.python.internal import test_util
3534

3635
tfb = tfp.bijectors
@@ -494,9 +493,7 @@ def test_unit_sample_shape(self):
494493
@tfd.JointDistributionCoroutineAutoBatched
495494
def dist():
496495
x = yield tfd.Normal(loc=tf.zeros([3]), scale=1., name='x')
497-
if ps.rank(x) != 1:
498-
raise ValueError('Unexpected shape.')
499-
yield tfd.Bernoulli(logits=tf.reduce_sum(x), name='y')
496+
yield tfd.Bernoulli(logits=tf.einsum('n->', x), name='y')
500497

501498
for sample_shape in [(), 1, [1], [1, 1], [2]]:
502499
self.assertAllEqual(

0 commit comments

Comments
 (0)