Skip to content

Commit 6c29df8

Browse files
sharadmvtensorflower-gardener
authored andcommitted
Use pbroadcast while executing sharded JDs to ensure proper gradients for
both sample and log_prob. PiperOrigin-RevId: 377414022
1 parent 5d1fd40 commit 6c29df8

File tree

3 files changed

+162
-50
lines changed

3 files changed

+162
-50
lines changed

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ def _execute_model(self,
753753
sample_shape=(),
754754
seed=None,
755755
value=None,
756+
stop_index=None,
756757
sample_and_trace_fn=trace_distributions_and_values):
757758
"""Executes `model`, creating both samples and distributions."""
758759
values_out = []
@@ -833,6 +834,8 @@ def _execute_model(self,
833834
values_out.append(traced_values)
834835

835836
index += 1
837+
if stop_index is not None and index == stop_index:
838+
break
836839
d = gen.send(next_value)
837840
except StopIteration:
838841
pass
@@ -1205,7 +1208,6 @@ def _jd_log_prob_ratio(p, x, q, y, name=None):
12051208
ps, _ = p.sample_distributions(value=x, seed=samplers.zeros_seed())
12061209
qs, _ = q.sample_distributions(value=y, seed=samplers.zeros_seed())
12071210
tf.nest.assert_same_structure(ps, qs)
1208-
parts = []
1209-
for p_, x_, q_, y_ in zip(ps, x, qs, y):
1210-
parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_))
1211-
return tf.add_n(parts)
1211+
log_prob_ratio_parts = nest.map_structure_up_to(
1212+
ps, log_prob_ratio.log_prob_ratio, ps, x, qs, y)
1213+
return tf.add_n(tf.nest.flatten(log_prob_ratio_parts))

tensorflow_probability/python/experimental/distribute/joint_distribution.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,86 @@
2020

2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python import distributions as distribution_lib
23+
from tensorflow_probability.python.distributions import joint_distribution as jd_lib
2324
from tensorflow_probability.python.distributions import log_prob_ratio as lp_ratio
2425
from tensorflow_probability.python.internal import distribute_lib
25-
from tensorflow_probability.python.internal import samplers
2626

27-
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
27+
28+
def pbroadcast_value(value, value_axis_names, output_axis_names):
29+
value_axis_names = distribute_lib.canonicalize_axis_name(value_axis_names)
30+
pbroadcast_axes = [
31+
axis_name for axis_name in output_axis_names
32+
if axis_name not in value_axis_names
33+
]
34+
return distribute_lib.pbroadcast(value, pbroadcast_axes)
35+
36+
37+
def _maybe_substitute_or_add_value_in_tuple(value_tuple, index, value):
38+
if index > len(value_tuple):
39+
raise ValueError('Cannot add value to tuple without available slot.')
40+
if index == len(value_tuple):
41+
return value_tuple + (value,)
42+
curr_value = value_tuple[index]
43+
if curr_value is not None:
44+
return value_tuple
45+
return value_tuple[:index] + (value,) + value_tuple[index + 1:]
2846

2947

3048
class JointDistributionDistributedMixin(object):
3149
"""A JDMixin that shards the log_prob calculation."""
3250

