Skip to content

Commit 05b7b9c

Browse files
srvasudetensorflower-gardener
authored andcommitted
Fixes to Polynomial and Linear kernel.
- Enable tests in numpy mode by avoiding mutable numpy operations. - Add parameters for `bias_amplitude` and `slope_amplitude`. Currently the parameters `bias_variance` and `slope_variance` are actually standard devations which is somewhat confusing. This proposes to deprecate them in favor of amplitude parameters which matches the terminology in other places, and act like standard deviations everywhere. PiperOrigin-RevId: 463761934
1 parent 6fb306b commit 05b7b9c

File tree

4 files changed

+177
-88
lines changed

4 files changed

+177
-88
lines changed

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ multi_substrate_py_test(
255255
name = "polynomial_test",
256256
size = "small",
257257
srcs = ["polynomial_test.py"],
258-
numpy_tags = ["notap"],
259258
deps = [
260259
# absl/testing:parameterized dep,
261260
# numpy dep,

tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def _instantiable_base_kernels():
119119
MUTEX_PARAMS = (
120120
set(['length_scale', 'inverse_length_scale']),
121121
set(['scale_diag', 'inverse_scale_diag']),
122+
set(['bias_variance', 'bias_amplitude']),
123+
set(['slope_variance', 'slope_amplitude']),
122124
)
123125

124126
# pylint is unable to handle @hps.composite (e.g. complains "No value for
@@ -1047,12 +1049,14 @@ def constrain_to_range(low, high):
10471049
# well-conditioned. The ranges below were chosen to ensure kernel
10481050
# matrices are positive definite.
10491051
'amplitude': constrain_to_range(1., 2.),
1052+
'bias_amplitude': constrain_to_range(0.1, 0.5),
10501053
'bias_variance': constrain_to_range(0.1, 0.5),
10511054
'constant': constrain_to_range(0.1, 0.5),
10521055
'concentration0': constrain_to_range(1., 2.),
10531056
'concentration1': constrain_to_range(1., 2.),
10541057
'df': constrain_to_range(2., 5.),
10551058
'scales': constrain_to_range(1., 2.),
1059+
'slope_amplitude': constrain_to_range(0.1, 0.5),
10561060
'slope_variance': constrain_to_range(0.1, 0.5),
10571061
'exponent': lambda x: tf.math.floor(constrain_to_range(1, 4.)(x)),
10581062
'length_scale': constrain_to_range(1., 6.),

tensorflow_probability/python/math/psd_kernels/polynomial.py

Lines changed: 123 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_probability.python.internal import tensor_util
2323
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2424
from tensorflow_probability.python.math.psd_kernels.internal import util
25+
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
2526

2627
__all__ = [
2728
'Constant',
@@ -40,7 +41,7 @@ class Polynomial(psd_kernel.AutoCompositeTensorPsdKernel):
4041
or None.
4142
4243
```none
43-
k(x, y) = bias_variance**2 + slope_variance**2 *
44+
k(x, y) = bias_amplitude**2 + slope_amplitude**2 *
4445
((x - shift) dot (y - shift))**exponent
4546
```
4647
@@ -54,9 +55,16 @@ class Polynomial(psd_kernel.AutoCompositeTensorPsdKernel):
5455
5556
"""
5657

58+
@deprecation.deprecated_args(
59+
'2022-11-01',
60+
'`bias_variance` and `slope_variance` are deprecated. Please use '
61+
'`bias_amplitude` and `slope_amplitude` instead.',
62+
'bias_variance', 'slope_variance')
5763
def __init__(self,
5864
bias_variance=None,
5965
slope_variance=None,
66+
bias_amplitude=None,
67+
slope_amplitude=None,
6068
shift=None,
6169
exponent=None,
6270
feature_ndims=1,
@@ -66,28 +74,40 @@ def __init__(self,
6674
"""Construct a Polynomial kernel instance.
6775
6876
Args:
69-
bias_variance: Non-negative floating point `Tensor` that controls the
70-
variance from the origin. If bias = 0, there is no variance and the
77+
bias_variance: Deprecated. Non-negative floating point `Tensor` that
78+
controls the variance from the origin. If bias = 0, there is no
79+
variance and the fitted function goes through the origin. Must be
80+
broadcastable with `slope_variance`, `shift`, `exponent`, and inputs
81+
to `apply` and `matrix` methods. A value of `None` is treated like 0.
82+
Default Value: `None`
83+
slope_variance: Deprecated. Non-negative floating point `Tensor` that
84+
controls the variance of the regression line slope that is the basis
85+
for the polynomial. Must be broadcastable with `bias_variance`, `shift`,
86+
`exponent`, and inputs to `apply` and `matrix` methods. A value of
87+
`None` is treated like 1.
88+
Default Value: `None`
89+
bias_amplitude: Non-negative floating point `Tensor` that controls the
90+
stddev from the origin. If bias = 0, there is no stddev and the
7191
fitted function goes through the origin. Must be broadcastable with
72-
`slope_variance`, `shift`, `exponent`, and inputs to `apply` and
92+
`slope_amplitude`, `shift`, `exponent`, and inputs to `apply` and
7393
`matrix` methods. A value of `None` is treated like 0.
7494
Default Value: `None`
75-
slope_variance: Non-negative floating point `Tensor` that controls the
76-
variance of the regression line slope that is the basis for the
77-
polynomial. Must be broadcastable with `bias_variance`, `shift`,
95+
slope_amplitude: Non-negative floating point `Tensor` that controls the
96+
stddev of the regression line slope that is the basis for the
97+
polynomial. Must be broadcastable with `bias_amplitude`, `shift`,
7898
`exponent`, and inputs to `apply` and `matrix` methods. A value of
7999
`None` is treated like 1.
80100
Default Value: `None`
81101
shift: Floating point `Tensor` that contols the intercept with the x-axis
82102
of the linear function to be exponentiated to get this polynomial. Must
83-
be broadcastable with `bias_variance`, `slope_variance`, `exponent` and
84-
inputs to `apply` and `matrix` methods. A value of `None` is treated
103+
be broadcastable with `bias_amplitude`, `slope_amplitude`, `exponent`
104+
and inputs to `apply` and `matrix` methods. A value of `None` is treated
85105
like 0, which results in having the intercept at the origin.
86106
Default Value: `None`
87107
exponent: Positive floating point `Tensor` that controls the exponent
88108
(also known as the degree) of the polynomial function, and must be an
89109
integer.
90-
Must be broadcastable with `bias_variance`, `slope_variance`, `shift`,
110+
Must be broadcastable with `bias_amplitude`, `slope_amplitude`, `shift`,
91111
and inputs to `apply` and `matrix` methods. A value of `None` is treated
92112
like 1, which results in a linear kernel.
93113
Default Value: `None`
@@ -104,11 +124,20 @@ def __init__(self,
104124
parameters = dict(locals()) if parameters is None else parameters
105125
with tf.name_scope(name):
106126
dtype = util.maybe_get_common_dtype(
107-
[bias_variance, slope_variance, shift, exponent])
127+
[bias_variance,
128+
slope_variance,
129+
bias_amplitude,
130+
slope_amplitude,
131+
shift,
132+
exponent])
108133
self._bias_variance = tensor_util.convert_nonref_to_tensor(
109134
bias_variance, name='bias_variance', dtype=dtype)
110135
self._slope_variance = tensor_util.convert_nonref_to_tensor(
111136
slope_variance, name='slope_variance', dtype=dtype)
137+
self._bias_amplitude = tensor_util.convert_nonref_to_tensor(
138+
bias_amplitude, name='bias_amplitude', dtype=dtype)
139+
self._slope_amplitude = tensor_util.convert_nonref_to_tensor(
140+
slope_amplitude, name='slope_amplitude', dtype=dtype)
112141
self._shift = tensor_util.convert_nonref_to_tensor(
113142
shift, name='shift', dtype=dtype)
114143
self._exponent = tensor_util.convert_nonref_to_tensor(
@@ -124,12 +153,18 @@ def __init__(self,
124153
def _parameter_properties(cls, dtype):
125154
from tensorflow_probability.python.bijectors import softplus # pylint:disable=g-import-not-at-top
126155
return dict(
156+
bias_amplitude=parameter_properties.ParameterProperties(
157+
default_constraining_bijector_fn=(
158+
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
127159
bias_variance=parameter_properties.ParameterProperties(
128160
default_constraining_bijector_fn=(
129161
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
130162
exponent=parameter_properties.ParameterProperties(
131163
default_constraining_bijector_fn=(
132164
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
165+
slope_amplitude=parameter_properties.ParameterProperties(
166+
default_constraining_bijector_fn=(
167+
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
133168
slope_variance=parameter_properties.ParameterProperties(
134169
default_constraining_bijector_fn=(
135170
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
@@ -145,6 +180,16 @@ def slope_variance(self):
145180
"""Variance on slope parameter."""
146181
return self._slope_variance
147182

183+
@property
184+
def bias_amplitude(self):
185+
"""Stddev on bias parameter."""
186+
return self._bias_amplitude
187+
188+
@property
189+
def slope_amplitude(self):
190+
"""Amplitude on slope parameter."""
191+
return self._slope_amplitude
192+
148193
@property
149194
def shift(self):
150195
"""Shift of linear function that is exponentiated."""
@@ -155,6 +200,16 @@ def exponent(self):
155200
"""Exponent of the polynomial term."""
156201
return self._exponent
157202

203+
def _get_bias_amplitude(self):
204+
if self.bias_amplitude is not None:
205+
return self.bias_amplitude
206+
return self.bias_variance
207+
208+
def _get_slope_amplitude(self):
209+
if self.slope_amplitude is not None:
210+
return self.slope_amplitude
211+
return self.slope_variance
212+
158213
def _apply(self, x1, x2, example_ndims=0):
159214
if self.shift is None:
160215
dot_prod = util.sum_rightmost_ndims_preserving_shape(
@@ -169,17 +224,19 @@ def _apply(self, x1, x2, example_ndims=0):
169224
if self.exponent is not None:
170225
exponent = tf.convert_to_tensor(self.exponent)
171226
exponent = util.pad_shape_with_ones(exponent, example_ndims)
172-
dot_prod **= exponent
227+
dot_prod = dot_prod ** exponent
173228

174-
if self.slope_variance is not None:
175-
slope_variance = tf.convert_to_tensor(self.slope_variance)
176-
slope_variance = util.pad_shape_with_ones(slope_variance, example_ndims)
177-
dot_prod *= slope_variance**2.
229+
slope_amplitude = self._get_slope_amplitude()
230+
if slope_amplitude is not None:
231+
slope_amplitude = tf.convert_to_tensor(slope_amplitude)
232+
slope_amplitude = util.pad_shape_with_ones(slope_amplitude, example_ndims)
233+
dot_prod = dot_prod * slope_amplitude**2.
178234

179-
if self.bias_variance is not None:
180-
bias_variance = tf.convert_to_tensor(self.bias_variance)
181-
bias_variance = util.pad_shape_with_ones(bias_variance, example_ndims)
182-
dot_prod += bias_variance**2.
235+
bias_amplitude = self._get_bias_amplitude()
236+
if bias_amplitude is not None:
237+
bias_amplitude = tf.convert_to_tensor(bias_amplitude)
238+
bias_amplitude = util.pad_shape_with_ones(bias_amplitude, example_ndims)
239+
dot_prod = dot_prod + bias_amplitude**2.
183240

184241
return dot_prod
185242

@@ -190,8 +247,8 @@ def _parameter_control_dependencies(self, is_init):
190247
ok_to_check = lambda x: ( # pylint:disable=g-long-lambda
191248
x is not None) and (is_init != tensor_util.is_ref(x))
192249

193-
bias_variance = self.bias_variance
194-
slope_variance = self.slope_variance
250+
bias_amplitude = self._get_bias_amplitude()
251+
slope_amplitude = self._get_slope_amplitude()
195252

196253
if ok_to_check(self.exponent):
197254
exponent = tf.convert_to_tensor(self.exponent)
@@ -202,23 +259,23 @@ def _parameter_control_dependencies(self, is_init):
202259
assertions.append(
203260
distribution_util.assert_integer_form(
204261
exponent, message='`exponent` must be an integer.'))
205-
if ok_to_check(self.bias_variance):
206-
bias_variance = tf.convert_to_tensor(self.bias_variance)
262+
if ok_to_check(bias_amplitude):
263+
bias_amplitude = tf.convert_to_tensor(bias_amplitude)
207264
assertions.append(
208265
assert_util.assert_non_negative(
209-
bias_variance, message='`bias_variance` must be non-negative.'))
210-
if ok_to_check(self.slope_variance):
211-
slope_variance = tf.convert_to_tensor(self.slope_variance)
266+
bias_amplitude, message='`bias_amplitude` must be non-negative.'))
267+
if ok_to_check(slope_amplitude):
268+
slope_amplitude = tf.convert_to_tensor(slope_amplitude)
212269
assertions.append(
213270
assert_util.assert_non_negative(
214-
slope_variance,
215-
message='`slope_variance` must be non-negative.'))
271+
slope_amplitude,
272+
message='`slope_amplitude` must be non-negative.'))
216273

217-
if (ok_to_check(self.bias_variance) and ok_to_check(self.slope_variance)):
274+
if (ok_to_check(self.bias_amplitude) and ok_to_check(self.slope_amplitude)):
218275
assertions.append(
219276
assert_util.assert_positive(
220-
tf.math.abs(slope_variance) + tf.math.abs(bias_variance),
221-
message=('`slope_variance` and `bias_variance` '
277+
tf.math.abs(slope_amplitude) + tf.math.abs(bias_amplitude),
278+
message=('`slope_amplitude` and `bias_amplitude` '
222279
'can not both be zero.')))
223280

224281
return assertions
@@ -234,14 +291,21 @@ class Linear(Polynomial):
234291
exponent.
235292
236293
```none
237-
k(x, y) = bias_variance**2 + slope_variance**2 *
294+
k(x, y) = bias_amplitude**2 + slope_amplitude**2 *
238295
((x - shift) dot (y - shift))
239296
```
240297
"""
241298

299+
@deprecation.deprecated_args(
300+
'2022-11-01',
301+
'`bias_variance` and `slope_variance` are deprecated. Please use '
302+
'`bias_amplitude` and `slope_amplitude` instead.',
303+
'bias_variance', 'slope_variance')
242304
def __init__(self,
243305
bias_variance=None,
244306
slope_variance=None,
307+
bias_amplitude=None,
308+
slope_amplitude=None,
245309
shift=None,
246310
feature_ndims=1,
247311
validate_args=False,
@@ -261,11 +325,23 @@ def __init__(self,
261325
`bias_variance`, `shift`, and inputs to `apply` and `matrix` methods. A
262326
value of `None` is treated like 1.
263327
Default Value: `None`
328+
bias_amplitude: Non-negative floating point `Tensor` that controls the
329+
stddev from the origin. If bias = 0, there is no stddev and the
330+
fitted function goes through the origin. Must be broadcastable with
331+
`slope_amplitude`, `shift`, `exponent`, and inputs to `apply` and
332+
`matrix` methods. A value of `None` is treated like 0.
333+
Default Value: `None`
334+
slope_amplitude: Non-negative floating point `Tensor` that controls the
335+
stddev of the regression line slope that is the basis for the
336+
polynomial. Must be broadcastable with `bias_amplitude`, `shift`,
337+
`exponent`, and inputs to `apply` and `matrix` methods. A value of
338+
`None` is treated like 1.
339+
Default Value: `None`
264340
shift: Floating point `Tensor` that controls the intercept with the x-axis
265-
of the linear interpolation. Must be broadcastable with `bias_variance`,
266-
`slope_variance`, and inputs to `apply` and `matrix` methods. A value of
267-
`None` is treated like 0, which results in having the intercept at the
268-
origin.
341+
of the linear interpolation. Must be broadcastable with
342+
`bias_amplitude`, `slope_amplitude`, and inputs to `apply` and `matrix`
343+
methods. A value of `None` is treated like 0, which results in having
344+
the intercept at the origin.
269345
feature_ndims: Python `int` number of rightmost dims to include in kernel
270346
computation.
271347
Default Value: 1
@@ -280,6 +356,8 @@ def __init__(self,
280356
super(Linear, self).__init__(
281357
bias_variance=bias_variance,
282358
slope_variance=slope_variance,
359+
bias_amplitude=bias_amplitude,
360+
slope_amplitude=slope_amplitude,
283361
shift=shift,
284362
exponent=None,
285363
feature_ndims=feature_ndims,
@@ -291,9 +369,15 @@ def __init__(self,
291369
def _parameter_properties(cls, dtype):
292370
from tensorflow_probability.python.bijectors import softplus # pylint:disable=g-import-not-at-top
293371
return dict(
372+
bias_amplitude=parameter_properties.ParameterProperties(
373+
default_constraining_bijector_fn=(
374+
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
294375
bias_variance=parameter_properties.ParameterProperties(
295376
default_constraining_bijector_fn=(
296377
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
378+
slope_amplitude=parameter_properties.ParameterProperties(
379+
default_constraining_bijector_fn=(
380+
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
297381
slope_variance=parameter_properties.ParameterProperties(
298382
default_constraining_bijector_fn=(
299383
lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
@@ -334,8 +418,8 @@ def __init__(self,
334418
constant, name='constant')
335419
from tensorflow_probability.python import util as tfp_util # pylint:disable=g-import-not-at-top
336420
super(Constant, self).__init__(
337-
bias_variance=tfp_util.DeferredTensor(self._constant, tf.math.sqrt),
338-
slope_variance=0.0,
421+
bias_amplitude=tfp_util.DeferredTensor(self._constant, tf.math.sqrt),
422+
slope_amplitude=0.0,
339423
shift=None,
340424
feature_ndims=feature_ndims,
341425
validate_args=validate_args,

0 commit comments

Comments
 (0)