@@ -263,19 +263,18 @@ def testLogProbMatchesGPNoiseless(self):
263
263
disable_numpy = True , disable_jax = False ,
264
264
reason = 'Jit not available in numpy.' )
265
265
def testJitMultitaskGaussianProcess (self ):
266
- # 5x5 grid of index points in R^2 and flatten to 25x2
267
- index_points = np .linspace (- 4. , 4. , 5 , dtype = np .float32 )
266
+ # 3x3 grid of index points in R^2 and flatten to 9x2
267
+ index_points = np .linspace (- 4. , 4. , 3 , dtype = np .float32 )
268
268
index_points = np .stack (np .meshgrid (index_points , index_points ), axis = - 1 )
269
269
index_points = np .reshape (index_points , [- 1 , 2 ])
270
- # ==> shape = [25, 2]
271
-
272
- # Kernel with batch_shape [2, 4, 3, 1]
273
- amplitude = np .array ([1. , 2. ], np .float32 ).reshape ([2 , 1 , 1 , 1 ])
274
- length_scale = np .array ([1. , 2. , 3. , 4. ], np .float32 ).reshape ([1 , 4 , 1 , 1 ])
275
- observation_noise_variance = np .array (
276
- [1e-5 , 1e-6 , 1e-5 ], np .float32 ).reshape ([1 , 1 , 3 , 1 ])
277
- batched_index_points = np .stack ([index_points ]* 6 )
278
- # ==> shape = [6, 25, 2]
270
+ # ==> shape = [9, 2]
271
+
272
+ # Kernel with batch_shape [2, 4, 3]
273
+ amplitude = np .array ([1. , 2. ], np .float32 ).reshape ([2 , 1 ,])
274
+ length_scale = np .array ([1. , 2. , 3. , 4. ], np .float32 ).reshape ([1 , 4 ,])
275
+ observation_noise_variance = np .float32 (1e-5 )
276
+ batched_index_points = np .stack ([index_points ]* 4 )
277
+ # ==> shape = [4, 9, 2]
279
278
kernel = tfk .ExponentiatedQuadratic (amplitude , length_scale )
280
279
multi_task_kernel = tfe .psd_kernels .Independent (
281
280
num_tasks = 3 , base_kernel = kernel )
@@ -294,9 +293,9 @@ def sample():
294
293
return multitask_gp .sample (seed = test_util .test_seed ())
295
294
296
295
observations = tf .convert_to_tensor (
297
- np .linspace (- 20. , 20. , 75 ).reshape (25 , 3 ).astype (np .float32 ))
298
- self .assertAllEqual (log_prob (observations ).shape , [2 , 4 , 3 , 6 ])
299
- self .assertAllEqual (sample ().shape , [2 , 4 , 3 , 6 , 25 , 3 ])
296
+ np .linspace (- 20. , 20. , 27 ).reshape (9 , 3 ).astype (np .float32 ))
297
+ self .assertAllEqual (log_prob (observations ).shape , [2 , 4 ])
298
+ self .assertAllEqual (sample ().shape , [2 , 4 , 9 , 3 ])
300
299
301
300
multitask_gp = tfe .distributions .MultiTaskGaussianProcess (
302
301
multi_task_kernel ,
@@ -312,8 +311,8 @@ def log_prob_no_noise(o):
312
311
def sample_no_noise ():
313
312
return multitask_gp .sample (seed = test_util .test_seed ())
314
313
315
- self .assertAllEqual (log_prob_no_noise (observations ).shape , [2 , 4 , 1 , 6 ])
316
- self .assertAllEqual (sample_no_noise ().shape , [2 , 4 , 1 , 6 , 25 , 3 ])
314
+ self .assertAllEqual (log_prob_no_noise (observations ).shape , [2 , 4 ])
315
+ self .assertAllEqual (sample_no_noise ().shape , [2 , 4 , 9 , 3 ])
317
316
318
317
def testMultiTaskBlockSeparable (self ):
319
318
# Check that the naive implementation matches any optimizations for a
0 commit comments