Skip to content

Commit 7555d90

Browse files
sharadmvtensorflower-gardener
authored andcommitted
Simplify sharded JDs by removing reduce_over_shards
PiperOrigin-RevId: 375595276
1 parent aa1f360 commit 7555d90

File tree

4 files changed

+61
-51
lines changed

4 files changed

+61
-51
lines changed

tensorflow_probability/python/experimental/distribute/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionNamed
2222
from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionSequential
2323
from tensorflow_probability.python.experimental.distribute.sharded import Sharded
24+
from tensorflow_probability.python.internal.distribute_lib import make_pbroadcast_function
2425
from tensorflow_probability.python.internal.distribute_lib import make_psum_function
2526
from tensorflow_probability.python.internal.distribute_lib import make_sharded_log_prob_parts
2627

2728
__all__ = [
2829
'JointDistributionCoroutine',
2930
'JointDistributionNamed',
3031
'JointDistributionSequential',
32+
'make_pbroadcast_function',
33+
'make_psum_function',
3134
'make_sharded_log_prob_parts',
3235
'Sharded',
3336
]

tensorflow_probability/python/experimental/distribute/joint_distribution.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import functools
22-
2321
import tensorflow.compat.v2 as tf
2422
from tensorflow_probability.python import distributions as distribution_lib
2523
from tensorflow_probability.python.distributions import log_prob_ratio as lp_ratio
@@ -39,30 +37,23 @@ def _map_measure_over_dists(self, attr, value):
3937
if (attr in ('log_prob', 'unnormalized_log_prob')) and any(
4038
self.experimental_shard_axis_names):
4139

42-
def inner_log_prob_parts(flat_value):
43-
unflat_value = self._model_unflatten(flat_value)
40+
def inner_log_prob_parts(value):
4441
ds, xs = self._call_flat_sample_distributions(
45-
value=unflat_value, seed=samplers.zeros_seed())
46-
# For sharded distributions, we need to make sure not to do an
47-
# all-reduce.
48-
axis_names = self._model_flatten(self.experimental_shard_axis_names)
49-
log_prob_fns = [
50-
functools.partial(getattr(d, attr), reduce_over_shards=False)
51-
if axis_name else getattr(d, attr)
52-
for d, axis_name in zip(ds, axis_names)
53-
]
42+
value=value, seed=samplers.zeros_seed())
5443
# We need to flatten and unflatten here to ensure the output structure
5544
# matches `flat_sharded_distributions`.
56-
vals = self._model_unflatten(
57-
[log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)])
58-
return self._model_flatten(vals)
59-
60-
flat_value = self._model_flatten(value)
61-
flat_axis_names = self._model_flatten(self.experimental_shard_axis_names)
62-
flat_xs = distribute_lib.make_sharded_log_prob_parts(
63-
inner_log_prob_parts, flat_axis_names)(
64-
flat_value)
65-
return iter(flat_xs)
45+
return self._model_unflatten(
46+
[getattr(d, attr)(x) for d, x in zip(ds, xs)])
47+
48+
axis_names = self.experimental_shard_axis_names
49+
# Individual distributions will apply psum in their `log_prob` methods
50+
# so we need to pbroadcast `value` according to `axis_names` to provide
51+
# correct gradients. We are safe to add pbroadcasts to functions with
52+
# psums already in them.
53+
log_prob_parts = distribute_lib.make_pbroadcast_function(
54+
inner_log_prob_parts, (axis_names,), axis_names,
55+
out_dtype=value)(value)
56+
return iter(tf.nest.flatten(log_prob_parts))
6657
ds, xs = self._call_flat_sample_distributions(
6758
value=value, seed=samplers.zeros_seed())
6859
return (getattr(d, attr)(x) for d, x in zip(ds, xs))
@@ -104,16 +95,14 @@ def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
10495
def log_prob_ratio_parts_fn(x, y):
10596
p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
10697
q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
107-
# Ensure sharded distributions defer reductions.
108-
kwds = lambda a: {'reduce_over_shards': False} if a else {}
10998
return nest.map_structure_up_to(
11099
p_dists,
111-
lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)),
112-
p_dists, x, q_dists, y, p_axis_names)
100+
lp_ratio.log_prob_ratio,
101+
p_dists, x, q_dists, y)
113102

