Skip to content

Commit 1d651d6

Browse files
jburnimtensorflower-gardener
authored andcommitted
Add an is_missing argument to GaussianProcess.log_prob.
When `is_missing` is passed, `GaussianProcess.log_prob` returns the log-probability of the marginal distribution where each event dimension for which `is_missing` is `True` is marginalized out. PiperOrigin-RevId: 427487729
1 parent 93e0365 commit 1d651d6

File tree

6 files changed

+183
-55
lines changed

6 files changed

+183
-55
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ multi_substrate_py_library(
724724
":kullback_leibler",
725725
":mvn_linear_operator",
726726
":normal",
727+
# numpy dep,
727728
# tensorflow dep,
728729
"//tensorflow_probability/python/bijectors:identity",
729730
"//tensorflow_probability/python/internal:distribution_util",
@@ -732,6 +733,7 @@ multi_substrate_py_library(
732733
"//tensorflow_probability/python/internal:reparameterization",
733734
"//tensorflow_probability/python/internal:tensor_util",
734735
"//tensorflow_probability/python/internal:tensorshape_util",
736+
"//tensorflow_probability/python/math/psd_kernels/internal:util",
735737
],
736738
)
737739

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import warnings
1818

1919
# Dependency imports
20+
import numpy as np
2021
import tensorflow.compat.v2 as tf
2122

2223
from tensorflow_probability.python.bijectors import identity as identity_bijector
@@ -33,6 +34,7 @@
3334
from tensorflow_probability.python.internal import reparameterization
3435
from tensorflow_probability.python.internal import tensor_util
3536
from tensorflow_probability.python.internal import tensorshape_util
37+
from tensorflow_probability.python.math.psd_kernels.internal import util as psd_kernels_util
3638
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
3739

