Skip to content

Commit fa952b3

Browse files
gneculatensorflower-gardener
authored andcommitted
Finish transition of tfxla.variadic_reduce to point to XlaVariadicReduceV2.
PiperOrigin-RevId: 386538798
1 parent e7a9c5c commit fa952b3

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

tensorflow_probability/python/internal/variadic_reduce.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,11 @@ def make_variadic_reduce(reducer, vjp_bwd, tangents_fn):
123123
def _xla_reduce(operands, inits, axis):
124124
"""JIT-ed wrapper for TF `xla.variadic_reduce(..., reducer)`."""
125125
from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
126-
try:
127-
result = xla.variadic_reduce(
128-
operands,
129-
init_values=inits,
130-
dimensions_to_reduce=axis,
131-
reducer=tf.function(reducer).get_concrete_function(inits, inits))
132-
except TypeError:
133-
# TODO(necula): This is needed only temporarily, until cl/384088502
134-
# makes it in tf-nightly.
135-
result = xla.variadic_reduce_v2(
136-
operands,
137-
init_values=inits,
138-
dimensions_to_reduce=axis,
139-
reducer=tf.function(reducer).get_concrete_function(inits, inits))
126+
result = xla.variadic_reduce(
127+
operands,
128+
init_values=inits,
129+
dimensions_to_reduce=axis,
130+
reducer=tf.function(reducer).get_concrete_function(inits, inits))
140131
# Graph mode: variadic reduce doesn't specify output shapes. Patch that.
141132
shp = operands[0].shape
142133
for arg in operands:

0 commit comments

Comments
 (0)