Skip to content

Commit f58370f

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Convert PSDKernels to AutoCompositeTensor.
PiperOrigin-RevId: 375517074
1 parent 1f2b8e0 commit f58370f

15 files changed

+166
-31
lines changed

tensorflow_probability/python/distributions/variational_gaussian_process.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@
3535
from tensorflow_probability.python.internal import parameter_properties
3636
from tensorflow_probability.python.internal import tensor_util
3737
from tensorflow_probability.python.internal import tensorshape_util
38-
from tensorflow_probability.python.math import psd_kernels as tfpk
38+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
3939
from tensorflow_probability.python.math.psd_kernels.internal import util as kernel_util
4040

4141
__all__ = [
4242
'VariationalGaussianProcess',
4343
]
4444

4545

46-
class _VariationalKernel(tfpk.PositiveSemidefiniteKernel):
46+
@psd_kernel.auto_composite_tensor_psd_kernel
47+
class _VariationalKernel(psd_kernel.AutoCompositeTensorPsdKernel):
4748
"""A PSDKernel which computes the variational kernel from [Titsias, 2009].
4849
4950
The VariationalGaussianProcess can be cast as a special case of
@@ -129,12 +130,12 @@ def __init__(self,
129130
inducing_index_points, dtype=dtype, name='inducing_index_points')
130131
self._variational_scale = tensor_util.convert_nonref_to_tensor(
131132
variational_scale, dtype=dtype, name='variational_scale')
132-
jitter = tensor_util.convert_nonref_to_tensor(
133+
self._jitter = tensor_util.convert_nonref_to_tensor(
133134
jitter, dtype=dtype, name='jitter')
134135

135136
def _compute_chol_kzz(z):
136137
kzz = base_kernel.matrix(z, z)
137-
result = tf.linalg.cholesky(_add_diagonal_shift(kzz, jitter))
138+
result = tf.linalg.cholesky(_add_diagonal_shift(kzz, self._jitter))
138139
return result
139140

140141
# Somewhat confusingly, but for the sake of brevity, we use `var` to refer

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ multi_substrate_py_library(
5555
deps = [
5656
# six dep,
5757
# tensorflow dep,
58+
"//tensorflow_probability/python/internal:auto_composite_tensor",
5859
"//tensorflow_probability/python/internal:dtype_util",
5960
"//tensorflow_probability/python/internal:tensorshape_util",
6061
],
@@ -66,9 +67,11 @@ multi_substrate_py_test(
6667
srcs = ["positive_semidefinite_kernel_test.py"],
6768
numpy_tags = ["notap"],
6869
deps = [
70+
":positive_semidefinite_kernel",
6971
# absl/testing:parameterized dep,
7072
# tensorflow dep,
7173
"//tensorflow_probability",
74+
"//tensorflow_probability/python/internal:auto_composite_tensor",
7275
"//tensorflow_probability/python/internal:test_util",
7376
"//tensorflow_probability/python/math/psd_kernels/internal:util",
7477
],

tensorflow_probability/python/math/psd_kernels/exp_sin_squared.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323

2424
from tensorflow_probability.python.internal import assert_util
2525
from tensorflow_probability.python.internal import tensor_util
26+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2627
from tensorflow_probability.python.math.psd_kernels.internal import util
27-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
28-
2928
__all__ = ['ExpSinSquared']
3029

3130

32-
class ExpSinSquared(PositiveSemidefiniteKernel):
31+
@psd_kernel.auto_composite_tensor_psd_kernel
32+
class ExpSinSquared(psd_kernel.AutoCompositeTensorPsdKernel):
3333
"""Exponentiated Sine Squared Kernel.
3434
3535
Also known as the "Periodic" kernel, this kernel, when

tensorflow_probability/python/math/psd_kernels/exponentiated_quadratic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222

2323
from tensorflow_probability.python.internal import assert_util
2424
from tensorflow_probability.python.internal import tensor_util
25+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2526
from tensorflow_probability.python.math.psd_kernels.internal import util
26-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
2727

2828

2929
__all__ = ['ExponentiatedQuadratic']
3030

3131

32-
class ExponentiatedQuadratic(PositiveSemidefiniteKernel):
32+
@psd_kernel.auto_composite_tensor_psd_kernel
33+
class ExponentiatedQuadratic(psd_kernel.AutoCompositeTensorPsdKernel):
3334
"""The ExponentiatedQuadratic kernel.
3435
3536
Sometimes called the "squared exponential", "Gaussian" or "radial basis

tensorflow_probability/python/math/psd_kernels/feature_scaled.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@
2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.internal import assert_util
2323
from tensorflow_probability.python.internal import tensor_util
24-
from tensorflow_probability.python.math.psd_kernels.feature_transformed import FeatureTransformed
24+
from tensorflow_probability.python.math.psd_kernels import feature_transformed
25+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2526
from tensorflow_probability.python.math.psd_kernels.internal import util
2627

2728
__all__ = ['FeatureScaled']
2829

2930

3031
# TODO(b/132103412): Support more general scaling via LinearOperator, along with
3132
# scaling all feature dimensions.
32-
class FeatureScaled(FeatureTransformed):
33+
@psd_kernel.auto_composite_tensor_psd_kernel
34+
class FeatureScaled(feature_transformed.FeatureTransformed):
3335
"""Kernel that first rescales all feature dimensions.
3436
3537
Given a kernel `k` and `scale_diag` and inputs `x` and `y`, this kernel

tensorflow_probability/python/math/psd_kernels/feature_transformed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020

2121
import tensorflow.compat.v2 as tf
2222

23-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
23+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2424

2525
__all__ = ['FeatureTransformed']
2626

