Skip to content

Commit 3f42739

Browse files
brianwa84jburnim
authored andcommitted
Fix batch slicing of precomputed GPRM.
PiperOrigin-RevId: 548144627
1 parent 6f62a06 commit 3f42739

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

tensorflow_probability/python/distributions/gaussian_process_regression_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow_probability.python.internal import dtype_util
2626
from tensorflow_probability.python.internal import nest_util
2727
from tensorflow_probability.python.internal import parameter_properties
28+
from tensorflow_probability.python.internal import slicing
2829
from tensorflow_probability.python.internal import tensor_util
2930
from tensorflow_probability.python.math.psd_kernels import schur_complement
3031
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
@@ -819,6 +820,7 @@ def _event_ndims_fn(self):
819820
shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED,
820821
),
821822
kernel=parameter_properties.BatchedComponentProperties(),
823+
_conditional_kernel=parameter_properties.BatchedComponentProperties(),
822824
observation_noise_variance=parameter_properties.ParameterProperties(
823825
event_ndims=0,
824826
shape_fn=lambda sample_shape: sample_shape[:-1],
@@ -829,3 +831,8 @@ def _event_ndims_fn(self):
829831
shape_fn=lambda sample_shape: sample_shape[:-1],
830832
default_constraining_bijector_fn=(
831833
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
834+
835+
def __getitem__(self, slices) -> 'GaussianProcessRegressionModel':
836+
# _conditional_mean_fn is a closure over possibly-sliced values, but will
837+
# be rebuilt by the constructor.
838+
return slicing.batch_slice(self, dict(_conditional_mean_fn=None), slices)

tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,25 @@ def testPrivateArgPreventsCholeskyRecomputation(self):
689689
self.assertAllClose(d.log_prob(y_obs), d2.log_prob(y_obs))
690690
self.assertEqual(mock_cholesky_fn.call_count, 2)
691691

692+
def test_batch_slice_precomputed_gprm(self):
693+
base_kernel = exponentiated_quadratic.ExponentiatedQuadratic(
694+
length_scale=tf.linspace(tf.ones([]), 2., 64), feature_ndims=0)
695+
x = tf.linspace(tf.zeros([]), 1., 126)
696+
y = tf.linspace(tf.zeros([]), 1.5, 162)
697+
d = gprm.GaussianProcessRegressionModel.precompute_regression_model(
698+
base_kernel,
699+
index_points=y,
700+
observation_index_points=x,
701+
observations=tf.math.sin(x),
702+
observation_noise_variance=1e-3)
703+
self.assertEqual((64,), d.batch_shape)
704+
self.assertEqual((162,), d.event_shape)
705+
self.assertEqual((64, 162,), d.sample(seed=test_util.test_seed()).shape)
706+
707+
self.assertEqual((), d[2].batch_shape)
708+
self.assertEqual((162,), d[2].event_shape)
709+
self.assertEqual((162,), d[2].sample(seed=test_util.test_seed()).shape)
710+
692711

693712
class GaussianProcessRegressionModelStaticTest(
694713
_GaussianProcessRegressionModelTest):

0 commit comments

Comments
 (0)