3840
__all__ = [
@@ -425,30 +427,44 @@ def get_marginal_distribution(self, index_points=None):
425427
points, respectively.
426428
"""
427429
with self._name_and_control_scope('get_marginal_distribution'):
428-
# TODO(cgs): consider caching the result here, keyed on `index_points`.
429-
index_points = self._get_index_points(index_points)
430-
covariance = self._compute_covariance(index_points)
431-
loc = self._mean_fn(index_points)
432-
# If we're sure the number of index points is 1, we can just construct a
433-
# scalar Normal. This has computational benefits and supports things like
434-
# CDF that aren't otherwise straightforward to provide.
435-
if self._is_univariate_marginal(index_points):
436-
scale = tf.sqrt(covariance)
437-
# `loc` has a trailing 1 in the shape; squeeze it.
438-
loc = tf.squeeze(loc, axis=-1)
439-
return normal.Normal(
440-
loc=loc,
441-
scale=scale,
442-
validate_args=self._validate_args,
443-
allow_nan_stats=self._allow_nan_stats,
444-
name='marginal_distribution')
430+
return self._get_marginal_distribution(index_points=index_points)
431+
432+
def _get_marginal_distribution(self, index_points=None, is_missing=None):
433+
# TODO(cgs): consider caching the result here, keyed on `index_points`.
434+
index_points = self._get_index_points(index_points)
435+
covariance = self._compute_covariance(index_points)
436+
is_univariate_marginal = self._is_univariate_marginal(index_points)
437+
438+
loc = self._mean_fn(index_points)
439+
if is_univariate_marginal:
440+
# `loc` has a trailing 1 in the shape; squeeze it.
441+
loc = tf.squeeze(loc, axis=-1)
442+
443+
if is_missing is not None:
444+
loc = tf.where(is_missing, 0., loc)
445+
if is_univariate_marginal:
446+
covariance = tf.where(is_missing, 1., covariance)
445447
else:
446-
return self._marginal_fn(
447-
loc=loc,
448-
covariance=covariance,
449-
validate_args=self._validate_args,
450-
allow_nan_stats=self._allow_nan_stats,
451-
name='marginal_distribution')
448+
covariance = psd_kernels_util.mask_matrix(covariance, ~is_missing) # pylint:disable=invalid-unary-operand-type
449+
450+
# If we're sure the number of index points is 1, we can just construct a
451+
# scalar Normal. This has computational benefits and supports things like
452+
# CDF that aren't otherwise straightforward to provide.
453+
if is_univariate_marginal:
454+
scale = tf.sqrt(covariance)
455+
return normal.Normal(
456+
loc=loc,
457+
scale=scale,
458+
validate_args=self._validate_args,
459+
allow_nan_stats=self._allow_nan_stats,
460+
name='marginal_distribution')
461+
else:
462+
return self._marginal_fn(
463+
loc=loc,
464+
covariance=covariance,
465+
validate_args=self._validate_args,
466+
allow_nan_stats=self._allow_nan_stats,
467+
name='marginal_distribution')
452468

453469
@property
454470
def mean_fn(self):
@@ -524,8 +540,47 @@ def _get_index_points(self, index_points=None):
524540
return tf.convert_to_tensor(
525541
index_points if index_points is not None else self._index_points)
526542

527-
def _log_prob(self, value, index_points=None):
528-
return self.get_marginal_distribution(index_points).log_prob(value)
543+
@distribution_util.AppendDocstring(kwargs_dict={
544+
'index_points':
545+
'optional `float` `Tensor` representing a finite (batch of) of '
546+
'points in the index set over which this GP is defined. The shape '
547+
'has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the '
548+
'number of feature dimensions and must equal '
549+
'`self.kernel.feature_ndims` and `e` is the number of index points '
550+
'in each batch. Ultimately, this distribution corresponds to an '
551+
'`e`-dimensional multivariate normal. The batch shape must be '
552+
'broadcastable with `kernel.batch_shape` and any batch dims yielded '
553+
'by `mean_fn`. If not specified, `self.index_points` is used. '
554+
'Default value: `None`.',
555+
'is_missing':
556+
'optional `bool` `Tensor` of shape `[..., e]`, where `e` is the '
557+
'number of index points in each batch. Represents a batch of '
558+
'Boolean masks. When `is_missing` is not `None`, the returned '
559+
'log-prob is for the *marginal* distribution, in which all '
560+
'dimensions for which `is_missing` is `True` have been marginalized '
561+
'out. The batch dimensions of `is_missing` must broadcast with the '
562+
'sample and batch dimensions of `value` and of this `Distribution`. '
563+
'Default value: `None`.'
564+
})
565+
def _log_prob(self, value, index_points=None, is_missing=None):
566+
if is_missing is not None:
567+
is_missing = tf.convert_to_tensor(is_missing)
568+
index_points = self._get_index_points(index_points)
569+
mvn = self._get_marginal_distribution(index_points, is_missing=is_missing)
570+
if is_missing is None:
571+
return mvn.log_prob(value)
572+
573+
# Subtract out the Normal distribution's log normalizer for each dimension
574+
# that is masked out.
575+
lp = mvn.log_prob(tf.where(is_missing, 0., value))
576+
num_masked_dims = tf.cast(is_missing, mvn.dtype)
577+
if not self._is_univariate_marginal(index_points):
578+
event_shape = self._event_shape_tensor(index_points=index_points)
579+
num_masked_dims = tf.reduce_sum(
580+
num_masked_dims * tf.ones(event_shape, dtype=mvn.dtype),
581+
axis=-1)
582+
correction = num_masked_dims * -0.5 * np.log(2. * np.pi)
583+
return lp - correction
529584

530585
def _event_shape_tensor(self, index_points=None):
531586
index_points = self._get_index_points(index_points)

tensorflow_probability/python/distributions/gaussian_process_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,77 @@ def testGPPosteriorPredictive(self):
330330
self.evaluate(expected_gprm.log_prob(samples)),
331331
self.evaluate(actual_gprm.log_prob(samples)))
332332

333+
def testLogProbWithIsMissing(self):
334+
index_points = tf.Variable(
335+
[[-1.0, 0.0], [-0.5, -0.5], [1.5, 0.0], [1.6, 1.5]],
336+
shape=None if self.is_static else tf.TensorShape(None))
337+
self.evaluate(index_points.initializer)
338+
amplitude = tf.convert_to_tensor(1.1)
339+
length_scale = tf.convert_to_tensor(0.9)
340+
341+
gp = tfd.GaussianProcess(
342+
kernel=psd_kernels.ExponentiatedQuadratic(
343+
amplitude, length_scale),
344+
index_points=index_points,
345+
mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
346+
observation_noise_variance=.05,
347+
jitter=0.0)
348+
349+
x = gp.sample(5, seed=test_util.test_seed())
350+
351+
is_missing = np.array([
352+
[False, True, False, False],
353+
[False, False, False, False],
354+
[True, False, True, True],
355+
[True, False, False, True],
356+
[False, False, True, True],
357+
])
358+
359+
lp = gp.log_prob(tf.where(is_missing, np.nan, x), is_missing=is_missing)
360+
361+
# For each batch member, check that the log_prob is the same as for a
362+
# GaussianProcess without the missing index points.
363+
for i in range(5):
364+
gp_i = tfd.GaussianProcess(
365+
kernel=psd_kernels.ExponentiatedQuadratic(
366+
amplitude, length_scale),
367+
index_points=tf.gather(index_points, (~is_missing[i]).nonzero()[0]),
368+
mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
369+
observation_noise_variance=.05,
370+
jitter=0.0)
371+
lp_i = gp_i.log_prob(tf.gather(x[i], (~is_missing[i]).nonzero()[0]))
372+
# NOTE: This reshape is necessary because lp_i has shape [1] when
373+
# gp_i.index_points contains a single index point.
374+
self.assertAllClose(tf.reshape(lp_i, []), lp[i])
375+
376+
# The log_prob should be zero when all points are missing out.
377+
self.assertAllClose(tf.zeros((3, 2)),
378+
gp.log_prob(tf.ones((3, 1, 4)) * np.nan,
379+
is_missing=tf.constant(True, shape=(2, 4))))
380+
381+
def testUnivariateLogProbWithIsMissing(self):
382+
index_points = tf.convert_to_tensor([[[0.0, 0.0]], [[0.5, 1.0]]])
383+
amplitude = tf.convert_to_tensor(1.1)
384+
length_scale = tf.convert_to_tensor(0.9)
385+
386+
gp = tfd.GaussianProcess(
387+
kernel=psd_kernels.ExponentiatedQuadratic(
388+
amplitude, length_scale),
389+
index_points=index_points,
390+
mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
391+
observation_noise_variance=.05,
392+
jitter=0.0)
393+
394+
x = gp.sample(3, seed=test_util.test_seed())
395+
lp = gp.log_prob(x)
396+
397+
self.assertAllClose(lp, gp.log_prob(x, is_missing=[False, False]))
398+
self.assertAllClose(tf.convert_to_tensor([np.zeros((3, 2)), lp]),
399+
gp.log_prob(x, is_missing=[[[True]], [[False]]]))
400+
self.assertAllClose(
401+
tf.convert_to_tensor([[lp[0, 0], 0.0], [0.0, 0.0], [0., lp[2, 1]]]),
402+
gp.log_prob(x, is_missing=[[False, True], [True, True], [True, False]]))
403+
333404

334405
@test_util.test_all_tf_execution_regimes
335406
class GaussianProcessStaticTest(_GaussianProcessTest, test_util.TestCase):

tensorflow_probability/python/math/psd_kernels/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ multi_substrate_py_library(
281281
"//tensorflow_probability/python/distributions:cholesky_util",
282282
"//tensorflow_probability/python/internal:dtype_util",
283283
"//tensorflow_probability/python/internal:parameter_properties",
284-
"//tensorflow_probability/python/internal:prefer_static",
285284
"//tensorflow_probability/python/internal:tensorshape_util",
285+
"//tensorflow_probability/python/math/psd_kernels/internal:util",
286286
],
287287
)
288288

tensorflow_probability/python/math/psd_kernels/internal/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_probability.python.internal import tensorshape_util
2323

2424
__all__ = [
25+
'mask_matrix',
2526
'maybe_get_common_dtype',
2627
'pad_shape_with_ones',
2728
'pairwise_square_distance_matrix',
@@ -284,3 +285,29 @@ def pairwise_square_distance_tensor(
284285
# Now we need to undo the transformation.
285286
return tf.reshape(pairwise, tf.concat([
286287
tf.shape(pairwise)[:-2], x1_example_shape, x2_example_shape], axis=0))
288+
289+
290+
def mask_matrix(x, mask=None):
291+
"""Copies a matrix, replacing masked-out rows/cols from the identity matrix.
292+
293+
Args:
294+
x: A Tensor of shape `[..., n, n]`, representing a batch of n-by-n matrices.
295+
mask: A boolean Tensor of shape `[..., n]`, representing a batch of masks.
296+
If `mask` is None, `x` is returned.
297+
Returns:
298+
A Tensor of shape `[..., n, n]`, representing a batch of n-by-n matrices.
299+
For each batch member `r`, element `r[i, j]` equals `eye(n)[i, j]` if
300+
dimension `i` or `j` is False in the corresponding input mask. Otherwise,
301+
`r[i, j]` equals the corresponding element from `x`.
302+
"""
303+
if mask is None:
304+
return x
305+
306+
x = tf.convert_to_tensor(x)
307+
mask = tf.convert_to_tensor(mask, dtype=tf.bool)
308+
309+
n = ps.dimension_size(x, -1)
310+
311+
return tf.where(~mask[..., tf.newaxis] | ~mask[..., tf.newaxis, :],
312+
tf.eye(n, dtype=x.dtype),
313+
x)

tensorflow_probability/python/math/psd_kernels/schur_complement.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from tensorflow_probability.python.internal import distribution_util
1919
from tensorflow_probability.python.internal import dtype_util
2020
from tensorflow_probability.python.internal import parameter_properties
21-
from tensorflow_probability.python.internal import prefer_static as ps
2221
from tensorflow_probability.python.internal import tensor_util
2322
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
2423
from tensorflow_probability.python.math.psd_kernels.internal import util
@@ -58,32 +57,6 @@ def _compute_divisor_matrix(
5857
return divisor_matrix
5958

6059

61-
def _mask_matrix(x, mask=None):
62-
"""Copies a matrix, replacing masked-out rows/cols from the identity matrix.
63-
64-
Args:
65-
x: A Tensor of shape `[..., n, n]`, representing a batch of n-by-n matrices.
66-
mask: A boolean Tensor of shape `[..., n]`, representing a batch of masks.
67-
If `mask` is None, `x` is returned.
68-
Returns:
69-
A Tensor of shape `[..., n, n]`, representing a batch of n-by-n matrices.
70-
For each batch member `r`, element `r[i, j]` equals `eye(n)[i, j]` if
71-
dimension `i` or `j` is False in the corresponding input mask. Otherwise,
72-
`r[i, j]` equals the corresponding element from `x`.
73-
"""
74-
if mask is None:
75-
return x
76-
77-
x = tf.convert_to_tensor(x)
78-
mask = tf.convert_to_tensor(mask, dtype=tf.bool)
79-
80-
n = ps.dimension_size(x, -1)
81-
82-
return tf.where(~mask[..., tf.newaxis] | ~mask[..., tf.newaxis, :],
83-
tf.eye(n, dtype=x.dtype),
84-
x)
85-
86-
8760
class SchurComplement(psd_kernel.AutoCompositeTensorPsdKernel):
8861
"""The SchurComplement kernel.
8962
@@ -363,7 +336,7 @@ def with_precomputed_divisor(
363336

364337
# TODO(b/196219597): Add a check to ensure that we have a `base_kernel`
365338
# that is explicitly concretized.
366-
divisor_matrix_cholesky = cholesky_fn(_mask_matrix(
339+
divisor_matrix_cholesky = cholesky_fn(util.mask_matrix(
367340
_compute_divisor_matrix(base_kernel,
368341
diag_shift=diag_shift,
369342
fixed_inputs=fixed_inputs),
@@ -529,7 +502,7 @@ def _divisor_matrix(self, fixed_inputs=None, fixed_inputs_mask=None):
529502
# NOTE: Replacing masked-out rows/columns of the divisor matrix with
530503
# rows/columns from the identity matrix is equivalent to using a divisor
531504
# matrix in which those rows and columns have been dropped.
532-
return _mask_matrix(
505+
return util.mask_matrix(
533506
_compute_divisor_matrix(self._base_kernel,
534507
diag_shift=self._diag_shift,
535508
fixed_inputs=fixed_inputs),

0 commit comments

Comments
 (0)