Skip to content

Commit 6f62a06

Browse files
ColCarrolljburnim
authored andcommitted
Fix gamma_exponential psd kernel so that the covariance is
amplitude**2 * exp(-(||x - x'||**2 / (2 * lengthscale**2))**gamma) rather than amplitude**2 * exp(-(||x - x'||**(2*gamma) / (2 * lengthscale**2))) PiperOrigin-RevId: 547842658
1 parent 8f45820 commit 6f62a06

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ multi_substrate_py_test(
189189
size = "small",
190190
srcs = ["gamma_exponential_test.py"],
191191
deps = [
192+
":exponentiated_quadratic",
192193
":gamma_exponential",
193194
# absl/testing:parameterized dep,
194195
# numpy dep,

tensorflow_probability/python/math/psd_kernels/gamma_exponential.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class GammaExponential(psd_kernel.AutoCompositeTensorPsdKernel):
3535
3636
```none
3737
k(x, y) = amplitude**2 * exp(
38-
-||x - y||**(2 * power) / (2 * length_scale**2))
38+
-(||x - y||**2 / (2 * length_scale**2))**power)
3939
```
4040
4141
where the double-bars represent vector length (ie, Euclidean, or L2 norm).
@@ -150,20 +150,21 @@ def _parameter_properties(cls, dtype):
150150
def _apply_with_distance(
151151
self, x1, x2, pairwise_square_distance, example_ndims=0):
152152

153-
if self.power is not None:
154-
power = tf.convert_to_tensor(self.power)
155-
power = util.pad_shape_with_ones(power, example_ndims)
156-
pairwise_pow_distance = pairwise_square_distance ** power
157-
else:
158-
pairwise_pow_distance = pairwise_square_distance
159-
160-
exponent = -0.5 * pairwise_pow_distance
153+
exponent = 0.5 * pairwise_square_distance
161154
inverse_length_scale = self._inverse_length_scale_parameter()
162155
if inverse_length_scale is not None:
163156
inverse_length_scale = util.pad_shape_with_ones(
164157
inverse_length_scale, example_ndims)
165158
exponent = exponent * tf.math.square(inverse_length_scale)
166159

160+
if self.power is not None:
161+
power = tf.convert_to_tensor(self.power)
162+
power = util.pad_shape_with_ones(power, example_ndims)
163+
exponent = exponent ** power
164+
else:
165+
exponent = pairwise_square_distance
166+
exponent = -1. * exponent
167+
167168
if self.amplitude is not None:
168169
amplitude = tf.convert_to_tensor(self.amplitude)
169170
amplitude = util.pad_shape_with_ones(amplitude, example_ndims)

tensorflow_probability/python/math/psd_kernels/gamma_exponential_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorflow.compat.v2 as tf
2121

2222
from tensorflow_probability.python.internal import test_util
23+
from tensorflow_probability.python.math.psd_kernels import exponentiated_quadratic
2324
from tensorflow_probability.python.math.psd_kernels import gamma_exponential
2425

2526

@@ -58,14 +59,24 @@ def testValuesAreCorrect(self, feature_ndims, dims):
5859
y = np.random.uniform(-1, 1, size=shape).astype(np.float32)
5960
self.assertAllClose(
6061
amplitude ** 2 * np.exp(
61-
-np.float32(.5) * np.sum((x - y)**2)**gamma / length_scale**2),
62+
-(np.sum((x - y)**2) /
63+
(np.float32(2.) * length_scale**2))**gamma),
6264
self.evaluate(k.apply(x, y)))
6365

6466
def testNoneShapes(self):
6567
k = gamma_exponential.GammaExponential(
6668
amplitude=np.reshape(np.arange(12.), [2, 3, 2]))
6769
self.assertAllEqual((2, 3, 2), k.batch_shape)
6870

71+
def testEqualsExponentiatedQuadratic(self):
72+
np.random.seed(42)
73+
ge = gamma_exponential.GammaExponential(
74+
amplitude=3., length_scale=0.5, power=1., feature_ndims=0)
75+
eq = exponentiated_quadratic.ExponentiatedQuadratic(
76+
amplitude=3., length_scale=0.5, feature_ndims=0)
77+
t1, t2 = np.random.rand(2, 10)
78+
self.assertAllClose(ge.apply(t1, t2), eq.apply(t1, t2))
79+
6980
def testShapesAreCorrect(self):
7081
k = gamma_exponential.GammaExponential(
7182
amplitude=1., length_scale=1., power=2.)

0 commit comments

Comments
 (0)