2727

28-
class FeatureTransformed(PositiveSemidefiniteKernel):
28+
@psd_kernel.auto_composite_tensor_psd_kernel
29+
class FeatureTransformed(psd_kernel.AutoCompositeTensorPsdKernel):
2930
"""Input transformed kernel.
3031
3132
Given a kernel `k` and function `f`, compute `k_{new}(x, y) = k(f(x), f(y))`.

tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from tensorflow_probability.python.internal import assert_util
2626
from tensorflow_probability.python.internal import tensor_util
2727
from tensorflow_probability.python.math.psd_kernels import feature_transformed
28+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2829
from tensorflow_probability.python.math.psd_kernels.internal import util
2930

3031

3132
__all__ = ['KumaraswamyTransformed']
3233

3334

35+
@psd_kernel.auto_composite_tensor_psd_kernel
3436
class KumaraswamyTransformed(feature_transformed.FeatureTransformed):
3537
"""Transform inputs by Kumaraswamy bijector.
3638

tensorflow_probability/python/math/psd_kernels/matern.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from tensorflow_probability.python.internal import dtype_util
2626
from tensorflow_probability.python.internal import tensor_util
2727
from tensorflow_probability.python.math import bessel as tfp_math
28+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2829
from tensorflow_probability.python.math.psd_kernels.internal import util
29-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
3030

3131
__all__ = [
3232
'GeneralizedMatern',
@@ -111,7 +111,9 @@ def _tensor(self, x1, x2, x1_example_ndims, x2_example_ndims):
111111
example_ndims=(x1_example_ndims + x2_example_ndims))
112112

113113

114-
class GeneralizedMatern(_AmplitudeLengthScaleMixin, PositiveSemidefiniteKernel):
114+
@psd_kernel.auto_composite_tensor_psd_kernel
115+
class GeneralizedMatern(_AmplitudeLengthScaleMixin,
116+
psd_kernel.AutoCompositeTensorPsdKernel):
115117
"""Generalized Matern Kernel.
116118
117119
This kernel parameterizes the Matern family of kernels.
@@ -224,7 +226,9 @@ def _parameter_control_dependencies(self, is_init):
224226
return assertions
225227

226228

227-
class MaternOneHalf(_AmplitudeLengthScaleMixin, PositiveSemidefiniteKernel):
229+
@psd_kernel.auto_composite_tensor_psd_kernel
230+
class MaternOneHalf(_AmplitudeLengthScaleMixin,
231+
psd_kernel.AutoCompositeTensorPsdKernel):
228232
"""Matern Kernel with parameter 1/2.
229233
230234
This kernel is part of the Matern family of kernels, with parameter 1/2.
@@ -295,7 +299,9 @@ def _apply_with_distance(
295299
return tf.exp(log_result)
296300

297301

298-
class MaternThreeHalves(_AmplitudeLengthScaleMixin, PositiveSemidefiniteKernel):
302+
@psd_kernel.auto_composite_tensor_psd_kernel
303+
class MaternThreeHalves(_AmplitudeLengthScaleMixin,
304+
psd_kernel.AutoCompositeTensorPsdKernel):
299305
"""Matern Kernel with parameter 3/2.
300306
301307
This kernel is part of the Matern family of kernels, with parameter 3/2.
@@ -363,7 +369,9 @@ def _apply_with_distance(
363369
return tf.exp(log_result)
364370

365371

366-
class MaternFiveHalves(_AmplitudeLengthScaleMixin, PositiveSemidefiniteKernel):
372+
@psd_kernel.auto_composite_tensor_psd_kernel
373+
class MaternFiveHalves(_AmplitudeLengthScaleMixin,
374+
psd_kernel.AutoCompositeTensorPsdKernel):
367375
"""Matern 5/2 Kernel.
368376
369377
This kernel is part of the Matern family of kernels, with parameter 5/2.

tensorflow_probability/python/math/psd_kernels/parabolic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
from tensorflow_probability.python.internal import assert_util
2424
from tensorflow_probability.python.internal import prefer_static as ps
2525
from tensorflow_probability.python.internal import tensor_util
26+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2627
from tensorflow_probability.python.math.psd_kernels.internal import util
27-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
2828

2929
__all__ = ['Parabolic']
3030

3131

32-
class Parabolic(PositiveSemidefiniteKernel):
32+
@psd_kernel.auto_composite_tensor_psd_kernel
33+
class Parabolic(psd_kernel.AutoCompositeTensorPsdKernel):
3334
"""The Parabolic kernel.
3435
3536
```none

tensorflow_probability/python/math/psd_kernels/polynomial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
from tensorflow_probability.python.internal import assert_util
2626
from tensorflow_probability.python.internal import tensor_util
27+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2728
from tensorflow_probability.python.math.psd_kernels.internal import util
28-
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
2929

3030
__all__ = [
3131
'Linear',
@@ -45,7 +45,8 @@ def _maybe_shape_dynamic(tensor):
4545
return tf.shape(tensor)
4646

4747

48-
class Polynomial(PositiveSemidefiniteKernel):
48+
@psd_kernel.auto_composite_tensor_psd_kernel
49+
class Polynomial(psd_kernel.AutoCompositeTensorPsdKernel):
4950
"""Polynomial Kernel.
5051
5152
Is based on the dot product covariance function and can be obtained
@@ -213,6 +214,7 @@ def _parameter_control_dependencies(self, is_init):
213214
return assertions
214215

215216

217+
@psd_kernel.auto_composite_tensor_psd_kernel
216218
class Linear(Polynomial):
217219
"""Linear Kernel.
218220

0 commit comments

Comments
 (0)