Skip to content

Commit 273eba3

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add tfp.math.psd_kernels.SpectralMixture.
PiperOrigin-RevId: 427396648
1 parent 20c4ce2 commit 273eba3

File tree

5 files changed

+461
-1
lines changed

5 files changed

+461
-1
lines changed

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ multi_substrate_py_library(
4747
":positive_semidefinite_kernel",
4848
":rational_quadratic",
4949
":schur_complement",
50+
":spectral_mixture",
5051
"//tensorflow_probability/python/internal:all_util",
5152
"//tensorflow_probability/python/internal:dtype_util",
5253
"//tensorflow_probability/python/math/psd_kernels/internal",
@@ -301,6 +302,37 @@ multi_substrate_py_test(
301302
],
302303
)
303304

305+
multi_substrate_py_library(
306+
name = "spectral_mixture",
307+
srcs = ["spectral_mixture.py"],
308+
deps = [
309+
":positive_semidefinite_kernel",
310+
# numpy dep,
311+
# tensorflow dep,
312+
"//tensorflow_probability/python/internal:assert_util",
313+
"//tensorflow_probability/python/internal:dtype_util",
314+
"//tensorflow_probability/python/internal:parameter_properties",
315+
"//tensorflow_probability/python/internal:tensor_util",
316+
"//tensorflow_probability/python/internal:tensorshape_util",
317+
"//tensorflow_probability/python/math:generic",
318+
"//tensorflow_probability/python/math/psd_kernels/internal:util",
319+
],
320+
)
321+
322+
multi_substrate_py_test(
323+
name = "spectral_mixture_test",
324+
size = "small",
325+
srcs = ["spectral_mixture_test.py"],
326+
jax_size = "medium",
327+
deps = [
328+
# absl/testing:parameterized dep,
329+
# numpy dep,
330+
# tensorflow dep,
331+
"//tensorflow_probability",
332+
"//tensorflow_probability/python/internal:test_util",
333+
],
334+
)
335+
304336
multi_substrate_py_library(
305337
name = "feature_scaled",
306338
srcs = ["feature_scaled.py"],

tensorflow_probability/python/math/psd_kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
3636
from tensorflow_probability.python.math.psd_kernels.rational_quadratic import RationalQuadratic
3737
from tensorflow_probability.python.math.psd_kernels.schur_complement import SchurComplement
38+
from tensorflow_probability.python.math.psd_kernels.spectral_mixture import SpectralMixture
3839

3940
_allowed_symbols = [
4041
'AutoCompositeTensorPsdKernel',
@@ -57,6 +58,7 @@
5758
'PositiveSemidefiniteKernel',
5859
'RationalQuadratic',
5960
'SchurComplement',
61+
'SpectralMixture',
6062
]
6163

6264
all_util.remove_undocumented(__name__, _allowed_symbols)

tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
'FeatureScaled',
3939
'KumaraswamyTransformed',
4040
'PointwiseExponential',
41-
'SchurComplement'
41+
'SchurComplement',
42+
'SpectralMixture',
4243
]
4344

4445

