Skip to content

Commit fea149f

Browse files
gneculatensorflower-gardener
authored andcommitted
Change TFP code to point to tfxla.variadic_reduce_v2.
PiperOrigin-RevId: 385176212
1 parent 784bcaa commit fea149f

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

tensorflow_probability/python/internal/variadic_reduce.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,20 @@ def make_variadic_reduce(reducer, vjp_bwd, tangents_fn):
123123
def _xla_reduce(operands, inits, axis):
124124
"""JIT-ed wrapper for TF `xla.variadic_reduce(..., reducer)`."""
125125
from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
126-
result = xla.variadic_reduce(
127-
operands,
128-
init_value=inits,
129-
dimensions_to_reduce=axis,
130-
reducer=tf.function(reducer).get_concrete_function(inits, inits))
126+
try:
127+
result = xla.variadic_reduce(
128+
operands,
129+
init_values=inits,
130+
dimensions_to_reduce=axis,
131+
reducer=tf.function(reducer).get_concrete_function(inits, inits))
132+
except TypeError:
133+
# TODO(necula): This is needed only temporarily, until cl/384088502
134+
# makes it in tf-nightly.
135+
result = xla.variadic_reduce_v2(
136+
operands,
137+
init_values=inits,
138+
dimensions_to_reduce=axis,
139+
reducer=tf.function(reducer).get_concrete_function(inits, inits))
131140
# Graph mode: variadic reduce doesn't specify output shapes. Patch that.
132141
shp = operands[0].shape
133142
for arg in operands:

0 commit comments

Comments
 (0)