|
| 1 | +# Copyright 2021 The TensorFlow Probability Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================ |
| 15 | +"""Additive kernel.""" |
| 16 | + |
| 17 | +import tensorflow.compat.v2 as tf |
| 18 | + |
| 19 | +from tensorflow_probability.python.internal import dtype_util |
| 20 | +from tensorflow_probability.python.internal import parameter_properties |
| 21 | +from tensorflow_probability.python.internal import prefer_static as ps |
| 22 | +from tensorflow_probability.python.internal import tensor_util |
| 23 | +from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel |
| 24 | +from tensorflow_probability.python.math.psd_kernels.internal import util |
| 25 | + |
| 26 | + |
| 27 | +__all__ = [ |
| 28 | + 'AdditiveKernel', |
| 29 | +] |
| 30 | + |
| 31 | + |
| 32 | +class AdditiveKernel(psd_kernel.AutoCompositeTensorPsdKernel): |
| 33 | + """Additive Kernel. |
| 34 | +
|
| 35 | + This kernel has the following form |
| 36 | + ```none |
| 37 | + k(x, y) = sum k_add_i(x, y) |
| 38 | + k_add_n(x, y) = a_n**2 sum_{1<=i1<i2<...in} prod k_i(x[i], y[i]) |
| 39 | + ``` |
| 40 | + Where $k_i$ is the one-dimensional base kernel for the `i`th dimension. |
| 41 | +
|
| 42 | + In other words, this computes sums of elementary symmetric polynomials |
| 43 | + over `k_i(x[i], y[i])`. |
| 44 | +
|
| 45 | + This kernel is very related to the ANOVA kernel defined as: |
| 46 | + `k_{ANOVA}(x, y) = prod (1 + k_i(x[i], x[i])`. `k_{ANOVA}` is |
| 47 | + equivalent to a special case of this kernel where the `amplitudes` are |
| 48 | + all one, along with a constant shift by 1. |
| 49 | +
|
| 50 | + #### References |
| 51 | +
|
| 52 | + [1] D. Duvenaud, H. Nickish, C. E. Rasmussen, Additive Gaussian Process. |
| 53 | + https://hannes.nickisch.org/papers/conferences/duvenaud11gpadditive.pdf |
| 54 | +
|
| 55 | + [2] M. Stitson, A. Gammerman, V. Vapnik, V. Vovk, et al. |
| 56 | + Support Vector Regression with ANOVA Decomposition Kernels |
| 57 | + http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.34.7818&rep=rep1&type=pdf |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + kernel, |
| 63 | + amplitudes, |
| 64 | + validate_args=False, |
| 65 | + name='AdditiveKernel'): |
| 66 | + """Instantiates an `AdditiveKernel`. |
| 67 | +
|
| 68 | + Args: |
| 69 | + kernel: An instance of `PositiveSemidefiniteKernel`s that are defined |
| 70 | + within this class (specifically they allow for reinterpreting |
| 71 | + batch dimensions as feature dimensions) that act on inputs of |
| 72 | + the form `[B1, ...., Bk, D, 1]`; that is, `kernel` is a batch of |
| 73 | + D-kernels, each acting on 1-dimensional inputs. We assume that the |
| 74 | + kernel has a batch dimension broadcastable with `[D]`. `kernel` must |
| 75 | + inherit from `tf.__internal__.CompositeTensor`. |
| 76 | + amplitudes: `Tensor` of shape `[B1, ...., Bk, M]`, where `M` is the order |
| 77 | + of the additive kernel. `M` must be statically identifiable. |
| 78 | + validate_args: Python `bool`, default `False`. When `True` kernel |
| 79 | + parameters are checked for validity despite possibly degrading runtime |
| 80 | + performance. When `False` invalid inputs may silently render incorrect |
| 81 | + outputs. |
| 82 | + name: Python `str` name prefixed to Ops created by this class. Default: |
| 83 | + subclass name. |
| 84 | + Raises: |
| 85 | + TypeError: if `kernel` is not an instance of |
| 86 | + `tf.__internal__.CompositeTensor`. |
| 87 | + """ |
| 88 | + parameters = dict(locals()) |
| 89 | + with tf.name_scope(name): |
| 90 | + if not isinstance(kernel, tf.__internal__.CompositeTensor): |
| 91 | + raise TypeError('`kernel` must inherit from ' |
| 92 | + '`tf.__internal__.CompositeTensor`.') |
| 93 | + dtype = util.maybe_get_common_dtype([kernel, amplitudes]) |
| 94 | + self._kernel = kernel |
| 95 | + self._amplitudes = tensor_util.convert_nonref_to_tensor( |
| 96 | + amplitudes, name='amplitudes', dtype=dtype) |
| 97 | + super(AdditiveKernel, self).__init__( |
| 98 | + feature_ndims=self.kernel.feature_ndims, |
| 99 | + dtype=dtype, |
| 100 | + name=name, |
| 101 | + validate_args=validate_args, |
| 102 | + parameters=parameters) |
| 103 | + |
| 104 | + @property |
| 105 | + def amplitudes(self): |
| 106 | + """Amplitude parameter for each additive kernel.""" |
| 107 | + return self._amplitudes |
| 108 | + |
| 109 | + @property |
| 110 | + def kernel(self): |
| 111 | + """Inner kernel used for scalar kernel computations.""" |
| 112 | + return self._kernel |
| 113 | + |
| 114 | + @classmethod |
| 115 | + def _parameter_properties(cls, dtype, num_classes=None): |
| 116 | + from tensorflow_probability.python.bijectors import softplus as softplus_bijector # pylint:disable=g-import-not-at-top |
| 117 | + return dict( |
| 118 | + amplitudes=parameter_properties.ParameterProperties( |
| 119 | + event_ndims=1, |
| 120 | + default_constraining_bijector_fn=( |
| 121 | + softplus_bijector.Softplus(low=dtype_util.eps(dtype)))), |
| 122 | + kernel=parameter_properties.BatchedComponentProperties(event_ndims=1)) |
| 123 | + |
| 124 | + # Below, the Additive Kernel is computed via a recurrence on elementary |
| 125 | + # symmetric polynomials. |
| 126 | + # Let z_i = k[i](x[i], y[i]) |
| 127 | + # Then we are computing the elementary symmetric polynomials |
| 128 | + # S_n(z_1, ..., z_k) = \sum_i \prod_{1 <= j_1 < j_2, ... < j_n <= k} z_j |
| 129 | + # Elementary symmetric polynomials satisfy the recurrence: |
| 130 | + # S_n(z_1, ..., z_k) = S_n(z_1, ..., z_{k-1}) + |
| 131 | + # z_k * S_{n - 1}(z_1, ..., z_{k - 1}) |
| 132 | + # Thus, we can use dynamic programming to compute the elementary symmetric |
| 133 | + # polynomials over z_i, and use vectorization to do this in a batched way. |
| 134 | + |
| 135 | + def _apply(self, x1, x2, example_ndims=0): |
| 136 | + @tf.recompute_grad |
| 137 | + def _inner_apply(x1, x2): |
| 138 | + order = ps.shape(self.amplitudes)[-1] |
| 139 | + |
| 140 | + def scan_fn(esp, i): |
| 141 | + s = self.kernel[..., i].apply( |
| 142 | + x1[..., i][..., tf.newaxis], |
| 143 | + x2[..., i][..., tf.newaxis], |
| 144 | + example_ndims=example_ndims) |
| 145 | + next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1] |
| 146 | + # Add the zero-th polynomial. |
| 147 | + next_esp = tf.concat( |
| 148 | + [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1) |
| 149 | + return next_esp |
| 150 | + |
| 151 | + batch_shape = ps.broadcast_shape( |
| 152 | + ps.shape(x1)[:-self.kernel.feature_ndims], |
| 153 | + ps.shape(x2)[:-self.kernel.feature_ndims]) |
| 154 | + |
| 155 | + batch_shape = ps.broadcast_shape( |
| 156 | + batch_shape, |
| 157 | + ps.concat([ |
| 158 | + self.batch_shape_tensor(), |
| 159 | + [1] * example_ndims], axis=0)) |
| 160 | + |
| 161 | + initializer = tf.concat( |
| 162 | + [tf.ones(ps.concat([batch_shape, [1]], axis=0), |
| 163 | + dtype=self.dtype), |
| 164 | + tf.zeros(ps.concat([batch_shape, [order]], axis=0), |
| 165 | + dtype=self.dtype)], axis=-1) |
| 166 | + |
| 167 | + esps = tf.scan( |
| 168 | + scan_fn, |
| 169 | + elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32), |
| 170 | + parallel_iterations=32, |
| 171 | + initializer=initializer)[-1, ..., 1:] |
| 172 | + amplitudes = util.pad_shape_with_ones( |
| 173 | + self.amplitudes, ndims=example_ndims, start=-2) |
| 174 | + return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1) |
| 175 | + return _inner_apply(x1, x2) |
| 176 | + |
| 177 | + def _matrix(self, x1, x2): |
| 178 | + @tf.recompute_grad |
| 179 | + def _inner_matrix(x1, x2): |
| 180 | + order = ps.shape(self.amplitudes)[-1] |
| 181 | + |
| 182 | + def scan_fn(esp, i): |
| 183 | + s = self.kernel[..., i].matrix( |
| 184 | + x1[..., i][..., tf.newaxis], x2[..., i][..., tf.newaxis]) |
| 185 | + next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1] |
| 186 | + # Add the zero-th polynomial. |
| 187 | + next_esp = tf.concat( |
| 188 | + [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1) |
| 189 | + return next_esp |
| 190 | + |
| 191 | + batch_shape = ps.broadcast_shape( |
| 192 | + ps.shape(x1)[:-(self.kernel.feature_ndims + 1)], |
| 193 | + ps.shape(x2)[:-(self.kernel.feature_ndims + 1)]) |
| 194 | + batch_shape = ps.broadcast_shape( |
| 195 | + batch_shape, self.batch_shape_tensor()) |
| 196 | + matrix_shape = [ |
| 197 | + ps.shape(x1)[-(self.kernel.feature_ndims + 1)], |
| 198 | + ps.shape(x2)[-(self.kernel.feature_ndims + 1)]] |
| 199 | + total_shape = ps.concat([batch_shape, matrix_shape], axis=0) |
| 200 | + |
| 201 | + initializer = tf.concat( |
| 202 | + [tf.ones(ps.concat([total_shape, [1]], axis=0), |
| 203 | + dtype=self.dtype), |
| 204 | + tf.zeros(ps.concat([total_shape, [order]], axis=0), |
| 205 | + dtype=self.dtype)], axis=-1) |
| 206 | + |
| 207 | + esps = tf.scan( |
| 208 | + scan_fn, |
| 209 | + elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32), |
| 210 | + parallel_iterations=32, |
| 211 | + initializer=initializer)[-1, ..., 1:] |
| 212 | + amplitudes = util.pad_shape_with_ones( |
| 213 | + self.amplitudes, ndims=2, start=-2) |
| 214 | + return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1) |
| 215 | + return _inner_matrix(x1, x2) |
0 commit comments