Skip to content

Commit 2aac7fd

Browse files
Revert "Fix some np.array calls"
This reverts commit e34b656.
1 parent c4c3425 commit 2aac7fd

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorflow_probability/python/math/special_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ def testBetaincGradient(self, dtype):
200200
space = np.logspace(-2., 2., 10).tolist()
201201
space_x = np.linspace(0.01, 0.99, 10).tolist()
202202
a, b, x = zip(*list(itertools.product(space, space, space_x)))
203-
a, b, x = [np.array(z, dtype=dtype) for z in [a, b, x]]
203+
204+
a = np.array(a, dtype=dtype)
205+
b = np.array(b, dtype=dtype)
206+
x = np.array(x, dtype=dtype)
204207

205208
# Wrap in tf.function and compile for faster computations.
206209
betainc = tf.function(tfp_math.betainc, autograph=False, jit_compile=True)
@@ -229,7 +232,10 @@ def testBetaincDerivativeFinite(self, dtype):
229232
space = np.logspace(np.log10(eps), 5.).tolist()
230233
space_x = np.linspace(eps, 1. - eps).tolist()
231234
a, b, x = zip(*list(itertools.product(space, space, space_x)))
232-
a, b, x = [np.array(z, dtype=dtype) for z in [a, b, x]]
235+
236+
a = np.array(a, dtype=dtype)
237+
b = np.array(b, dtype=dtype)
238+
x = np.array(x, dtype=dtype)
233239

234240
def betainc_partials(a, b, x):
235241
return tfp_math.value_and_gradient(tfp_math.betainc, [a, b, x])[1]
@@ -433,7 +439,10 @@ def testBetaincSecondDerivativeFinite(self, dtype):
433439
space = np.logspace(-2., 2., 5).tolist()
434440
space_x = np.linspace(0.01, 0.99, 5).tolist()
435441
a, b, x = zip(*list(itertools.product(space, space, space_x)))
436-
a, b, x = [np.array(z, dtype=dtype) for z in [a, b, x]]
442+
443+
a = np.array(a, dtype=dtype)
444+
b = np.array(b, dtype=dtype)
445+
x = np.array(x, dtype=dtype)
437446

438447
def betainc_partials(a, b, x):
439448
return tfp_math.value_and_gradient(tfp_math.betainc, [a, b, x])[1]

0 commit comments

Comments
 (0)