26
26
from tensorflow_probability .python .bijectors import bijector
27
27
from tensorflow_probability .python .internal import assert_util
28
28
from tensorflow_probability .python .internal import parameter_properties
29
- from tensorflow_probability .python .internal import prefer_static
29
+ from tensorflow_probability .python .internal import prefer_static as ps
30
30
from tensorflow_probability .python .internal import tensor_util
31
31
from tensorflow_probability .python .internal import tensorshape_util
32
32
@@ -102,7 +102,7 @@ def __init__(
102
102
if static_axis >= 0 :
103
103
raise ValueError ('`axis` must be negative. Got {}' .format (axis ))
104
104
105
- self ._axis = tf . convert_to_tensor (axis , tf .int32 )
105
+ self ._axis = ps . convert_to_shape_tensor (axis , tf .int32 )
106
106
107
107
super (Split , self ).__init__ (
108
108
forward_min_event_ndims = - axis ,
@@ -154,7 +154,7 @@ def _inverse(self, y):
154
154
assertions = []
155
155
else :
156
156
assertions = self ._validate_output_shape_tensors (
157
- [prefer_static .shape (y_ ) for y_ in y ])
157
+ [ps .shape (y_ ) for y_ in y ])
158
158
159
159
with tf .control_dependencies (assertions ):
160
160
return tf .concat (y , axis = self .axis )
@@ -179,7 +179,7 @@ def _forward(self, x):
179
179
if is_validated or not self .validate_args :
180
180
assertions = []
181
181
else :
182
- assertions = self ._validate_input_shape_tensor (prefer_static .shape (x ))
182
+ assertions = self ._validate_input_shape_tensor (ps .shape (x ))
183
183
184
184
with tf .control_dependencies (assertions ):
185
185
if self .split_sizes is None :
@@ -277,7 +277,7 @@ def _forward_event_shape_tensor(self, input_shape):
277
277
# Each element of the `output_shape_tensor` list is equal to the
278
278
# `input_shape`, with the corresponding element of `split_sizes`
279
279
# substituted in the `axis` position.
280
- positive_axis = prefer_static .rank_from_shape (input_shape ) + self .axis
280
+ positive_axis = ps .rank_from_shape (input_shape ) + self .axis
281
281
tiled_input_shape = tf .tile (
282
282
input_shape [tf .newaxis , :], [self .num_splits , 1 ])
283
283
fused_output_shapes = tf .concat ([
@@ -314,7 +314,7 @@ def _inverse_event_shape_tensor(self, output_shapes):
314
314
total_size = tf .reduce_sum ([t [self .axis ] for t in output_shapes ])
315
315
inverse_event_shape = tf .tensor_scatter_nd_update (
316
316
output_shapes [0 ],
317
- [[prefer_static .rank_from_shape (output_shapes [0 ]) + self .axis ]],
317
+ [[ps .rank_from_shape (output_shapes [0 ]) + self .axis ]],
318
318
[total_size ])
319
319
return tf .identity (tf .convert_to_tensor (
320
320
inverse_event_shape , dtype_hint = tf .int32 ,
@@ -458,7 +458,7 @@ def _validate_input_shape(self, input_shape):
458
458
459
459
def _validate_input_shape_tensor (self , input_shape ):
460
460
input_dim = tf .gather (
461
- input_shape , [prefer_static .rank_from_shape (input_shape ) + self .axis ])
461
+ input_shape , [ps .rank_from_shape (input_shape ) + self .axis ])
462
462
if self .split_sizes is None :
463
463
return [assert_util .assert_equal (
464
464
0 ,
@@ -505,7 +505,7 @@ def _validate_output_shape_tensors(self, output_shapes):
505
505
for i , shape in enumerate (output_shapes ):
506
506
output_size = tf .gather (
507
507
shape ,
508
- [prefer_static .rank_from_shape (output_shapes [0 ]) + self .axis ])
508
+ [ps .rank_from_shape (output_shapes [0 ]) + self .axis ])
509
509
split_size = split_sizes [i ]
510
510
assertions .append (
511
511
assert_util .assert_equal (
0 commit comments