Skip to content

Commit 18c5c9a

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add gp.posterior_predictive for constructing a posterior predictive from a GP.
This is equivalent (and a wrapper) to calling GaussianProcessRegressionModel. PiperOrigin-RevId: 384605668
1 parent 878f096 commit 18c5c9a

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,53 @@ def _mode(self, index_points=None):
596596
def _default_event_space_bijector(self):
597597
return identity_bijector.Identity(validate_args=self.validate_args)
598598

599+
def posterior_predictive(self, observations, predictive_index_points=None):
600+
"""Return the posterior predictive distribution associated with this distribution.
601+
602+
Given `predictive_index_points` and `observations`, return the posterior
603+
predictive distribution on the `predictive_index_points` conditioned on
604+
`index_points` and `observations` associated to `index_points`.
605+
606+
This is equivalent to using the `GaussianProcessRegressionModel` class.
607+
608+
Args:
609+
observations: `float` `Tensor` representing collection, or batch of
610+
collections, of observations corresponding to
611+
`self.index_points`. Shape has the form `[b1, ..., bB, e]`, which
612+
must be broadcastable with the batch and example shapes of
613+
`self.index_points`. The batch shape `[b1, ..., bB]` must be
614+
broadcastable with the shapes of all other batched parameters
615+
predictive_index_points: `float` `Tensor` representing finite collection,
616+
or batch of collections, of points in the index set over which the GP
617+
is defined.
618+
Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
619+
number of feature dimensions and must equal `kernel.feature_ndims` and
620+
`e` is the number (size) of predictive index points in each batch.
621+
The batch shape must be broadcastable with this distributions
622+
`batch_shape`.
623+
Default value: `None`.
624+
625+
Returns:
626+
gprm: An instance of `Distribution` that represents the posterior
627+
predictive.
628+
"""
629+
from tensorflow_probability.python.distributions import gaussian_process_regression_model as gprm # pylint:disable=g-import-not-at-top
630+
if self.index_points is None:
631+
raise ValueError(
632+
'Expected that `self.index_points` is not `None`. Using '
633+
'`self.index_points=None` is equivalent to using a `GaussianProcess` '
634+
'prior, which this class encapsulates.')
635+
return gprm.GaussianProcessRegressionModel(
636+
kernel=self.kernel,
637+
observation_index_points=self.index_points,
638+
observations=observations,
639+
index_points=predictive_index_points,
640+
observation_noise_variance=self.observation_noise_variance,
641+
mean_fn=self.mean_fn,
642+
jitter=self.jitter,
643+
validate_args=self.validate_args,
644+
allow_nan_stats=self.allow_nan_stats)
645+
599646

600647
def _assert_kl_compatible(marginal, other):
601648
if ((isinstance(marginal, normal.Normal) and

tensorflow_probability/python/distributions/gaussian_process_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,52 @@ def test_marginal_fn(
257257
np.eye(10),
258258
gp.get_marginal_distribution().covariance())
259259

260+
def testGPPosteriorPredictive(self):
261+
amplitude = np.float64(.5)
262+
length_scale = np.float64(2.)
263+
jitter = np.float64(1e-4)
264+
observation_noise_variance = np.float64(3e-3)
265+
kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
266+
267+
index_points = np.random.uniform(-1., 1., 10)[..., np.newaxis]
268+
269+
gp = tfd.GaussianProcess(
270+
kernel,
271+
index_points,
272+
observation_noise_variance=observation_noise_variance,
273+
jitter=jitter,
274+
validate_args=True)
275+
276+
predictive_index_points = np.random.uniform(1., 2., 10)[..., np.newaxis]
277+
observations = np.linspace(1., 10., 10)
278+
279+
expected_gprm = tfd.GaussianProcessRegressionModel(
280+
kernel=kernel,
281+
observation_index_points=index_points,
282+
observations=observations,
283+
observation_noise_variance=observation_noise_variance,
284+
jitter=jitter,
285+
index_points=predictive_index_points,
286+
validate_args=True)
287+
288+
actual_gprm = gp.posterior_predictive(
289+
predictive_index_points=predictive_index_points,
290+
observations=observations)
291+
292+
samples = self.evaluate(actual_gprm.sample(10, seed=test_util.test_seed()))
293+
294+
self.assertAllClose(
295+
self.evaluate(expected_gprm.log_prob(samples)),
296+
self.evaluate(actual_gprm.log_prob(samples)))
297+
298+
self.assertAllClose(
299+
self.evaluate(expected_gprm.mean()),
300+
self.evaluate(actual_gprm.mean()))
301+
302+
self.assertAllClose(
303+
self.evaluate(expected_gprm.covariance()),
304+
self.evaluate(actual_gprm.covariance()))
305+
260306

261307
@test_util.test_all_tf_execution_regimes
262308
class GaussianProcessStaticTest(_GaussianProcessTest, test_util.TestCase):

0 commit comments

Comments
 (0)