@@ -751,6 +752,7 @@ def schur_complements(
751752
'fixed_inputs': fixed_inputs,
752753
'diag_shift': diag_shift
753754
}
755+
754756
for param_name in schur_complement_params:
755757
if enable_vars and draw(hps.booleans()):
756758
kernel_variable_names.append(param_name)
@@ -768,6 +770,96 @@ def schur_complements(
768770
return result_kernel, kernel_variable_names
769771

770772

773+
@hps.composite
774+
def spectral_mixtures(
775+
draw,
776+
batch_shape=None,
777+
event_dim=None,
778+
feature_dim=None,
779+
feature_ndims=None,
780+
enable_vars=None,
781+
depth=None):
782+
"""Strategy for drawing `SpectralMixture` kernels.
783+
784+
The underlying kernel is drawn from the `kernels` strategy.
785+
786+
Args:
787+
draw: Hypothesis strategy sampler supplied by `@hps.composite`.
788+
batch_shape: An optional `TensorShape`. The batch shape of the resulting
789+
Kernel. Hypothesis will pick a batch shape if omitted.
790+
event_dim: Optional Python int giving the size of each of the
791+
kernel's parameters' event dimensions. This is shared across all
792+
parameters, permitting square event matrices, compatible location and
793+
scale Tensors, etc. If omitted, Hypothesis will choose one.
794+
feature_dim: Optional Python int giving the size of each feature dimension.
795+
If omitted, Hypothesis will choose one.
796+
feature_ndims: Optional Python int stating the number of feature dimensions
797+
inputs will have. If omitted, Hypothesis will choose one.
798+
enable_vars: TODO(bjp): Make this `True` all the time and put variable
799+
initialization in slicing_test. If `False`, the returned parameters are
800+
all Tensors, never Variables or DeferredTensor.
801+
depth: Python `int` giving maximum nesting depth of compound kernel.
802+
803+
Returns:
804+
kernels: A strategy for drawing `SchurComplement` kernels with the specified
805+
`batch_shape` (or an arbitrary one if omitted).
806+
"""
807+
if depth is None:
808+
depth = draw(depths())
809+
if batch_shape is None:
810+
batch_shape = draw(tfp_hps.shapes())
811+
if event_dim is None:
812+
event_dim = draw(hps.integers(min_value=2, max_value=6))
813+
if feature_dim is None:
814+
feature_dim = draw(hps.integers(min_value=2, max_value=6))
815+
if feature_ndims is None:
816+
feature_ndims = draw(hps.integers(min_value=2, max_value=6))
817+
818+
num_mixtures = draw(hps.integers(min_value=2, max_value=5))
819+
820+
logits = draw(kernel_input(
821+
batch_shape=batch_shape,
822+
example_ndims=0,
823+
feature_dim=num_mixtures,
824+
feature_ndims=1))
825+
826+
locs = draw(kernel_input(
827+
batch_shape=batch_shape,
828+
example_ndims=1,
829+
example_dim=num_mixtures,
830+
feature_dim=feature_dim,
831+
feature_ndims=feature_ndims))
832+
833+
scales = tfp_hps.softplus_plus_eps()(draw(kernel_input(
834+
batch_shape=batch_shape,
835+
example_ndims=1,
836+
example_dim=num_mixtures,
837+
feature_dim=feature_dim,
838+
feature_ndims=feature_ndims)))
839+
840+
hp.note(f'Forming SpectralMixture kernel with logits: {logits} '
841+
f'locs: {locs} and scales: {scales}')
842+
843+
spectral_mixture_params = {'locs': locs, 'logits': logits, 'scales': scales}
844+
845+
kernel_variable_names = []
846+
for param_name in spectral_mixture_params:
847+
if enable_vars and draw(hps.booleans()):
848+
kernel_variable_names.append(param_name)
849+
spectral_mixture_params[param_name] = tf.Variable(
850+
spectral_mixture_params[param_name], name=param_name)
851+
if draw(hps.booleans()):
852+
spectral_mixture_params[param_name] = tfp_hps.defer_and_count_usage(
853+
spectral_mixture_params[param_name])
854+
result_kernel = tfpk.SpectralMixture(
855+
logits=spectral_mixture_params['logits'],
856+
locs=spectral_mixture_params['locs'],
857+
scales=spectral_mixture_params['scales'],
858+
feature_ndims=feature_ndims,
859+
validate_args=True)
860+
return result_kernel, kernel_variable_names
861+
862+
771863
@hps.composite
772864
def base_kernels(
773865
draw,
@@ -932,6 +1024,14 @@ def kernels(
9321024
feature_ndims=feature_ndims,
9331025
enable_vars=enable_vars,
9341026
depth=depth))
1027+
elif kernel_name == 'SpectralMixture':
1028+
return draw(spectral_mixtures(
1029+
batch_shape=batch_shape,
1030+
event_dim=event_dim,
1031+
feature_dim=feature_dim,
1032+
feature_ndims=feature_ndims,
1033+
enable_vars=enable_vars,
1034+
depth=depth))
9351035

9361036
raise ValueError('Kernel name {} not found.'.format(kernel_name))
9371037

@@ -952,6 +1052,7 @@ def constrain_to_range(low, high):
9521052
'concentration0': constrain_to_range(1., 2.),
9531053
'concentration1': constrain_to_range(1., 2.),
9541054
'df': constrain_to_range(2., 5.),
1055+
'scales': constrain_to_range(1., 2.),
9551056
'slope_variance': constrain_to_range(0.1, 0.5),
9561057
'exponent': lambda x: tf.math.floor(constrain_to_range(1, 4.)(x)),
9571058
'length_scale': constrain_to_range(1., 6.),
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""The SpectralMixture kernel."""
16+
17+
import numpy as np
18+
import tensorflow.compat.v2 as tf
19+
20+
from tensorflow_probability.python.internal import assert_util
21+
from tensorflow_probability.python.internal import dtype_util
22+
from tensorflow_probability.python.internal import parameter_properties
23+
from tensorflow_probability.python.internal import prefer_static as ps
24+
from tensorflow_probability.python.internal import tensor_util
25+
from tensorflow_probability.python.math import generic as tfp_math
26+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
27+
from tensorflow_probability.python.math.psd_kernels.internal import util
28+
29+
30+
__all__ = ['SpectralMixture']
31+
32+
33+
class SpectralMixture(psd_kernel.AutoCompositeTensorPsdKernel):
34+
"""The SpectralMixture kernel.
35+
36+
This kernel is derived from parameterizing the spectral density of a
37+
stationary kernel by a mixture of `m` diagonal multivariate normal
38+
distributions [1].
39+
40+
This in turn parameterizes the following kernel:
41+
42+
```none
43+
k(x, y) = sum_j w[j] (prod_i
44+
exp(-2 * (pi * (x[i] - y[i]) * s[j][i])**2) *
45+
cos(2 * pi * (x[i] - y[i]) * m[j][i]))
46+
```
47+
48+
where:
49+
* `j` is the number of mixtures (as mentioned above).
50+
* `w[j]` are the mixture weights.
51+
* `m[j]` and `s[j]` parameterize a `MultivariateNormalDiag(m[j], s[j])`.
52+
In other words, they are the mean and diagonal scale for each mixture
53+
component.
54+
55+
NOTE: This kernel can result in negative off-diagonal entries.
56+
57+
#### References
58+
[1]: A. Wilson, R. P. Adams.
59+
Gaussian Process Kernels for Pattern Discovery and Extrapolation.
60+
https://arxiv.org/abs/1302.4245
61+
"""
62+
63+
def __init__(self,
64+
logits,
65+
locs,
66+
scales,
67+
feature_ndims=1,
68+
validate_args=False,
69+
name='SpectralMixture'):
70+
"""Construct a SpectralMixture kernel instance.
71+
72+
Args:
73+
logits: Floating-point `Tensor` of shape `[..., M]`, whose softmax
74+
represents the mixture weights for the spectral density. Must
75+
be broadcastable with `locs` and `scales`.
76+
locs: Floating-point `Tensor` of shape `[..., M, F1, F2, ... FN]`, which
77+
represents the location parameter of each of the `M` mixture components.
78+
`N` is `feature_ndims`. Must be broadcastable with `logits` and
79+
`scales`.
80+
scales: Positive Floating-point `Tensor` of shape
81+
`[..., M, F1, F2, ..., FN]`, which represents the scale parameter of
82+
each of the `M` mixture components. `N` is `feature_ndims`. Must be
83+
broadcastable with `locs` and `logits`. These parameters act like
84+
inverse length scale parameters.
85+
feature_ndims: Python `int` number of rightmost dims to include in the
86+
squared difference norm in the exponential.
87+
validate_args: If `True`, parameters are checked for validity despite
88+
possibly degrading runtime performance
89+
name: Python `str` name prefixed to Ops created by this class.
90+
"""
91+
parameters = dict(locals())
92+
with tf.name_scope(name):
93+
dtype = util.maybe_get_common_dtype([logits, locs, scales])
94+
self._logits = tensor_util.convert_nonref_to_tensor(
95+
logits, name='logits', dtype=dtype)
96+
self._locs = tensor_util.convert_nonref_to_tensor(
97+
locs, name='locs', dtype=dtype)
98+
self._scales = tensor_util.convert_nonref_to_tensor(
99+
scales, name='scales', dtype=dtype)
100+
super(SpectralMixture, self).__init__(
101+
feature_ndims,
102+
dtype=dtype,
103+
name=name,
104+
validate_args=validate_args,
105+
parameters=parameters)
106+
107+
@property
108+
def logits(self):
109+
"""Logits parameter."""
110+
return self._logits
111+
112+
@property
113+
def locs(self):
114+
"""Location parameter."""
115+
return self._locs
116+
117+
@property
118+
def scales(self):
119+
"""Scale parameter."""
120+
return self._scales
121+
122+
@classmethod
123+
def _parameter_properties(cls, dtype):
124+
from tensorflow_probability.python.bijectors import softplus # pylint:disable=g-import-not-at-top
125+
return dict(
126+
logits=parameter_properties.ParameterProperties(event_ndims=1),
127+
locs=parameter_properties.ParameterProperties(
128+
event_ndims=lambda self: self.feature_ndims + 1),
129+
scales=parameter_properties.ParameterProperties(
130+
event_ndims=lambda self: self.feature_ndims + 1,
131+
default_constraining_bijector_fn=(
132+
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))))
133+
134+
def _apply_with_distance(
135+
self, x1, x2, pairwise_square_distance, example_ndims=0):
136+
exponent = -2. * pairwise_square_distance
137+
locs = util.pad_shape_with_ones(
138+
self.locs, ndims=example_ndims, start=-(self.feature_ndims + 1))
139+
cos_coeffs = tf.math.cos(2 * np.pi * (x1 - x2) * locs)
140+
feature_ndims = ps.cast(self.feature_ndims, ps.rank(cos_coeffs).dtype)
141+
reduction_axes = ps.range(
142+
ps.rank(cos_coeffs) - feature_ndims, ps.rank(cos_coeffs))
143+
coeff_sign = tf.math.reduce_prod(
144+
tf.math.sign(cos_coeffs), axis=reduction_axes)
145+
log_cos_coeffs = tf.math.reduce_sum(
146+
tf.math.log(tf.math.abs(cos_coeffs)), axis=reduction_axes)
147+
148+
logits = util.pad_shape_with_ones(
149+
self.logits, ndims=example_ndims, start=-1)
150+
151+
log_result, sign = tfp_math.reduce_weighted_logsumexp(
152+
exponent + log_cos_coeffs + logits,
153+
coeff_sign, return_sign=True, axis=-(example_ndims + 1))
154+
155+
return sign * tf.math.exp(log_result)
156+
157+
def _apply(self, x1, x2, example_ndims=0):
158+
# Add an extra dimension to x1 and x2 so it broadcasts with scales.
159+
# [B1, ...., E1, ...., E2, M, F1, ..., F2]
160+
x1 = util.pad_shape_with_ones(
161+
x1, ndims=1, start=-(self.feature_ndims + example_ndims + 1))
162+
x2 = util.pad_shape_with_ones(
163+
x2, ndims=1, start=-(self.feature_ndims + example_ndims + 1))
164+
scales = util.pad_shape_with_ones(
165+
self.scales, ndims=example_ndims, start=-(self.feature_ndims + 1))
166+
pairwise_square_distance = util.sum_rightmost_ndims_preserving_shape(
167+
tf.math.square(np.pi * (x1 - x2) * scales), ndims=self.feature_ndims)
168+
return self._apply_with_distance(
169+
x1, x2, pairwise_square_distance, example_ndims=example_ndims)
170+
171+
def _matrix(self, x1, x2):
172+
# Add an extra dimension to x1 and x2 so it broadcasts with scales.
173+
x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 2))
174+
x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2))
175+
scales = util.pad_shape_with_ones(
176+
self.scales, ndims=1, start=-(self.feature_ndims + 1))
177+
pairwise_square_distance = util.pairwise_square_distance_matrix(
178+
np.pi * x1 * scales, np.pi * x2 * scales, self.feature_ndims)
179+
x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 1))
180+
x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2))
181+
# Expand `x1` and `x2` so that the broadcast against each other.
182+
return self._apply_with_distance(
183+
x1, x2, pairwise_square_distance, example_ndims=2)
184+
185+
def _parameter_control_dependencies(self, is_init):
186+
if not self.validate_args:
187+
return []
188+
assertions = []
189+
if is_init != tensor_util.is_ref(self._scales):
190+
assertions.append(assert_util.assert_positive(
191+
self._scales,
192+
message='`scales` must be positive.'))
193+
return assertions

0 commit comments

Comments
 (0)