|
| 1 | +# Copyright 2021 The TensorFlow Probability Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================ |
| 15 | +"""The SpectralMixture kernel.""" |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import tensorflow.compat.v2 as tf |
| 19 | + |
| 20 | +from tensorflow_probability.python.internal import assert_util |
| 21 | +from tensorflow_probability.python.internal import dtype_util |
| 22 | +from tensorflow_probability.python.internal import parameter_properties |
| 23 | +from tensorflow_probability.python.internal import prefer_static as ps |
| 24 | +from tensorflow_probability.python.internal import tensor_util |
| 25 | +from tensorflow_probability.python.math import generic as tfp_math |
| 26 | +from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel |
| 27 | +from tensorflow_probability.python.math.psd_kernels.internal import util |
| 28 | + |
| 29 | + |
| 30 | +__all__ = ['SpectralMixture'] |
| 31 | + |
| 32 | + |
| 33 | +class SpectralMixture(psd_kernel.AutoCompositeTensorPsdKernel): |
| 34 | + """The SpectralMixture kernel. |
| 35 | +
|
| 36 | + This kernel is derived from parameterizing the spectral density of a |
| 37 | + stationary kernel by a mixture of `m` diagonal multivariate normal |
| 38 | + distributions [1]. |
| 39 | +
|
| 40 | + This in turn parameterizes the following kernel: |
| 41 | +
|
| 42 | + ```none |
| 43 | + k(x, y) = sum_j w[j] (prod_i |
| 44 | + exp(-2 * (pi * (x[i] - y[i]) * s[j][i])**2) * |
| 45 | + cos(2 * pi * (x[i] - y[i]) * m[j][i])) |
| 46 | + ``` |
| 47 | +
|
| 48 | + where: |
| 49 | + * `j` is the number of mixtures (as mentioned above). |
| 50 | + * `w[j]` are the mixture weights. |
| 51 | + * `m[j]` and `s[j]` parameterize a `MultivariateNormalDiag(m[j], s[j])`. |
| 52 | + In other words, they are the mean and diagonal scale for each mixture |
| 53 | + component. |
| 54 | +
|
| 55 | + NOTE: This kernel can result in negative off-diagonal entries. |
| 56 | +
|
| 57 | + #### References |
| 58 | + [1]: A. Wilson, R. P. Adams. |
| 59 | + Gaussian Process Kernels for Pattern Discovery and Extrapolation. |
| 60 | + https://arxiv.org/abs/1302.4245 |
| 61 | + """ |
| 62 | + |
| 63 | + def __init__(self, |
| 64 | + logits, |
| 65 | + locs, |
| 66 | + scales, |
| 67 | + feature_ndims=1, |
| 68 | + validate_args=False, |
| 69 | + name='SpectralMixture'): |
| 70 | + """Construct a SpectralMixture kernel instance. |
| 71 | +
|
| 72 | + Args: |
| 73 | + logits: Floating-point `Tensor` of shape `[..., M]`, whose softmax |
| 74 | + represents the mixture weights for the spectral density. Must |
| 75 | + be broadcastable with `locs` and `scales`. |
| 76 | + locs: Floating-point `Tensor` of shape `[..., M, F1, F2, ... FN]`, which |
| 77 | + represents the location parameter of each of the `M` mixture components. |
| 78 | + `N` is `feature_ndims`. Must be broadcastable with `logits` and |
| 79 | + `scales`. |
| 80 | + scales: Positive Floating-point `Tensor` of shape |
| 81 | + `[..., M, F1, F2, ..., FN]`, which represents the scale parameter of |
| 82 | + each of the `M` mixture components. `N` is `feature_ndims`. Must be |
| 83 | + broadcastable with `locs` and `logits`. These parameters act like |
| 84 | + inverse length scale parameters. |
| 85 | + feature_ndims: Python `int` number of rightmost dims to include in the |
| 86 | + squared difference norm in the exponential. |
| 87 | + validate_args: If `True`, parameters are checked for validity despite |
| 88 | + possibly degrading runtime performance |
| 89 | + name: Python `str` name prefixed to Ops created by this class. |
| 90 | + """ |
| 91 | + parameters = dict(locals()) |
| 92 | + with tf.name_scope(name): |
| 93 | + dtype = util.maybe_get_common_dtype([logits, locs, scales]) |
| 94 | + self._logits = tensor_util.convert_nonref_to_tensor( |
| 95 | + logits, name='logits', dtype=dtype) |
| 96 | + self._locs = tensor_util.convert_nonref_to_tensor( |
| 97 | + locs, name='locs', dtype=dtype) |
| 98 | + self._scales = tensor_util.convert_nonref_to_tensor( |
| 99 | + scales, name='scales', dtype=dtype) |
| 100 | + super(SpectralMixture, self).__init__( |
| 101 | + feature_ndims, |
| 102 | + dtype=dtype, |
| 103 | + name=name, |
| 104 | + validate_args=validate_args, |
| 105 | + parameters=parameters) |
| 106 | + |
| 107 | + @property |
| 108 | + def logits(self): |
| 109 | + """Logits parameter.""" |
| 110 | + return self._logits |
| 111 | + |
| 112 | + @property |
| 113 | + def locs(self): |
| 114 | + """Location parameter.""" |
| 115 | + return self._locs |
| 116 | + |
| 117 | + @property |
| 118 | + def scales(self): |
| 119 | + """Scale parameter.""" |
| 120 | + return self._scales |
| 121 | + |
| 122 | + @classmethod |
| 123 | + def _parameter_properties(cls, dtype): |
| 124 | + from tensorflow_probability.python.bijectors import softplus # pylint:disable=g-import-not-at-top |
| 125 | + return dict( |
| 126 | + logits=parameter_properties.ParameterProperties(event_ndims=1), |
| 127 | + locs=parameter_properties.ParameterProperties( |
| 128 | + event_ndims=lambda self: self.feature_ndims + 1), |
| 129 | + scales=parameter_properties.ParameterProperties( |
| 130 | + event_ndims=lambda self: self.feature_ndims + 1, |
| 131 | + default_constraining_bijector_fn=( |
| 132 | + lambda: softplus.Softplus(low=dtype_util.eps(dtype))))) |
| 133 | + |
| 134 | + def _apply_with_distance( |
| 135 | + self, x1, x2, pairwise_square_distance, example_ndims=0): |
| 136 | + exponent = -2. * pairwise_square_distance |
| 137 | + locs = util.pad_shape_with_ones( |
| 138 | + self.locs, ndims=example_ndims, start=-(self.feature_ndims + 1)) |
| 139 | + cos_coeffs = tf.math.cos(2 * np.pi * (x1 - x2) * locs) |
| 140 | + feature_ndims = ps.cast(self.feature_ndims, ps.rank(cos_coeffs).dtype) |
| 141 | + reduction_axes = ps.range( |
| 142 | + ps.rank(cos_coeffs) - feature_ndims, ps.rank(cos_coeffs)) |
| 143 | + coeff_sign = tf.math.reduce_prod( |
| 144 | + tf.math.sign(cos_coeffs), axis=reduction_axes) |
| 145 | + log_cos_coeffs = tf.math.reduce_sum( |
| 146 | + tf.math.log(tf.math.abs(cos_coeffs)), axis=reduction_axes) |
| 147 | + |
| 148 | + logits = util.pad_shape_with_ones( |
| 149 | + self.logits, ndims=example_ndims, start=-1) |
| 150 | + |
| 151 | + log_result, sign = tfp_math.reduce_weighted_logsumexp( |
| 152 | + exponent + log_cos_coeffs + logits, |
| 153 | + coeff_sign, return_sign=True, axis=-(example_ndims + 1)) |
| 154 | + |
| 155 | + return sign * tf.math.exp(log_result) |
| 156 | + |
| 157 | + def _apply(self, x1, x2, example_ndims=0): |
| 158 | + # Add an extra dimension to x1 and x2 so it broadcasts with scales. |
| 159 | + # [B1, ...., E1, ...., E2, M, F1, ..., F2] |
| 160 | + x1 = util.pad_shape_with_ones( |
| 161 | + x1, ndims=1, start=-(self.feature_ndims + example_ndims + 1)) |
| 162 | + x2 = util.pad_shape_with_ones( |
| 163 | + x2, ndims=1, start=-(self.feature_ndims + example_ndims + 1)) |
| 164 | + scales = util.pad_shape_with_ones( |
| 165 | + self.scales, ndims=example_ndims, start=-(self.feature_ndims + 1)) |
| 166 | + pairwise_square_distance = util.sum_rightmost_ndims_preserving_shape( |
| 167 | + tf.math.square(np.pi * (x1 - x2) * scales), ndims=self.feature_ndims) |
| 168 | + return self._apply_with_distance( |
| 169 | + x1, x2, pairwise_square_distance, example_ndims=example_ndims) |
| 170 | + |
| 171 | + def _matrix(self, x1, x2): |
| 172 | + # Add an extra dimension to x1 and x2 so it broadcasts with scales. |
| 173 | + x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 2)) |
| 174 | + x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2)) |
| 175 | + scales = util.pad_shape_with_ones( |
| 176 | + self.scales, ndims=1, start=-(self.feature_ndims + 1)) |
| 177 | + pairwise_square_distance = util.pairwise_square_distance_matrix( |
| 178 | + np.pi * x1 * scales, np.pi * x2 * scales, self.feature_ndims) |
| 179 | + x1 = util.pad_shape_with_ones(x1, ndims=1, start=-(self.feature_ndims + 1)) |
| 180 | + x2 = util.pad_shape_with_ones(x2, ndims=1, start=-(self.feature_ndims + 2)) |
| 181 | + # Expand `x1` and `x2` so that the broadcast against each other. |
| 182 | + return self._apply_with_distance( |
| 183 | + x1, x2, pairwise_square_distance, example_ndims=2) |
| 184 | + |
| 185 | + def _parameter_control_dependencies(self, is_init): |
| 186 | + if not self.validate_args: |
| 187 | + return [] |
| 188 | + assertions = [] |
| 189 | + if is_init != tensor_util.is_ref(self._scales): |
| 190 | + assertions.append(assert_util.assert_positive( |
| 191 | + self._scales, |
| 192 | + message='`scales` must be positive.')) |
| 193 | + return assertions |
0 commit comments