33-
def _map_measure_over_dists(self, attr, value):
34-
"""Override the default implementation to shard its log_prob calculation."""
35-
if any(x is None for x in tf.nest.flatten(value)):
36-
raise ValueError('No `value` part can be `None`; saw: {}.'.format(value))
37-
if (attr in ('log_prob', 'unnormalized_log_prob')) and any(
38-
self.experimental_shard_axis_names):
39-
40-
def inner_log_prob_parts(value):
41-
ds, xs = self._call_flat_sample_distributions(
42-
value=value, seed=samplers.zeros_seed())
43-
# We need to flatten and unflatten here to ensure the output structure
44-
# matches `flat_sharded_distributions`.
45-
return self._model_unflatten(
46-
[getattr(d, attr)(x) for d, x in zip(ds, xs)])
47-
48-
axis_names = self.experimental_shard_axis_names
49-
# Individual distributions will apply psum in their `log_prob` methods
50-
# so we need to pbroadcast `value` according to `axis_names` to provide
51-
# correct gradients. We are safe to add pbroadcasts to functions with
52-
# psums already in them.
53-
log_prob_parts = distribute_lib.make_pbroadcast_function(
54-
inner_log_prob_parts, (axis_names,), axis_names,
55-
out_dtype=value)(value)
56-
return iter(tf.nest.flatten(log_prob_parts))
57-
ds, xs = self._call_flat_sample_distributions(
58-
value=value, seed=samplers.zeros_seed())
59-
return (getattr(d, attr)(x) for d, x in zip(ds, xs))
51+
def _call_execute_model(
52+
self,
53+
sample_shape=(),
54+
seed=None,
55+
value=None,
56+
sample_and_trace_fn=jd_lib.trace_distributions_and_values):
57+
return self._distribute_execute_model(
58+
sample_shape=sample_shape,
59+
seed=seed,
60+
value=value if value is None else self._model_flatten(value),
61+
sample_and_trace_fn=sample_and_trace_fn)
62+
63+
def _distribute_execute_model(
64+
self,
65+
sample_shape=(),
66+
seed=None,
67+
value=None,
68+
sample_and_trace_fn=jd_lib.trace_distributions_and_values):
69+
"""Executes a model, adding `pbroadcasts` to ensure correct gradients."""
70+
shard_axis_names = self._model_flatten(self.experimental_shard_axis_names)
71+
final_values_out = []
72+
if value is None:
73+
value = ()
74+
75+
def sample_and_trace_value_fn(dist,
76+
sample_shape,
77+
seed,
78+
value=None):
79+
value, traced = sample_and_trace_fn(
80+
dist=dist, sample_shape=sample_shape, seed=seed, value=value)
81+
# We trace `next_value` here so we can pass it back in as part of `value`
82+
# in the next iteration of the coroutine.
83+
return value, (value, traced)
84+
85+
for output_index, output_axes in enumerate(shard_axis_names):
86+
# We pbroadcast all values according to the difference between the current
87+
# `output_axes` and their own active axes.
88+
previous_shard_axes = shard_axis_names[:len(value)]
89+
pbroadcasted_value = tuple(
90+
pbroadcast_value(v, v_axis_names, output_axes)
91+
for v, v_axis_names in zip(value, previous_shard_axes)
92+
)
93+
pbroadcasted_values, traced_values = zip(*super()._execute_model(
94+
sample_shape=sample_shape,
95+
seed=seed,
96+
value=pbroadcasted_value + (None,),
97+
stop_index=output_index + 1,
98+
sample_and_trace_fn=sample_and_trace_value_fn))
99+
value = _maybe_substitute_or_add_value_in_tuple(
100+
value, output_index, pbroadcasted_values[output_index])
101+
final_values_out.append(traced_values[output_index])
102+
return final_values_out
60103

61104

