@@ -596,6 +596,53 @@ def _mode(self, index_points=None):
596
596
def _default_event_space_bijector (self ):
597
597
return identity_bijector .Identity (validate_args = self .validate_args )
598
598
599
+ def posterior_predictive (self , observations , predictive_index_points = None ):
600
+ """Return the posterior predictive distribution associated with this distribution.
601
+
602
+ Given `predictive_index_points` and `observations`, return the posterior
603
+ predictive distribution on the `predictive_index_points` conditioned on
604
+ `index_points` and `observations` associated to `index_points`.
605
+
606
+ This is equivalent to using the `GaussianProcessRegressionModel` class.
607
+
608
+ Args:
609
+ observations: `float` `Tensor` representing collection, or batch of
610
+ collections, of observations corresponding to
611
+ `self.index_points`. Shape has the form `[b1, ..., bB, e]`, which
612
+ must be broadcastable with the batch and example shapes of
613
+ `self.index_points`. The batch shape `[b1, ..., bB]` must be
614
+ broadcastable with the shapes of all other batched parameters
615
+ predictive_index_points: `float` `Tensor` representing finite collection,
616
+ or batch of collections, of points in the index set over which the GP
617
+ is defined.
618
+ Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
619
+ number of feature dimensions and must equal `kernel.feature_ndims` and
620
+ `e` is the number (size) of predictive index points in each batch.
621
+ The batch shape must be broadcastable with this distributions
622
+ `batch_shape`.
623
+ Default value: `None`.
624
+
625
+ Returns:
626
+ gprm: An instance of `Distribution` that represents the posterior
627
+ predictive.
628
+ """
629
+ from tensorflow_probability .python .distributions import gaussian_process_regression_model as gprm # pylint:disable=g-import-not-at-top
630
+ if self .index_points is None :
631
+ raise ValueError (
632
+ 'Expected that `self.index_points` is not `None`. Using '
633
+ '`self.index_points=None` is equivalent to using a `GaussianProcess` '
634
+ 'prior, which this class encapsulates.' )
635
+ return gprm .GaussianProcessRegressionModel (
636
+ kernel = self .kernel ,
637
+ observation_index_points = self .index_points ,
638
+ observations = observations ,
639
+ index_points = predictive_index_points ,
640
+ observation_noise_variance = self .observation_noise_variance ,
641
+ mean_fn = self .mean_fn ,
642
+ jitter = self .jitter ,
643
+ validate_args = self .validate_args ,
644
+ allow_nan_stats = self .allow_nan_stats )
645
+
599
646
600
647
def _assert_kl_compatible (marginal , other ):
601
648
if ((isinstance (marginal , normal .Normal ) and
0 commit comments