114103
return tf.add_n(
115104
tf.nest.flatten(
116-
distribute_lib.make_psum_function(
105+
distribute_lib.make_pbroadcast_function(
117106
log_prob_ratio_parts_fn,
118107
in_axes=(p_axis_names, p_axis_names),
119108
out_axes=p_axis_names,

tensorflow_probability/python/experimental/distribute/joint_distribution_test.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def run(key):
155155

156156
keys = tfp.random.split_seed(self.key, 2)
157157
samples = []
158+
unmapped_samples = []
158159
log_probs = []
159160
true_log_probs = []
160161

@@ -172,24 +173,51 @@ def run(key):
172173
true_log_prob = true_log_prob_fn(w, x, data)
173174

174175
samples.append(sample)
176+
unmapped_samples.append((w, x, data))
175177
log_probs.append(log_prob[0])
176178
true_log_probs.append(true_log_prob)
177179

180+
def true_diff(x, y):
181+
return true_log_prob_fn(*x) - true_log_prob_fn(*y)
182+
178183
def run_diff(x, y):
179-
return tfp.experimental.distributions.log_prob_ratio(dist, x, dist, y)
184+
def _lpr(x, y):
185+
return tfp.experimental.distributions.log_prob_ratio(dist, x, dist, y)
186+
return tfp.math.value_and_gradient(_lpr, [x, y])
180187

181-
dist_lp_diff = self.per_replica_to_tensor(
188+
dist_lp_diff, dist_lp_diff_grad = self.per_replica_to_tensor(
182189
self.strategy_run(
183190
run_diff, tuple(tf.nest.map_structure(self.shard_values, samples))))
184191

185-
true_lp_diff = true_log_probs[0] - true_log_probs[1]
192+
true_lp_diff, true_lp_diff_grad = tfp.math.value_and_gradient(
193+
true_diff, unmapped_samples)
194+
195+
if isinstance(dist, jd.JointDistributionNamed):
196+
dist_lp_diff_grad[0] = (
197+
dist_lp_diff_grad[0]['w'][0],
198+
dist_lp_diff_grad[0]['x'],
199+
dist_lp_diff_grad[0]['data'])
200+
dist_lp_diff_grad[1] = (
201+
dist_lp_diff_grad[1]['w'][0],
202+
dist_lp_diff_grad[1]['x'],
203+
dist_lp_diff_grad[1]['data'])
204+
else:
205+
true_lp_diff_grad[0] = list(true_lp_diff_grad[0])
206+
true_lp_diff_grad[1] = list(true_lp_diff_grad[1])
207+
dist_lp_diff_grad[0] = list(dist_lp_diff_grad[0])
208+
dist_lp_diff_grad[0][0] = dist_lp_diff_grad[0][0][0]
209+
dist_lp_diff_grad[1] = list(dist_lp_diff_grad[1])
210+
dist_lp_diff_grad[1][0] = dist_lp_diff_grad[1][0][0]
211+
186212
lp_diff = log_probs[0] - log_probs[1]
187213

188214
self.assertAllClose(
189-
self.evaluate(true_lp_diff), self.evaluate(lp_diff),
215+
true_lp_diff, lp_diff,
190216
rtol=7e-6) # relaxed tol for fp32 in JAX
191217
self.assertAllClose(
192-
self.evaluate(true_lp_diff), self.evaluate(dist_lp_diff[0]))
218+
true_lp_diff, dist_lp_diff[0])
219+
self.assertAllClose(
220+
true_lp_diff_grad, dist_lp_diff_grad)
193221

194222
def test_default_event_space_bijector_non_interacting(self):
195223

tensorflow_probability/python/experimental/distribute/sharded.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,9 @@
3535

3636
def _implement_sharded_lp_fn(fn_name):
3737
"""Implements log_prob or unnormalized_log_prob."""
38-
def lp_fn(self, x, reduce_over_shards=True, **kwargs):
39-
40-
new_kwargs = dict(kwargs)
41-
if self.distribution.experimental_shard_axis_names:
42-
new_kwargs['reduce_over_shards'] = reduce_over_shards
43-
lp = getattr(self.distribution, fn_name)(x, **new_kwargs)
44-
if reduce_over_shards:
45-
lp = distribute_lib.psum(lp, self.experimental_shard_axis_names)
46-
47-
return lp
38+
def lp_fn(self, x):
39+
lp = getattr(self.distribution, fn_name)(x)
40+
return distribute_lib.psum(lp, self.experimental_shard_axis_names)
4841

4942
lp_fn.__name__ = f'_{fn_name}'
5043
return lp_fn
@@ -181,7 +174,7 @@ def _default_event_space_bijector(self, *args, **kwargs):
181174

182175

183176
@log_prob_ratio.RegisterLogProbRatio(Sharded)
184-
def _sharded_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True):
177+
def _sharded_log_prob_ratio(p, x, q, y, name=None):
185178
"""Distributed log-prob ratio for Sharded."""
186179
with tf.name_scope(name or 'sharded_log_prob_ratio'):
187180
if p.experimental_shard_axis_names != q.experimental_shard_axis_names:
@@ -194,10 +187,7 @@ def log_prob_ratio_fn(x, y):
194187
return log_prob_ratio.log_prob_ratio(p.distribution, x,
195188
q.distribution, y)
196189

197-
if reduce_over_shards:
198-
axes = p.experimental_shard_axis_names
199-
200-
return distribute_lib.make_psum_function(
201-
log_prob_ratio_fn, in_axes=(axes, axes), out_axes=axes,
202-
out_dtype=x)(x, y)
203-
return log_prob_ratio_fn(x, y)
190+
axes = p.experimental_shard_axis_names
191+
return distribute_lib.make_psum_function(
192+
log_prob_ratio_fn, in_axes=(axes, axes), out_axes=axes,
193+
out_dtype=x)(x, y)

0 commit comments

Comments
 (0)