Skip to content

Commit 854abf1

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Minimal support for structured inputs to MultiTaskKernel.
PiperOrigin-RevId: 476460795
1 parent eeb6b48 commit 854abf1

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tensorflow_probability/python/experimental/psd_kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ multi_substrate_py_library(
7474
# numpy dep,
7575
# tensorflow dep,
7676
"//tensorflow_probability/python/internal:dtype_util",
77+
"//tensorflow_probability/python/internal:nest_util",
7778
"//tensorflow_probability/python/math/psd_kernels:positive_semidefinite_kernel",
7879
"//tensorflow_probability/python/math/psd_kernels/internal:util",
7980
],

tensorflow_probability/python/experimental/psd_kernels/multitask_kernel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow.compat.v2 as tf
2222

2323
from tensorflow_probability.python.internal import dtype_util
24+
from tensorflow_probability.python.internal import nest_util
2425
from tensorflow_probability.python.internal import parameter_properties
2526
from tensorflow_probability.python.internal import tensor_util
2627
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
@@ -140,8 +141,10 @@ def matrix_over_all_tasks(self, x1, x2, name='matrix_over_all_tasks'):
140141
# ==> [6, 6]
141142
"""
142143
with tf.name_scope(name):
143-
x1 = tf.convert_to_tensor(x1, name='x1', dtype_hint=self.dtype)
144-
x2 = tf.convert_to_tensor(x2, name='x2', dtype_hint=self.dtype)
144+
x1 = nest_util.convert_to_nested_tensor(
145+
x1, name='x1', dtype_hint=self.dtype, allow_packing=True)
146+
x2 = nest_util.convert_to_nested_tensor(
147+
x2, name='x2', dtype_hint=self.dtype, allow_packing=True)
145148
return self._matrix_over_all_tasks(x1, x2)
146149

147150
def _matrix_over_all_tasks(self, x1, x2):

0 commit comments

Comments
 (0)