Skip to content

Commit 58ed20e

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Fix dtype handling in tfpk.FeatureScaled.
PiperOrigin-RevId: 474410041
1 parent 397a53e commit 58ed20e

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _softplus(x, name=None): # pylint: disable=unused-argument
964964

965965
sqrt = utils.copy_docstring(
966966
'tf.math.sqrt',
967-
lambda x, name=None: np.sqrt(x))
967+
lambda x, name=None: np.sqrt(_convert_to_tensor(x)))
968968

969969
square = utils.copy_docstring(
970970
'tf.math.square',

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ multi_substrate_py_test(
368368
# absl/testing:parameterized dep,
369369
# numpy dep,
370370
# tensorflow dep,
371+
"//tensorflow_probability/python/internal:dtype_util",
371372
"//tensorflow_probability/python/internal:test_util",
372373
],
373374
)

tensorflow_probability/python/math/psd_kernels/feature_scaled.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def __init__(
7575
raise ValueError(
7676
'Must specify exactly one of `scale_diag` and `inverse_scale_diag`.')
7777
with tf.name_scope(name):
78+
dtype = util.maybe_get_common_dtype(
79+
[kernel, scale_diag, inverse_scale_diag])
7880
self._scale_diag = tensor_util.convert_nonref_to_tensor(
79-
scale_diag, name='scale_diag')
81+
scale_diag, dtype=dtype, name='scale_diag')
8082
self._inverse_scale_diag = tensor_util.convert_nonref_to_tensor(
81-
inverse_scale_diag, name='inverse_scale_diag')
83+
inverse_scale_diag, dtype=dtype, name='inverse_scale_diag')
8284

8385
def rescale_input(x, feature_ndims, example_ndims):
8486
"""Computes `x / scale_diag`."""

tensorflow_probability/python/math/psd_kernels/feature_scaled_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow.compat.v2 as tf
2121

22+
from tensorflow_probability.python.internal import dtype_util
2223
from tensorflow_probability.python.internal import test_util
2324
from tensorflow_probability.python.math.psd_kernels import exponentiated_quadratic
2425
from tensorflow_probability.python.math.psd_kernels import feature_scaled
@@ -50,7 +51,7 @@ def testBatchShape(self):
5051
# Use 3 feature_ndims.
5152
kernel = exponentiated_quadratic.ExponentiatedQuadratic(
5253
amplitude, inner_length_scale, feature_ndims=3)
53-
scale_diag = tf.ones([20, 1, 2, 1, 1, 1])
54+
scale_diag = tf.ones([20, 1, 2, 1, 1, 1], dtype=self.dtype)
5455
ard_kernel = feature_scaled.FeatureScaled(kernel, scale_diag=scale_diag)
5556
self.assertAllEqual([20, 10, 2], ard_kernel.batch_shape)
5657
self.assertAllEqual(
@@ -77,6 +78,7 @@ def testKernelParametersBroadcast(self, feature_ndims, dims):
7778
2, 5, size=([3, 1, 2] + input_shape)).astype(self.dtype)
7879

7980
ard_kernel = feature_scaled.FeatureScaled(kernel, scale_diag=length_scale)
81+
self.assertIs(dtype_util.as_numpy_dtype(ard_kernel.dtype), self.dtype)
8082

8183
x = np.random.uniform(-1, 1, size=input_shape).astype(self.dtype)
8284
y = np.random.uniform(-1, 1, size=input_shape).astype(self.dtype)

0 commit comments

Comments
 (0)