Skip to content

Commit 8fa9dcb

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add errors for MTGP/MTGPRM when observations are ill-formed.
PiperOrigin-RevId: 473825902
1 parent 1a5bd75 commit 8fa9dcb

File tree

5 files changed

+145
-0
lines changed

5 files changed

+145
-0
lines changed

tensorflow_probability/python/experimental/distributions/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ multi_substrate_py_library(
187187
"//tensorflow_probability/python/experimental/linalg:linear_operator_unitary",
188188
"//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel",
189189
"//tensorflow_probability/python/internal:dtype_util",
190+
"//tensorflow_probability/python/internal:tensor_util",
191+
"//tensorflow_probability/python/internal:tensorshape_util",
190192
],
191193
)
192194

@@ -216,6 +218,8 @@ multi_substrate_py_library(
216218
# tensorflow dep,
217219
"//tensorflow_probability/python/distributions:cholesky_util",
218220
"//tensorflow_probability/python/internal:dtype_util",
221+
"//tensorflow_probability/python/internal:tensor_util",
222+
"//tensorflow_probability/python/internal:tensorshape_util",
219223
"//tensorflow_probability/python/math/psd_kernels/internal:util",
220224
],
221225
)

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tensorflow_probability.python.internal import prefer_static as ps
3333
from tensorflow_probability.python.internal import reparameterization
3434
from tensorflow_probability.python.internal import tensor_util
35+
from tensorflow_probability.python.internal import tensorshape_util
3536

3637

3738
def _vec(x):
@@ -447,7 +448,45 @@ def _get_index_points(self, index_points=None):
447448
return tf.convert_to_tensor(
448449
index_points if index_points is not None else self._index_points)
449450

451+
def _check_observations_valid(self, observations, index_points):
452+
observation_rank = tensorshape_util.rank(observations.shape)
453+
454+
if observation_rank is None:
455+
return
456+
457+
if observation_rank >= 1:
458+
# Check that the last dimension of observations matches the number of
459+
# tasks.
460+
num_observations = tf.compat.dimension_value(observations.shape[-1])
461+
if (num_observations is not None and
462+
num_observations != 1 and
463+
num_observations != self.kernel.num_tasks):
464+
raise ValueError(
465+
f'Expected the number of observations {num_observations} '
466+
f'to broadcast / match the number of tasks '
467+
f'{self.kernel.num_tasks}')
468+
469+
if observation_rank >= 2:
470+
num_index_points = tf.compat.dimension_value(observations.shape[-2])
471+
472+
expected_num_index_points = index_points.shape[
473+
-(self.kernel.feature_ndims + 1)]
474+
if (num_index_points is not None and
475+
expected_num_index_points is not None and
476+
num_index_points != 1 and
477+
num_index_points != expected_num_index_points):
478+
raise ValueError(
479+
f'Expected number of index points '
480+
f'{expected_num_index_points} to broadcast / match the second '
481+
f'to last dimension of `observations` {num_index_points}')
482+
450483
def _log_prob(self, value, index_points=None):
484+
# Check that observations with at least 2 dimensions have
485+
# shape that's broadcastable to `[N, T]`, where `N` is the number
486+
# of index points, and T the number of tasks.
487+
index_points = self._get_index_points(index_points)
488+
self._check_observations_valid(value, index_points)
489+
451490
return self._get_flattened_marginal_distribution(
452491
index_points=index_points).log_prob(_vec(value))
453492

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def _mean_fn(x):
286286
self._observations = observations
287287
self._observations_is_missing = observations_is_missing
288288

289+
self._check_observations_valid(observations)
290+
289291
if _flattened_conditional_mean_fn is None:
290292

291293
def flattened_conditional_mean_fn(x):
@@ -319,6 +321,38 @@ def flattened_conditional_mean_fn(x):
319321
parameters=parameters,
320322
name=name)
321323

