23
23
24
24
import tensorflow .compat .v2 as tf
25
25
from tensorflow_probability .python .internal import custom_gradient as tfp_custom_gradient
26
+ from tensorflow_probability .python .internal import prefer_static as ps
26
27
from tensorflow_probability .python .internal import samplers
27
28
28
29
from tensorflow .python .util import nest # pylint: disable=g-direct-tensorflow-import
33
34
from jax import lax # pylint: disable=g-import-not-at-top
34
35
35
36
36
- def canonicalize_axis_name ( axis_name ):
37
- """Converts an input into a list of axis strings ."""
38
- if not axis_name :
37
+ def canonicalize_named_axis ( named_axes ):
38
+ """Converts an input into a list of named axis `str`s ."""
39
+ if named_axes is None :
39
40
return []
40
- if (isinstance (axis_name , str ) or
41
- not isinstance (axis_name , collections .Iterable )):
42
- return [axis_name ]
43
- return list (axis_name )
41
+ if (isinstance (named_axes , str ) or
42
+ not isinstance (named_axes , collections .Iterable )):
43
+ named_axes = [named_axes ]
44
+ if len (named_axes ) > 1 and not JAX_MODE :
45
+ raise ValueError (
46
+ f'TensorFlow backend does not support multiple shard axes: { named_axes } '
47
+ )
48
+ return list (named_axes )
44
49
45
50
46
- def psum (x , axis_name = None ):
47
- axis_name = canonicalize_axis_name (axis_name )
48
- for name in axis_name :
49
- x = rwb_psum (x , name )
51
+ def _make_reduce_op (tensor_reduce_fn , collective_reduce_fn ):
52
+ """Makes an op that both reduces over both positional axes and named axes.
53
+
54
+ Assumes that the reducers are associative so we can rearrange the tensor and
55
+ collective reduce's orders.
56
+
57
+ Args:
58
+ tensor_reduce_fn: A function that reduces over the dimensions of a `Tensor`.
59
+ `tensor_reduce_fn` should take in an `axis` keyword argument.
60
+ collective_reduce_fn: A function that reduces over named axes.
61
+ `collective_reduce_fn` should take in a `named_axis` keyword argument.
62
+
63
+ Returns:
64
+ A reduced `Tensor`.
65
+ """
66
+
67
+ def reduce_fn (x , axis = None , named_axis = None , ** kwargs ):
68
+ named_axis = canonicalize_named_axis (named_axis )
69
+ x = tensor_reduce_fn (x , axis = axis , ** kwargs )
70
+ return collective_reduce_fn (x , named_axis = named_axis )
71
+
72
+ return reduce_fn
73
+
74
+
75
+ def psum (x , named_axis = None ):
76
+ axes = canonicalize_named_axis (named_axis )
77
+ for axis in axes :
78
+ x = rwb_psum (x , axis )
79
+ return x
80
+
81
+
82
+ reduce_sum = _make_reduce_op (tf .reduce_sum , psum )
83
+
84
+
85
+ def pbroadcast (x , named_axis = None ):
86
+ axes = canonicalize_named_axis (named_axis )
87
+ for axis in axes :
88
+ x = rwb_pbroadcast (x , axis )
50
89
return x
51
90
52
91
53
- def pbroadcast (x , axis_name = None ):
54
- axis_name = canonicalize_axis_name (axis_name )
92
+ def pmean (x , named_axis = None ):
93
+ axes = canonicalize_named_axis (named_axis )
94
+ for axis in axes :
95
+ x = psum (x , named_axis = axis ) / get_axis_size (axis )
96
+ return x
97
+
98
+
99
+ reduce_mean = _make_reduce_op (tf .reduce_mean , pmean )
100
+
101
+
102
+ def pmax (x , named_axis = None ):
103
+ # TODO(b/187173243): fix gradients for pmax
104
+ axes = canonicalize_named_axis (named_axis )
105
+ for axis in axes :
106
+ if not JAX_MODE :
107
+ raise NotImplementedError ('`pmax` not supported in TF' )
108
+ x = lax .pmax (x , axis )
109
+ return x
110
+
111
+
112
+ reduce_max = _make_reduce_op (tf .reduce_max , pmax )
113
+
114
+
115
+ def pmin (x , named_axis = None ):
116
+ # TODO(b/187173243): fix gradients for pmin
117
+ axis_name = canonicalize_named_axis (named_axis )
55
118
for name in axis_name :
56
- x = rwb_pbroadcast (x , name )
119
+ if not JAX_MODE :
120
+ raise NotImplementedError ('`pmax` not supported in TF' )
121
+ x = lax .pmin (x , name )
57
122
return x
58
123
59
124
60
- def pmean (x , axis_name = None ):
61
- if JAX_MODE :
62
- axis_name = canonicalize_axis_name (axis_name )
63
- for name in axis_name :
64
- x = lax .pmean (x , name )
65
- return x
66
- ctx = tf .distribute .get_replica_context ()
67
- return ctx .all_reduce ('mean' , x )
125
+ reduce_min = _make_reduce_op (tf .reduce_min , pmin )
126
+
127
+
128
+ def reduce_logsumexp (x , axis = None , named_axis = None , ** kwargs ):
129
+ xmax = reduce_max (
130
+ tf .stop_gradient (x ), axis = axis , named_axis = named_axis , keepdims = True )
131
+ xmax = tf .where (tf .is_finite (xmax ), xmax , tf .zeros_like (xmax ))
132
+ result = tf .log (
133
+ reduce_sum (tf .exp (x - xmax ), axis = axis , named_axis = named_axis ), ** kwargs )
134
+ return tf .reshape (xmax , ps .shape (result )) + result
68
135
69
136
70
137
def get_axis_index (axis_name = None ):
@@ -83,7 +150,7 @@ def get_axis_size(axis_name=None):
83
150
84
151
def _rwb_psum_fwd (x , axis_name ):
85
152
if JAX_MODE :
86
- axis_name = canonicalize_axis_name (axis_name )
153
+ axis_name = canonicalize_named_axis (axis_name )
87
154
out = lax .psum (x , axis_name )
88
155
else :
89
156
ctx = tf .distribute .get_replica_context ()
@@ -100,13 +167,15 @@ def fold_in_axis_index(seed, axis_name=None):
100
167
if axis_name is None :
101
168
return seed
102
169
nest .assert_shallow_structure (seed , axis_name )
103
- axis_names = nest .map_structure_up_to (
104
- seed , canonicalize_axis_name , axis_name )
170
+ axis_names = nest .map_structure_up_to (seed , canonicalize_named_axis ,
171
+ axis_name )
172
+
105
173
def fold_in (seed , axes ):
106
174
for name in axes :
107
175
axis_index = get_axis_index (name )
108
176
seed = samplers .fold_in (seed , tf .cast (axis_index , tf .int32 ))
109
177
return seed
178
+
110
179
return nest .map_structure_up_to (seed , fold_in , seed , axis_names )
111
180
112
181
@@ -121,6 +190,7 @@ def rwb_psum(x, axis_name):
121
190
Args:
122
191
x: a `Tensor` target for the psum.
123
192
axis_name: A string axis name for the psum.
193
+
124
194
Returns:
125
195
A `Tensor` that is the result of applying a psum to an input `Tensor`.
126
196
"""
@@ -161,8 +231,8 @@ def make_pbroadcast_function(fn, in_axes, out_axes, out_dtype):
161
231
value w.r.t. the input value will be psum-ed over the axes present in the
162
232
output but not the input.
163
233
out_axes: A structure of axis names that should match the structure of the
164
- output of `fn`. The inputs to `fn` will be pbroadcast-ed before
165
- computing output terms according to their output axes.
234
+ output of `fn`. The inputs to `fn` will be pbroadcast-ed before computing
235
+ output terms according to their output axes.
166
236
out_dtype: A structure of dtypes that matches the output of `fn`.
167
237
168
238
Returns:
@@ -176,9 +246,9 @@ def make_pbroadcast_function(fn, in_axes, out_axes, out_dtype):
176
246
def pbroadcast_fn (* args ):
177
247
nest .assert_shallow_structure (args , in_axes )
178
248
nest .assert_shallow_structure (out_dtype , out_axes )
179
- map_in_axes = nest .map_structure_up_to (args , canonicalize_axis_name ,
249
+ map_in_axes = nest .map_structure_up_to (args , canonicalize_named_axis ,
180
250
in_axes )
181
- map_out_axes = nest .map_structure_up_to (out_dtype , canonicalize_axis_name ,
251
+ map_out_axes = nest .map_structure_up_to (out_dtype , canonicalize_named_axis ,
182
252
out_axes )
183
253
184
254
def _pbroadcast_input (out_axes , x , in_axes ):
@@ -232,14 +302,14 @@ def make_psum_function(fn, in_axes, out_axes, out_dtype):
232
302
function and corrects the gradient with respect to its inputs.
233
303
"""
234
304
235
- out_axes = nest .map_structure_up_to (out_dtype , canonicalize_axis_name ,
305
+ out_axes = nest .map_structure_up_to (out_dtype , canonicalize_named_axis ,
236
306
out_axes )
237
307
238
308
def psum_fn (* args ):
239
309
out = make_pbroadcast_function (fn , in_axes , out_axes , out_dtype )(* args )
240
310
241
311
def _psum_output (x , out_axis ):
242
- return psum (x , out_axis )
312
+ return psum (x , named_axis = out_axis )
243
313
244
314
return nest .map_structure_up_to (out_dtype , _psum_output , out , out_axes )
245
315
0 commit comments