62105
class JointDistributionSequential(JointDistributionDistributedMixin,
@@ -91,19 +134,4 @@ def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
91134
if p_axis_names != q_axis_names:
92135
raise ValueError('p and q must use the same sharding. '
93136
f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')
94-
95-
def log_prob_ratio_parts_fn(x, y):
96-
p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
97-
q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
98-
return nest.map_structure_up_to(
99-
p_dists,
100-
lp_ratio.log_prob_ratio,
101-
p_dists, x, q_dists, y)
102-
103-
return tf.add_n(
104-
tf.nest.flatten(
105-
distribute_lib.make_pbroadcast_function(
106-
log_prob_ratio_parts_fn,
107-
in_axes=(p_axis_names, p_axis_names),
108-
out_axes=p_axis_names,
109-
out_dtype=x)(x, y)))
137+
return jd_lib._jd_log_prob_ratio(p, x, q, y, name=name) # pylint: disable=protected-access

tensorflow_probability/python/experimental/distribute/joint_distribution_test.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
tfb = tfp.bijectors
3131
tfd = tfp.distributions
3232

33+
Root = tfd.JointDistributionCoroutine.Root
34+
3335

3436
def true_log_prob_fn(w, x, data):
3537
return (tfd.Normal(0., 1.).log_prob(w) +
@@ -65,7 +67,7 @@ def make_jd_named(axis_name):
6567
def make_jd_coroutine(axis_name):
6668

6769
def model_coroutine():
68-
w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.))
70+
w = yield Root(tfd.Normal(0., 1.))
6971
x = yield sharded.Sharded(
7072
tfd.Sample(tfd.Normal(w, 1.), 1), shard_axis_name=axis_name)
7173
yield sharded.Sharded(
@@ -219,6 +221,86 @@ def _lpr(x, y):
219221
self.assertAllClose(
220222
true_lp_diff_grad, dist_lp_diff_grad)
221223

224+
def test_jd_has_correct_sample_path_gradients(self):
225+
226+
def log_prob_fn(x_loc):
227+
@tfd.JointDistributionCoroutine
228+
def surrogate():
229+
x = yield Root(tfd.Normal(x_loc, 1.))
230+
y = yield tfd.Normal(x, 1.)
231+
yield tfd.Sample(tfd.Normal(x + y, 1.), test_lib.NUM_DEVICES)
232+
233+
@tfd.JointDistributionCoroutine
234+
def model():
235+
yield Root(tfd.Normal(1., 1.))
236+
yield Root(tfd.Normal(1., 1.))
237+
yield tfd.Sample(tfd.Normal(1., 1.), test_lib.NUM_DEVICES)
238+
return tf.reduce_mean(
239+
model.log_prob(surrogate.sample(sample_shape=1e6, seed=self.key)))
240+
241+
true_log_prob, true_log_prob_grad = tfp.math.value_and_gradient(
242+
log_prob_fn, 0.)
243+
244+
def run(seed):
245+
def sharded_log_prob_fn(x_loc):
246+
@jd.JointDistributionCoroutine
247+
def surrogate():
248+
x = yield Root(tfd.Normal(x_loc, 1.))
249+
y = yield tfd.Normal(x, 1.)
250+
yield sharded.Sharded(tfd.Normal(x + y, 1.), self.axis_name)
251+
252+
@jd.JointDistributionCoroutine
253+
def model():
254+
yield Root(tfd.Normal(1., 1.))
255+
yield Root(tfd.Normal(1., 1.))
256+
yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name)
257+
return tf.reduce_mean(
258+
model.log_prob(surrogate.sample(sample_shape=1e6, seed=seed)))
259+
sharded_log_prob, sharded_log_prob_grad = tfp.math.value_and_gradient(
260+
sharded_log_prob_fn, 0.)
261+
return sharded_log_prob, sharded_log_prob_grad
262+
263+
sharded_log_prob, sharded_log_prob_grad = self.per_replica_to_tensor(
264+
self.strategy_run(
265+
run, (self.key,), in_axes=None))
266+
for i in range(test_lib.NUM_DEVICES):
267+
self.assertAllClose(sharded_log_prob[i], true_log_prob, atol=1e-2)
268+
self.assertAllClose(sharded_log_prob_grad[i], true_log_prob_grad,
269+
atol=1e-2)
270+
271+
def test_jd_has_correct_sample_path_gradients_with_partial_values(self):
272+
273+
def run(seed):
274+
@jd.JointDistributionCoroutine
275+
def model():
276+
yield Root(tfd.Normal(0., 1., name='x'))
277+
yield tfd.Normal(0., 1., name='y')
278+
yield sharded.Sharded(tfd.Normal(1., 1.), self.axis_name, name='z')
279+
280+
sample = model.sample(seed=seed)
281+
282+
def lp_fn1(x, y, z):
283+
return model.log_prob((x, y, z))
284+
285+
def lp_fn2(x, z):
286+
return model.log_prob(model.sample(value=(x, None, z), seed=seed))
287+
288+
lp_and_grad1 = tfp.math.value_and_gradient(
289+
lp_fn1, [*sample])
290+
(lp2, grad2) = tfp.math.value_and_gradient(
291+
lp_fn2, [sample.x, sample.z])
292+
return lp_and_grad1, (lp2, grad2)
293+
294+
(lp1, grad1), (lp2, grad2) = self.per_replica_to_tensor(
295+
self.strategy_run(
296+
run, (self.key,), in_axes=None))
297+
grad2 = [grad2[0], None, grad2[1]]
298+
for i in range(test_lib.NUM_DEVICES):
299+
for j in range(3):
300+
self.assertAllClose(lp1[i], lp2[i])
301+
if grad2[j] is not None:
302+
self.assertAllClose(grad1[j][i], grad2[j][i])
303+
222304
def test_default_event_space_bijector_non_interacting(self):
223305

224306
root = jd.JointDistributionCoroutine.Root

0 commit comments

Comments
 (0)