Skip to content

Commit 6fb306b

Browse files
srvasudetensorflower-gardener
authored andcommitted
Enable most PSDKernel tests in numpy.
- Avoid mutating numpy arrays. - Disable dynamic shape / dtype tests. PiperOrigin-RevId: 463755123
1 parent 796ecd6 commit 6fb306b

File tree

9 files changed

+24
-10
lines changed

9 files changed

+24
-10
lines changed

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ multi_substrate_py_test(
7272
name = "positive_semidefinite_kernel_test",
7373
size = "small",
7474
srcs = ["positive_semidefinite_kernel_test.py"],
75-
numpy_tags = ["notap"],
7675
deps = [
7776
":positive_semidefinite_kernel",
7877
# absl/testing:parameterized dep,
@@ -126,7 +125,6 @@ multi_substrate_py_test(
126125
name = "exponentiated_quadratic_test",
127126
size = "small",
128127
srcs = ["exponentiated_quadratic_test.py"],
129-
numpy_tags = ["notap"],
130128
deps = [
131129
# absl/testing:parameterized dep,
132130
# numpy dep,
@@ -153,7 +151,6 @@ multi_substrate_py_test(
153151
name = "exp_sin_squared_test",
154152
size = "small",
155153
srcs = ["exp_sin_squared_test.py"],
156-
numpy_tags = ["notap"],
157154
deps = [
158155
# absl/testing:parameterized dep,
159156
# numpy dep,
@@ -181,7 +178,6 @@ multi_substrate_py_test(
181178
size = "medium",
182179
srcs = ["matern_test.py"],
183180
jax_size = "medium",
184-
numpy_tags = ["notap"],
185181
shard_count = 4,
186182
deps = [
187183
# absl/testing:parameterized dep,
@@ -233,7 +229,6 @@ multi_substrate_py_test(
233229
name = "rational_quadratic_test",
234230
size = "small",
235231
srcs = ["rational_quadratic_test.py"],
236-
numpy_tags = ["notap"],
237232
deps = [
238233
# absl/testing:parameterized dep,
239234
# numpy dep,
@@ -291,7 +286,6 @@ multi_substrate_py_test(
291286
size = "medium",
292287
srcs = ["schur_complement_test.py"],
293288
jax_size = "large",
294-
numpy_tags = ["notap"],
295289
deps = [
296290
# absl/testing:parameterized dep,
297291
# numpy dep,

tensorflow_probability/python/math/psd_kernels/exp_sin_squared.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,20 @@ def _apply(self, x1, x2, example_ndims=0):
9898
# pad the shape with self.feature_ndims number of ones.
9999
period = util.pad_shape_with_ones(
100100
period, ndims=(example_ndims + self.feature_ndims))
101-
difference /= period
101+
difference = difference / period
102102
log_kernel = util.sum_rightmost_ndims_preserving_shape(
103103
-2 * tf.sin(difference) ** 2, ndims=self.feature_ndims)
104104

105105
if self.length_scale is not None:
106106
length_scale = tf.convert_to_tensor(self.length_scale)
107107
length_scale = util.pad_shape_with_ones(
108108
length_scale, ndims=example_ndims)
109-
log_kernel /= length_scale ** 2
109+
log_kernel = log_kernel / length_scale ** 2
110110

111111
if self.amplitude is not None:
112112
amplitude = tf.convert_to_tensor(self.amplitude)
113113
amplitude = util.pad_shape_with_ones(amplitude, ndims=example_ndims)
114-
log_kernel += 2. * tf.math.log(amplitude)
114+
log_kernel = log_kernel + 2. * tf.math.log(amplitude)
115115
return tf.exp(log_kernel)
116116

117117
@property

tensorflow_probability/python/math/psd_kernels/exp_sin_squared_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
@test_util.test_all_tf_execution_regimes
2828
class ExpSinSquaredTest(test_util.TestCase):
2929

30+
@test_util.disable_test_for_backend(
31+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
3032
def testMismatchedFloatTypesAreBad(self):
3133
tfp.math.psd_kernels.ExpSinSquared(1, 1) # Should be OK (float32 fallback).
3234
tfp.math.psd_kernels.ExpSinSquared(1, np.float64(1)) # Should be OK.

tensorflow_probability/python/math/psd_kernels/exponentiated_quadratic_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
@test_util.test_all_tf_execution_regimes
2828
class ExponentiatedQuadraticTest(test_util.TestCase):
2929

30+
@test_util.disable_test_for_backend(
31+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
3032
def testMismatchedFloatTypesAreBad(self):
3133
tfp.math.psd_kernels.ExponentiatedQuadratic(
3234
1, 1) # Should be OK (float32 fallback).

tensorflow_probability/python/math/psd_kernels/matern_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class _MaternTestCase(test_util.TestCase):
3131
Subclasses must specify _kernel_type and _numpy_kernel.
3232
"""
3333

34+
@test_util.disable_test_for_backend(
35+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
3436
def testMismatchedFloatTypesAreBad(self):
3537
self._kernel_type(1., 1) # Should be OK (float32 fallback).
3638
self._kernel_type(1., np.float64(1.)) # Should be OK.

tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def testStaticBatchShape(self, params, shape):
195195
@parameterized.named_parameters(
196196
('Dynamic-shape [2] kernel', [1., 2.], [2]),
197197
('Dynamic-shape [2, 1] kernel', [[1.], [2.]], [2, 1]))
198+
@test_util.disable_test_for_backend(
199+
disable_numpy=True, reason='No dynamic shapes.')
198200
def testDynamicBatchShape(self, params, shape):
199201
tensor_params = tf1.placeholder_with_default(params, shape=None)
200202
k = TestKernel(tensor_params)
@@ -231,6 +233,8 @@ def testApplyOutputWithStaticShapes(self):
231233
y # shape [3, 3]
232234
).shape)
233235

236+
@test_util.disable_test_for_backend(
237+
disable_numpy=True, reason='No dynamic shapes in numpy.')
234238
def testApplyOutputWithDynamicShapes(self):
235239
params_2_dynamic = tf1.placeholder_with_default([1., 2.], shape=None)
236240
k = TestKernel(params_2_dynamic)
@@ -437,6 +441,8 @@ def testDynamicShapesAndValuesOfProduct(self):
437441
for k in product_kernel[:, :1].kernels]),
438442
self.evaluate(product_kernel[:, :1].matrix(x, y)))
439443

444+
@test_util.disable_test_for_backend(
445+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
440446
def testSumOfKernelsWithNoneDtypes(self):
441447
none_kernel = TestKernel()
442448
float32_kernel = TestKernel(np.float32(1))
@@ -451,6 +457,8 @@ def testSumOfKernelsWithNoneDtypes(self):
451457
with self.assertRaises(TypeError):
452458
_ = float32_kernel + float64_kernel
453459

460+
@test_util.disable_test_for_backend(
461+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
454462
def testProductOfKernelsWithNoneDtypes(self):
455463
none_kernel = TestKernel()
456464
float32_kernel = TestKernel(np.float32(1))

tensorflow_probability/python/math/psd_kernels/rational_quadratic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _apply_with_distance(
148148
scale_mixture_rate = tf.convert_to_tensor(self.scale_mixture_rate)
149149
power = util.pad_shape_with_ones(
150150
scale_mixture_rate, ndims=example_ndims)
151-
pairwise_square_distance /= power
151+
pairwise_square_distance = pairwise_square_distance / power
152152

153153
log_result = -power * tf.math.log1p(pairwise_square_distance)
154154

tensorflow_probability/python/math/psd_kernels/rational_quadratic_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _rational_quadratic(
3232
return (amplitude ** 2) * (1. + np.sum((x - y) ** 2) / (
3333
2 * scale_mixture_rate * length_scale ** 2)) ** (-scale_mixture_rate)
3434

35+
@test_util.disable_test_for_backend(
36+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
3537
def testMismatchedFloatTypesAreBad(self):
3638
with self.assertRaises(TypeError):
3739
tfp.math.psd_kernels.RationalQuadratic(np.float32(1.), np.float64(1.))

tensorflow_probability/python/math/psd_kernels/schur_complement_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def _broadcast_2(s1, s2):
5656
@test_util.test_all_tf_execution_regimes
5757
class SchurComplementTest(test_util.TestCase):
5858

59+
@test_util.disable_test_for_backend(
60+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
5961
def testMismatchedFloatTypesAreBad(self):
6062
base_kernel = tfpk.ExponentiatedQuadratic(
6163
np.float64(5.), np.float64(.2))
@@ -255,6 +257,8 @@ def testNoneFixedInputs(self):
255257
self.evaluate(base_kernel.matrix(x, y)),
256258
self.evaluate(schur.matrix(x, y)))
257259

260+
@test_util.disable_test_for_backend(
261+
disable_numpy=True, reason='DType mismatch not caught in numpy.')
258262
def testBaseKernelNoneDtype(self):
259263
# Test that we don't have problems when base_kernel has no explicit dtype
260264
# (ie, params are all None), but fixed_inputs has a different dtype than the

0 commit comments

Comments
 (0)