324+
def _check_observations_valid(self, observations):
325+
observation_rank = tensorshape_util.rank(observations.shape)
326+
327+
if observation_rank is None:
328+
return
329+
330+
if observation_rank >= 1:
331+
# Check that the last dimension of observations matches the number of
332+
# tasks.
333+
num_observations = tf.compat.dimension_value(observations.shape[-1])
334+
if (num_observations is not None and
335+
num_observations != 1 and
336+
num_observations != self.kernel.num_tasks):
337+
raise ValueError(
338+
f'Expected the number of observations {num_observations} '
339+
f'to broadcast / match the number of tasks '
340+
f'{self.kernel.num_tasks}')
341+
342+
if observation_rank >= 2:
343+
num_index_points = tf.compat.dimension_value(observations.shape[-2])
344+
345+
expected_num_index_points = self.observation_index_points.shape[
346+
-(self.kernel.feature_ndims + 1)]
347+
if (num_index_points is not None and
348+
expected_num_index_points is not None and
349+
num_index_points != 1 and
350+
num_index_points != expected_num_index_points):
351+
raise ValueError(
352+
f'Expected number of observation index points '
353+
f'{expected_num_index_points} to broadcast / match the second '
354+
f'to last dimension of `observations` {num_index_points}')
355+
322356
@staticmethod
323357
def precompute_regression_model(
324358
kernel,

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,47 @@ def testShapes(self, num_tasks):
110110
self.assertAllEqual(
111111
self.evaluate(tf.shape(gp.mean())), batch_shape + event_shape)
112112

113+
def testValidateArgs(self):
114+
index_points = np.linspace(-4., 4., 10, dtype=np.float32)
115+
index_points = np.reshape(index_points, [5, 2])
116+
index_points = np.linspace(-4., 4., 16, dtype=np.float32)
117+
observation_index_points = np.reshape(index_points, [8, 2])
118+
119+
observation_noise_variance = 1e-4
120+
kernel = exponentiated_quadratic.ExponentiatedQuadratic()
121+
multi_task_kernel = multitask_kernel.Independent(
122+
num_tasks=3, base_kernel=kernel)
123+
with self.assertRaisesRegexp(ValueError, 'match the number of tasks'):
124+
observations = np.linspace(-1., 1., 24).astype(np.float32)
125+
mtgprm_lib.MultiTaskGaussianProcessRegressionModel(
126+
multi_task_kernel,
127+
observation_index_points=observation_index_points,
128+
observations=observations,
129+
index_points=index_points,
130+
observation_noise_variance=observation_noise_variance,
131+
validate_args=True)
132+
133+
with self.assertRaisesRegexp(ValueError, 'match the number of tasks'):
134+
observations = np.linspace(-1., 1., 32).reshape(8, 4).astype(np.float32)
135+
mtgprm_lib.MultiTaskGaussianProcessRegressionModel(
136+
multi_task_kernel,
137+
observation_index_points=observation_index_points,
138+
observations=observations,
139+
index_points=index_points,
140+
observation_noise_variance=observation_noise_variance,
141+
validate_args=True)
142+
143+
with self.assertRaisesRegexp(
144+
ValueError, 'match the second to last dimension'):
145+
observations = np.linspace(-1., 1., 18).reshape(6, 3).astype(np.float32)
146+
mtgprm_lib.MultiTaskGaussianProcessRegressionModel(
147+
multi_task_kernel,
148+
observation_index_points=observation_index_points,
149+
observations=observations,
150+
index_points=index_points,
151+
observation_noise_variance=observation_noise_variance,
152+
validate_args=True)
153+
113154
@parameterized.parameters(1, 3, 5)
114155
def testBindingIndexPoints(self, num_tasks):
115156
amplitude = np.float64(0.5)

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,33 @@ def testMultiTaskBlockSeparable(self):
382382
self.evaluate(actual_multitask_var),
383383
self.evaluate(multitask_var), rtol=4e-3)
384384

385+
def testLogProbValidateArgs(self):
386+
index_points = np.linspace(-4., 4., 10, dtype=np.float32)
387+
index_points = np.reshape(index_points, [-1, 2])
388+
389+
observation_noise_variance = 1e-4
390+
kernel = exponentiated_quadratic.ExponentiatedQuadratic()
391+
multi_task_kernel = multitask_kernel.Independent(
392+
num_tasks=3, base_kernel=kernel)
393+
multitask_gp = multitask_gaussian_process.MultiTaskGaussianProcess(
394+
multi_task_kernel,
395+
index_points,
396+
observation_noise_variance=observation_noise_variance,
397+
validate_args=True)
398+
399+
with self.assertRaisesRegexp(ValueError, 'match the number of tasks'):
400+
observations = np.linspace(-1., 1., 15).astype(np.float32)
401+
multitask_gp.log_prob(observations)
402+
403+
with self.assertRaisesRegexp(ValueError, 'match the number of tasks'):
404+
observations = np.linspace(-1., 1., 20).reshape(5, 4).astype(np.float32)
405+
multitask_gp.log_prob(observations)
406+
407+
with self.assertRaisesRegexp(
408+
ValueError, 'match the second to last dimension'):
409+
observations = np.linspace(-1., 1., 18).reshape(6, 3).astype(np.float32)
410+
multitask_gp.log_prob(observations)
411+
385412
def testLogProbMatchesGP(self):
386413
# Check that the independent kernel parameterization matches using a
387414
# single-task GP.

0 commit comments

Comments
 (0)