Skip to content

Commit a6d3239

Browse files
srvasudetensorflower-gardener
authored andcommitted
Reduce size of mtgp JIT test. This is testing JIT is working (rather than broadcasting), so the shapes can be smaller.
PiperOrigin-RevId: 454703384
1 parent 4bf8811 commit a6d3239

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_test.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -263,19 +263,18 @@ def testLogProbMatchesGPNoiseless(self):
263263
disable_numpy=True, disable_jax=False,
264264
reason='Jit not available in numpy.')
265265
def testJitMultitaskGaussianProcess(self):
266-
# 5x5 grid of index points in R^2 and flatten to 25x2
267-
index_points = np.linspace(-4., 4., 5, dtype=np.float32)
266+
# 3x3 grid of index points in R^2 and flatten to 9x2
267+
index_points = np.linspace(-4., 4., 3, dtype=np.float32)
268268
index_points = np.stack(np.meshgrid(index_points, index_points), axis=-1)
269269
index_points = np.reshape(index_points, [-1, 2])
270-
# ==> shape = [25, 2]
271-
272-
# Kernel with batch_shape [2, 4, 3, 1]
273-
amplitude = np.array([1., 2.], np.float32).reshape([2, 1, 1, 1])
274-
length_scale = np.array([1., 2., 3., 4.], np.float32).reshape([1, 4, 1, 1])
275-
observation_noise_variance = np.array(
276-
[1e-5, 1e-6, 1e-5], np.float32).reshape([1, 1, 3, 1])
277-
batched_index_points = np.stack([index_points]*6)
278-
# ==> shape = [6, 25, 2]
270+
# ==> shape = [9, 2]
271+
272+
# Kernel with batch_shape [2, 4, 3]
273+
amplitude = np.array([1., 2.], np.float32).reshape([2, 1,])
274+
length_scale = np.array([1., 2., 3., 4.], np.float32).reshape([1, 4,])
275+
observation_noise_variance = np.float32(1e-5)
276+
batched_index_points = np.stack([index_points]*4)
277+
# ==> shape = [4, 9, 2]
279278
kernel = tfk.ExponentiatedQuadratic(amplitude, length_scale)
280279
multi_task_kernel = tfe.psd_kernels.Independent(
281280
num_tasks=3, base_kernel=kernel)
@@ -294,9 +293,9 @@ def sample():
294293
return multitask_gp.sample(seed=test_util.test_seed())
295294

296295
observations = tf.convert_to_tensor(
297-
np.linspace(-20., 20., 75).reshape(25, 3).astype(np.float32))
298-
self.assertAllEqual(log_prob(observations).shape, [2, 4, 3, 6])
299-
self.assertAllEqual(sample().shape, [2, 4, 3, 6, 25, 3])
296+
np.linspace(-20., 20., 27).reshape(9, 3).astype(np.float32))
297+
self.assertAllEqual(log_prob(observations).shape, [2, 4])
298+
self.assertAllEqual(sample().shape, [2, 4, 9, 3])
300299

301300
multitask_gp = tfe.distributions.MultiTaskGaussianProcess(
302301
multi_task_kernel,
@@ -312,8 +311,8 @@ def log_prob_no_noise(o):
312311
def sample_no_noise():
313312
return multitask_gp.sample(seed=test_util.test_seed())
314313

315-
self.assertAllEqual(log_prob_no_noise(observations).shape, [2, 4, 1, 6])
316-
self.assertAllEqual(sample_no_noise().shape, [2, 4, 1, 6, 25, 3])
314+
self.assertAllEqual(log_prob_no_noise(observations).shape, [2, 4])
315+
self.assertAllEqual(sample_no_noise().shape, [2, 4, 9, 3])
317316

318317
def testMultiTaskBlockSeparable(self):
319318
# Check that the naive implementation matches any optimizations for a

0 commit comments

Comments
 (0)