Skip to content

Commit 77b7d3c

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Some JAX omnistaging fixes for tfb.Split
PiperOrigin-RevId: 379342851
1 parent 636693f commit 77b7d3c

File tree

1 file changed

+8
-8
lines changed
  • tensorflow_probability/python/bijectors

1 file changed

+8
-8
lines changed

tensorflow_probability/python/bijectors/split.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tensorflow_probability.python.bijectors import bijector
2727
from tensorflow_probability.python.internal import assert_util
2828
from tensorflow_probability.python.internal import parameter_properties
29-
from tensorflow_probability.python.internal import prefer_static
29+
from tensorflow_probability.python.internal import prefer_static as ps
3030
from tensorflow_probability.python.internal import tensor_util
3131
from tensorflow_probability.python.internal import tensorshape_util
3232

@@ -102,7 +102,7 @@ def __init__(
102102
if static_axis >= 0:
103103
raise ValueError('`axis` must be negative. Got {}'.format(axis))
104104

105-
self._axis = tf.convert_to_tensor(axis, tf.int32)
105+
self._axis = ps.convert_to_shape_tensor(axis, tf.int32)
106106

107107
super(Split, self).__init__(
108108
forward_min_event_ndims=-axis,
@@ -154,7 +154,7 @@ def _inverse(self, y):
154154
assertions = []
155155
else:
156156
assertions = self._validate_output_shape_tensors(
157-
[prefer_static.shape(y_) for y_ in y])
157+
[ps.shape(y_) for y_ in y])
158158

159159
with tf.control_dependencies(assertions):
160160
return tf.concat(y, axis=self.axis)
@@ -179,7 +179,7 @@ def _forward(self, x):
179179
if is_validated or not self.validate_args:
180180
assertions = []
181181
else:
182-
assertions = self._validate_input_shape_tensor(prefer_static.shape(x))
182+
assertions = self._validate_input_shape_tensor(ps.shape(x))
183183

184184
with tf.control_dependencies(assertions):
185185
if self.split_sizes is None:
@@ -277,7 +277,7 @@ def _forward_event_shape_tensor(self, input_shape):
277277
# Each element of the `output_shape_tensor` list is equal to the
278278
# `input_shape`, with the corresponding element of `split_sizes`
279279
# substituted in the `axis` position.
280-
positive_axis = prefer_static.rank_from_shape(input_shape) + self.axis
280+
positive_axis = ps.rank_from_shape(input_shape) + self.axis
281281
tiled_input_shape = tf.tile(
282282
input_shape[tf.newaxis, :], [self.num_splits, 1])
283283
fused_output_shapes = tf.concat([
@@ -314,7 +314,7 @@ def _inverse_event_shape_tensor(self, output_shapes):
314314
total_size = tf.reduce_sum([t[self.axis] for t in output_shapes])
315315
inverse_event_shape = tf.tensor_scatter_nd_update(
316316
output_shapes[0],
317-
[[prefer_static.rank_from_shape(output_shapes[0]) + self.axis]],
317+
[[ps.rank_from_shape(output_shapes[0]) + self.axis]],
318318
[total_size])
319319
return tf.identity(tf.convert_to_tensor(
320320
inverse_event_shape, dtype_hint=tf.int32,
@@ -458,7 +458,7 @@ def _validate_input_shape(self, input_shape):
458458

459459
def _validate_input_shape_tensor(self, input_shape):
460460
input_dim = tf.gather(
461-
input_shape, [prefer_static.rank_from_shape(input_shape) + self.axis])
461+
input_shape, [ps.rank_from_shape(input_shape) + self.axis])
462462
if self.split_sizes is None:
463463
return [assert_util.assert_equal(
464464
0,
@@ -505,7 +505,7 @@ def _validate_output_shape_tensors(self, output_shapes):
505505
for i, shape in enumerate(output_shapes):
506506
output_size = tf.gather(
507507
shape,
508-
[prefer_static.rank_from_shape(output_shapes[0]) + self.axis])
508+
[ps.rank_from_shape(output_shapes[0]) + self.axis])
509509
split_size = split_sizes[i]
510510
assertions.append(
511511
assert_util.assert_equal(

0 commit comments

Comments
 (0)