Skip to content

Commit 6b36b64

Browse files
davmretensorflower-gardener
authored andcommitted
Fix a couple of glitches with bijector batch shape.
1. ScaleMatvec* bijectors were missing annotations because I'd accidentally added them to the Block variant rather than to the base class. 2. Passing `np.int32` or `np.int64` values for event_ndims could lead to dtype mismatch errors. PiperOrigin-RevId: 379987638
1 parent 01139b8 commit 6b36b64

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ multi_substrate_py_library(
254254
":bijector",
255255
# tensorflow dep,
256256
"//tensorflow_probability/python/internal:dtype_util",
257+
"//tensorflow_probability/python/internal:parameter_properties",
257258
"//tensorflow_probability/python/internal:prefer_static",
258259
"//tensorflow_probability/python/internal:tensor_util",
259260
"//tensorflow_probability/python/internal:tensorshape_util",

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ class BijectorBatchShapesTest(test_util.TestCase):
371371
lambda: tfb.Scale(tf.ones([4, 2])), None),
372372
('sigmoid',
373373
lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None),
374+
('scale_matvec',
375+
lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None),
374376
('invert',
375377
lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None),
376378
('reshape',
@@ -416,6 +418,14 @@ def test_batch_shape_matches_output_shapes(self,
416418
self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y)
417419
self.assertAllEqual(batch_shape_tensor_x, batch_shape_x)
418420

421+
# Check that we're robust to integer type.
422+
batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor(
423+
x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims))
424+
batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor(
425+
y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims))
426+
self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64)
427+
self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x)
428+
419429
# Pushing a value through the bijector should return a Tensor(s) with
420430
# the expected batch shape...
421431
xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims)

tensorflow_probability/python/bijectors/scale_matvec_linear_operator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow_probability.python.bijectors import bijector
2424
from tensorflow_probability.python.internal import dtype_util
25+
from tensorflow_probability.python.internal import parameter_properties
2526
from tensorflow_probability.python.internal import prefer_static as ps
2627
from tensorflow_probability.python.internal import tensor_util
2728
from tensorflow_probability.python.internal import tensorshape_util
@@ -46,6 +47,10 @@ def adjoint(self):
4647
"""`bool` indicating whether this class uses `self.scale` or its adjoint."""
4748
return self._adjoint
4849

50+
@classmethod
51+
def _parameter_properties(cls, dtype):
52+
return dict(scale=parameter_properties.BatchedComponentProperties())
53+
4954
def _forward(self, x):
5055
return self.scale.matvec(x, adjoint=self.adjoint)
5156

@@ -237,10 +242,6 @@ def _inverse_event_shape_tensor(self, output_shape):
237242
return _cumulative_broadcast_dynamic(output_shape)
238243
return output_shape
239244

240-
@classmethod
241-
def _parameter_properties(cls, dtype):
242-
return {}
243-
244245

245246
def _cumulative_broadcast_static(event_shape):
246247
broadcast_shapes = [s[:-1] for s in event_shape]

tensorflow_probability/python/internal/batch_shape_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def get_base_shape(x):
137137

138138
def slice_batch_shape_tensor(base_shape, event_ndims):
139139
base_shape = ps.convert_to_shape_tensor(base_shape, dtype_hint=np.int32)
140+
event_ndims = ps.convert_to_shape_tensor(event_ndims, dtype_hint=np.int32)
140141
base_rank = ps.rank_from_shape(base_shape)
141142
return base_shape[:(base_rank -
142143
# Don't try to slice away more ndims than the parameter

0 commit comments

Comments
 (0)