Skip to content

Commit 1bdfdf3

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Fix __repr__ for tfd.GaussianProcess when index_points is None.
PiperOrigin-RevId: 464690048
1 parent a7a3e04 commit 1bdfdf3

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,22 @@ def _type_spec(self):
774774
def _convert_variables_to_tensors(self):
775775
return auto_composite_tensor.convert_variables_to_tensors(self)
776776

777+
def __repr__(self):
778+
if self.index_points is None:
779+
event_shape_str = '?'
780+
else:
781+
event_shape_str = distribution._str_tensorshape(self.event_shape)
782+
return ('<tfp.distributions.{type_name} '
783+
'\'{self_name}\''
784+
' batch_shape={batch_shape}'
785+
' event_shape={event_shape}'
786+
' dtype={dtype}>'.format(
787+
type_name=type(self).__name__,
788+
self_name=self.name or '<unknown>',
789+
batch_shape=distribution._str_tensorshape(self.batch_shape),
790+
event_shape=event_shape_str,
791+
dtype=distribution._str_dtype(self.dtype)))
792+
777793

778794
@auto_composite_tensor.type_spec_register(
779795
'tfp.distributions.GaussianProcess_ACTTypeSpec')

tensorflow_probability/python/distributions/gaussian_process_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def _kernel_fn(x, y):
213213
with self.assertRaises(ValueError):
214214
gp.mean()
215215

216+
self.assertIn("event_shape=?", repr(gp))
217+
self.assertIn("event_shape=[10]", repr(gp.copy(index_points=index_points)))
218+
216219
def testMarginalHasCorrectTypes(self):
217220
gp = tfd.GaussianProcess(
218221
kernel=psd_kernels.ExponentiatedQuadratic(), validate_args=True)

0 commit comments

Comments
 (0)