Skip to content

Commit 5b61370

Browse files
davmretensorflower-gardener
authored andcommitted
Replace einsum in JDAB test for TF 2.5 compatibility.
PiperOrigin-RevId: 374718460
1 parent e7733ce commit 5b61370

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3030,6 +3030,7 @@ multi_substrate_py_test(
30303030
# numpy dep,
30313031
# tensorflow dep,
30323032
"//tensorflow_probability",
3033+
"//tensorflow_probability/python/internal:prefer_static",
30333034
"//tensorflow_probability/python/internal:test_util",
30343035
# tensorflow/compiler/jit dep,
30353036
],

tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import tensorflow.compat.v2 as tf
3131
import tensorflow_probability as tfp
3232

33+
from tensorflow_probability.python.internal import prefer_static as ps
3334
from tensorflow_probability.python.internal import test_util
3435

3536
tfb = tfp.bijectors
@@ -493,7 +494,9 @@ def test_unit_sample_shape(self):
493494
@tfd.JointDistributionCoroutineAutoBatched
494495
def dist():
495496
x = yield tfd.Normal(loc=tf.zeros([3]), scale=1., name='x')
496-
yield tfd.Bernoulli(logits=tf.einsum('n->', x), name='y')
497+
if ps.rank(x) != 1:
498+
raise ValueError('Unexpected shape.')
499+
yield tfd.Bernoulli(logits=tf.reduce_sum(x), name='y')
497500

498501
for sample_shape in [(), 1, [1], [1, 1], [2]]:
499502
self.assertAllEqual(

0 commit comments

Comments
 (0)