Skip to content

Commit 6e9d1c0

Browse files
sharadmvtensorflower-gardener
authored andcommitted
Add support for more collectives and also handle integer axes
PiperOrigin-RevId: 380646068
1 parent e347851 commit 6e9d1c0

File tree

7 files changed

+238
-66
lines changed

7 files changed

+238
-66
lines changed

tensorflow_probability/python/experimental/bijectors/sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def _forward_log_det_jacobian(self, x, **kwargs):
7575

7676
return distribute_lib.psum(
7777
self.bijector.forward_log_det_jacobian(x, **kwargs),
78-
axis_name=self.shard_axis_name)
78+
named_axis=self.shard_axis_name)
7979

8080
def _inverse_log_det_jacobian(self, y, **kwargs):
8181

8282
return distribute_lib.psum(
8383
self.bijector.inverse_log_det_jacobian(y, **kwargs),
84-
axis_name=self.shard_axis_name)
84+
named_axis=self.shard_axis_name)

tensorflow_probability/python/experimental/distribute/joint_distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828

2929

3030
def pbroadcast_value(value, value_axis_names, output_axis_names):
31-
value_axis_names = distribute_lib.canonicalize_axis_name(value_axis_names)
31+
value_axis_names = distribute_lib.canonicalize_named_axis(value_axis_names)
3232
pbroadcast_axes = [
3333
axis_name for axis_name in output_axis_names
3434
if axis_name not in value_axis_names
3535
]
36-
return distribute_lib.pbroadcast(value, pbroadcast_axes)
36+
return distribute_lib.pbroadcast(value, named_axis=pbroadcast_axes)
3737

3838

3939
def _maybe_substitute_or_add_value_in_tuple(value_tuple, index, value):

tensorflow_probability/python/experimental/distribute/sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, distribution, shard_axis_name=None, validate_args=False,
9292
# Use inner axes before outer axes
9393
full_shard_axis_name = (
9494
distribution.experimental_shard_axis_names +
95-
distribute_lib.canonicalize_axis_name(shard_axis_name))
95+
distribute_lib.canonicalize_named_axis(shard_axis_name))
9696

9797
if not JAX_MODE:
9898
if len(full_shard_axis_name) > 1:

tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def chees_criterion(previous_state,
196196
"""
197197
batch_ndims = ps.rank(accept_prob)
198198
batch_axes = ps.range(batch_ndims, dtype=tf.int32)
199-
experimental_chain_axis_names = distribute_lib.canonicalize_axis_name(
199+
experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
200200
experimental_chain_axis_names)
201201
# Number of total chains is local batch size * distributed axis size
202202
local_axis_size = ps.maximum(ps.size(accept_prob), 1)
@@ -733,7 +733,7 @@ def adjust_state(x, v, shard_axes=None):
733733
trajectory_grad *= trajectory_jitter
734734

735735
# Weight by acceptance probability.
736-
experimental_chain_axis_names = distribute_lib.canonicalize_axis_name(
736+
experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
737737
experimental_chain_axis_names)
738738
trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
739739
trajectory_grad = tf.where(

tensorflow_probability/python/internal/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ multi_substrate_py_library(
112112
srcs_version = "PY3",
113113
deps = [
114114
":custom_gradient",
115+
":prefer_static",
115116
":samplers",
116117
# tensorflow dep,
117118
"//tensorflow_probability/python/math:gradient",
@@ -128,6 +129,7 @@ multi_substrate_py_test(
128129
":distribute_lib",
129130
":distribute_test_lib",
130131
":test_util",
132+
# absl/testing:parameterized dep,
131133
# tensorflow dep,
132134
"//tensorflow_probability",
133135
],

tensorflow_probability/python/internal/distribute_lib.py

Lines changed: 101 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import tensorflow.compat.v2 as tf
2525
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
26+
from tensorflow_probability.python.internal import prefer_static as ps
2627
from tensorflow_probability.python.internal import samplers
2728

2829
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
@@ -33,38 +34,104 @@
3334
from jax import lax # pylint: disable=g-import-not-at-top
3435

3536

36-
def canonicalize_axis_name(axis_name):
37-
"""Converts an input into a list of axis strings."""
38-
if not axis_name:
37+
def canonicalize_named_axis(named_axes):
38+
"""Converts an input into a list of named axis `str`s."""
39+
if named_axes is None:
3940
return []
40-
if (isinstance(axis_name, str) or
41-
not isinstance(axis_name, collections.Iterable)):
42-
return [axis_name]
43-
return list(axis_name)
41+
if (isinstance(named_axes, str) or
42+
not isinstance(named_axes, collections.Iterable)):
43+
named_axes = [named_axes]
44+
if len(named_axes) > 1 and not JAX_MODE:
45+
raise ValueError(
46+
f'TensorFlow backend does not support multiple shard axes: {named_axes}'
47+
)
48+
return list(named_axes)
4449

4550

46-
def psum(x, axis_name=None):
47-
axis_name = canonicalize_axis_name(axis_name)
48-
for name in axis_name:
49-
x = rwb_psum(x, name)
51+
def _make_reduce_op(tensor_reduce_fn, collective_reduce_fn):
52+
"""Makes an op that both reduces over both positional axes and named axes.
53+
54+
Assumes that the reducers are associative so we can rearrange the tensor and
55+
collective reduce's orders.
56+
57+
Args:
58+
tensor_reduce_fn: A function that reduces over the dimensions of a `Tensor`.
59+
`tensor_reduce_fn` should take in an `axis` keyword argument.
60+
collective_reduce_fn: A function that reduces over named axes.
61+
`collective_reduce_fn` should take in a `named_axis` keyword argument.
62+
63+
Returns:
64+
A reduced `Tensor`.
65+
"""
66+
67+
def reduce_fn(x, axis=None, named_axis=None, **kwargs):
68+
named_axis = canonicalize_named_axis(named_axis)
69+
x = tensor_reduce_fn(x, axis=axis, **kwargs)
70+
return collective_reduce_fn(x, named_axis=named_axis)
71+
72+
return reduce_fn
73+
74+
75+
def psum(x, named_axis=None):
76+
axes = canonicalize_named_axis(named_axis)
77+
for axis in axes:
78+
x = rwb_psum(x, axis)
79+
return x
80+
81+
82+
reduce_sum = _make_reduce_op(tf.reduce_sum, psum)
83+
84+
85+
def pbroadcast(x, named_axis=None):
86+
axes = canonicalize_named_axis(named_axis)
87+
for axis in axes:
88+
x = rwb_pbroadcast(x, axis)
5089
return x
5190

5291

53-
def pbroadcast(x, axis_name=None):
54-
axis_name = canonicalize_axis_name(axis_name)
92+
def pmean(x, named_axis=None):
93+
axes = canonicalize_named_axis(named_axis)
94+
for axis in axes:
95+
x = psum(x, named_axis=axis) / get_axis_size(axis)
96+
return x
97+
98+
99+
reduce_mean = _make_reduce_op(tf.reduce_mean, pmean)
100+
101+
102+
def pmax(x, named_axis=None):
103+
# TODO(b/187173243): fix gradients for pmax
104+
axes = canonicalize_named_axis(named_axis)
105+
for axis in axes:
106+
if not JAX_MODE:
107+
raise NotImplementedError('`pmax` not supported in TF')
108+
x = lax.pmax(x, axis)
109+
return x
110+
111+
112+
reduce_max = _make_reduce_op(tf.reduce_max, pmax)
113+
114+
115+
def pmin(x, named_axis=None):
116+
# TODO(b/187173243): fix gradients for pmin
117+
axis_name = canonicalize_named_axis(named_axis)
55118
for name in axis_name:
56-
x = rwb_pbroadcast(x, name)
119+
if not JAX_MODE:
120+
raise NotImplementedError('`pmax` not supported in TF')
121+
x = lax.pmin(x, name)
57122
return x
58123

59124

60-
def pmean(x, axis_name=None):
61-
if JAX_MODE:
62-
axis_name = canonicalize_axis_name(axis_name)
63-
for name in axis_name:
64-
x = lax.pmean(x, name)
65-
return x
66-
ctx = tf.distribute.get_replica_context()
67-
return ctx.all_reduce('mean', x)
125+
reduce_min = _make_reduce_op(tf.reduce_min, pmin)
126+
127+
128+
def reduce_logsumexp(x, axis=None, named_axis=None, **kwargs):
129+
xmax = reduce_max(
130+
tf.stop_gradient(x), axis=axis, named_axis=named_axis, keepdims=True)
131+
xmax = tf.where(tf.is_finite(xmax), xmax, tf.zeros_like(xmax))
132+
result = tf.log(
133+
reduce_sum(tf.exp(x - xmax), axis=axis, named_axis=named_axis), **kwargs)
134+
return tf.reshape(xmax, ps.shape(result)) + result
68135

69136

70137
def get_axis_index(axis_name=None):
@@ -83,7 +150,7 @@ def get_axis_size(axis_name=None):
83150

84151
def _rwb_psum_fwd(x, axis_name):
85152
if JAX_MODE:
86-
axis_name = canonicalize_axis_name(axis_name)
153+
axis_name = canonicalize_named_axis(axis_name)
87154
out = lax.psum(x, axis_name)
88155
else:
89156
ctx = tf.distribute.get_replica_context()
@@ -100,13 +167,15 @@ def fold_in_axis_index(seed, axis_name=None):
100167
if axis_name is None:
101168
return seed
102169
nest.assert_shallow_structure(seed, axis_name)
103-
axis_names = nest.map_structure_up_to(
104-
seed, canonicalize_axis_name, axis_name)
170+
axis_names = nest.map_structure_up_to(seed, canonicalize_named_axis,
171+
axis_name)
172+
105173
def fold_in(seed, axes):
106174
for name in axes:
107175
axis_index = get_axis_index(name)
108176
seed = samplers.fold_in(seed, tf.cast(axis_index, tf.int32))
109177
return seed
178+
110179
return nest.map_structure_up_to(seed, fold_in, seed, axis_names)
111180

112181

@@ -121,6 +190,7 @@ def rwb_psum(x, axis_name):
121190
Args:
122191
x: a `Tensor` target for the psum.
123192
axis_name: A string axis name for the psum.
193+
124194
Returns:
125195
A `Tensor` that is the result of applying a psum to an input `Tensor`.
126196
"""
@@ -161,8 +231,8 @@ def make_pbroadcast_function(fn, in_axes, out_axes, out_dtype):
161231
value w.r.t. the input value will be psum-ed over the axes present in the
162232
output but not the input.
163233
out_axes: A structure of axis names that should match the structure of the
164-
output of `fn`. The inputs to `fn` will be pbroadcast-ed before
165-
computing output terms according to their output axes.
234+
output of `fn`. The inputs to `fn` will be pbroadcast-ed before computing
235+
output terms according to their output axes.
166236
out_dtype: A structure of dtypes that matches the output of `fn`.
167237
168238
Returns:
@@ -176,9 +246,9 @@ def make_pbroadcast_function(fn, in_axes, out_axes, out_dtype):
176246
def pbroadcast_fn(*args):
177247
nest.assert_shallow_structure(args, in_axes)
178248
nest.assert_shallow_structure(out_dtype, out_axes)
179-
map_in_axes = nest.map_structure_up_to(args, canonicalize_axis_name,
249+
map_in_axes = nest.map_structure_up_to(args, canonicalize_named_axis,
180250
in_axes)
181-
map_out_axes = nest.map_structure_up_to(out_dtype, canonicalize_axis_name,
251+
map_out_axes = nest.map_structure_up_to(out_dtype, canonicalize_named_axis,
182252
out_axes)
183253

184254
def _pbroadcast_input(out_axes, x, in_axes):
@@ -232,14 +302,14 @@ def make_psum_function(fn, in_axes, out_axes, out_dtype):
232302
function and corrects the gradient with respect to its inputs.
233303
"""
234304

235-
out_axes = nest.map_structure_up_to(out_dtype, canonicalize_axis_name,
305+
out_axes = nest.map_structure_up_to(out_dtype, canonicalize_named_axis,
236306
out_axes)
237307

238308
def psum_fn(*args):
239309
out = make_pbroadcast_function(fn, in_axes, out_axes, out_dtype)(*args)
240310

241311
def _psum_output(x, out_axis):
242-
return psum(x, out_axis)
312+
return psum(x, named_axis=out_axis)
243313

244314
return nest.map_structure_up_to(out_dtype, _psum_output, out, out_axes)
245315

0 commit comments

Comments
 (0)