Skip to content

Commit 7d28f9b

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Make interacting default event space bijectors work for sharded JDs in simple cases.
The cases where it should work where all RVs are tensor-valued. This excludes nested JDs as well as JDs with non-JD multipart components (which we can't express anyway). A nontrivial change here is that _sanitize_value was moved into sample_and_trace functions because using _execute_model inside the bijector like I did opens up the possibility of elements of the `value` arg to `_execute_model` to be of a different structure than the actual distribution. This feels like a good thing anyway, if we interpretat the `_execute_model` as a primitive effect system. Ideally I'd make the base JD's use the new _conditioned_bijectors function, but doing so interacted poorly with autobatched JDs for unexplored reasons. PiperOrigin-RevId: 379403363
1 parent fe2d35e commit 7d28f9b

File tree

4 files changed

+36
-38
lines changed

4 files changed

+36
-38
lines changed

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class ValueWithTrace(collections.namedtuple(
7979

8080
def trace_distributions_and_values(dist, sample_shape, seed, value=None):
8181
"""Draws a sample, and traces both the distribution and sampled value."""
82+
value = _sanitize_value(dist, value)
8283
if value is None:
8384
value = dist.sample(sample_shape, seed=seed)
8485
elif tf.nest.is_nested(dist.dtype) and any(
@@ -103,6 +104,7 @@ def trace_values_only(dist, sample_shape, seed, value=None):
103104

104105
def trace_values_and_log_probs(dist, sample_shape, seed, value=None):
105106
"""Draws a sample, and traces both the sampled value and its log density."""
107+
value = _sanitize_value(dist, value)
106108
if value is None:
107109
value, lp = dist.experimental_sample_and_log_prob(sample_shape, seed=seed)
108110
elif tf.nest.is_nested(dist.dtype) and any(
@@ -786,7 +788,7 @@ def _execute_model(self,
786788
value_at_index = None
787789
if (value is not None and len(value) > index and
788790
value[index] is not None):
789-
value_at_index = _sanitize_value(actual_distribution, value[index])
791+
value_at_index = value[index]
790792
try:
791793
next_value, traced_values = sample_and_trace_fn(
792794
actual_distribution,

tensorflow_probability/python/experimental/distribute/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ multi_substrate_py_library(
7777
deps = [
7878
":sharded",
7979
# tensorflow dep,
80+
"//tensorflow_probability/python/bijectors:identity",
8081
"//tensorflow_probability/python/distributions",
8182
"//tensorflow_probability/python/distributions:log_prob_ratio",
8283
"//tensorflow_probability/python/internal:distribute_lib",
84+
"//tensorflow_probability/python/internal:samplers",
8385
],
8486
)
8587

@@ -112,7 +114,6 @@ multi_substrate_py_test(
112114
# absl/testing:parameterized dep,
113115
# tensorflow dep,
114116
"//tensorflow_probability",
115-
"//tensorflow_probability/python/internal:distribute_lib",
116117
"//tensorflow_probability/python/internal:distribute_test_lib",
117118
"//tensorflow_probability/python/internal:samplers",
118119
"//tensorflow_probability/python/internal:test_util",

tensorflow_probability/python/experimental/distribute/joint_distribution.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python import distributions as distribution_lib
23+
from tensorflow_probability.python.bijectors import identity as identity_bijector
2324
from tensorflow_probability.python.distributions import joint_distribution as jd_lib
2425
from tensorflow_probability.python.distributions import log_prob_ratio as lp_ratio
2526
from tensorflow_probability.python.internal import distribute_lib
27+
from tensorflow_probability.python.internal import samplers
2628

2729

2830
def pbroadcast_value(value, value_axis_names, output_axis_names):
@@ -101,6 +103,11 @@ def sample_and_trace_value_fn(dist,
101103
final_values_out.append(traced_values[output_index])
102104
return final_values_out
103105

106+
def _default_event_space_bijector(self, *args, **kwargs):
107+
if args or kwargs:
108+
return _DefaultJointBijector(self.experimental_pin(*args, **kwargs))
109+
return _DefaultJointBijector(self)
110+
104111

105112
class JointDistributionSequential(JointDistributionDistributedMixin,
106113
distribution_lib.JointDistributionSequential):
@@ -135,3 +142,26 @@ def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
135142
raise ValueError('p and q must use the same sharding. '
136143
f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')
137144
return jd_lib._jd_log_prob_ratio(p, x, q, y, name=name) # pylint: disable=protected-access
145+
146+
147+
class _DefaultJointBijector(jd_lib._DefaultJointBijector): # pylint: disable=protected-access
148+
"""Sharding-compatible event space bijector for JDs."""
149+
150+
def _conditioned_bijectors(self, samples, constrained=False):
151+
if samples is None:
152+
return self.bijectors
153+
154+
def sample_and_trace_fn(dist, value, **_):
155+
bij = self._bijector_fn(dist)
156+
if bij is None:
157+
bij = identity_bijector.Identity()
158+
159+
# If the RV is not yet constrained, transform it.
160+
value = value if constrained else bij.forward(value)
161+
return jd_lib.ValueWithTrace(value=value, traced=bij)
162+
163+
return self._jd._call_execute_model( # pylint: disable=protected-access
164+
sample_shape=(),
165+
value=samples,
166+
seed=samplers.zeros_seed(),
167+
sample_and_trace_fn=sample_and_trace_fn)

tensorflow_probability/python/experimental/distribute/joint_distribution_test.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import tensorflow_probability as tfp
2424
from tensorflow_probability.python.experimental.distribute import joint_distribution as jd
2525
from tensorflow_probability.python.experimental.distribute import sharded
26-
from tensorflow_probability.python.internal import distribute_lib
2726
from tensorflow_probability.python.internal import distribute_test_lib as test_lib
2827
from tensorflow_probability.python.internal import test_util
2928

@@ -381,19 +380,6 @@ def sharded_model():
381380
shard_axis_name=self.axis_name,
382381
name='z')
383382

384-
@tfd.JointDistributionCoroutine
385-
def manual_sharded_model():
386-
# This one has manual pbroadcasts; the goal is to get sharded_model above
387-
# to do this automatically.
388-
x = yield root(tfd.LogNormal(0., 1., name='x'))
389-
x = distribute_lib.pbroadcast(x, axis_name=self.axis_name)
390-
yield sharded.Sharded(
391-
tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y')
392-
yield sharded.Sharded(
393-
tfb.Scale(x)(tfd.Normal(0., 1.)),
394-
shard_axis_name=self.axis_name,
395-
name='z')
396-
397383
sample = model.sample(seed=self.key)
398384
unconstrained_sample = (
399385
model.experimental_default_event_space_bijector().inverse(sample))
@@ -416,12 +402,6 @@ def run(unconstrained_sample):
416402
lambda unconstrained_sample: unconstrained_lp( # pylint: disable=g-long-lambda
417403
sharded_model, unconstrained_sample), (unconstrained_sample,))
418404

419-
def manual_run(unconstrained_sample):
420-
return tfp.math.value_and_gradient(
421-
lambda unconstrained_sample: unconstrained_lp( # pylint: disable=g-long-lambda
422-
manual_sharded_model, unconstrained_sample),
423-
(unconstrained_sample,))
424-
425405
sharded_unconstrained_sample = unconstrained_sample._replace(
426406
y=self.shard_values(unconstrained_sample.y),
427407
z=self.shard_values(unconstrained_sample.z))
@@ -433,23 +413,8 @@ def manual_run(unconstrained_sample):
433413
lp = lp[0]
434414
g = g._replace(x=g.x[0])
435415

436-
manual_lp, (manual_g,) = self.per_replica_to_tensor(
437-
self.strategy_run(
438-
manual_run, (sharded_unconstrained_sample,),
439-
in_axes=(model.dtype._replace(x=None, y=0, z=0),)))
440-
manual_lp = manual_lp[0]
441-
manual_g = manual_g._replace(x=manual_g.x[0])
442-
443416
self.assertAllClose(true_lp, lp)
444-
# TODO(b/175084455): This will fail because there are sharded <->
445-
# non-sharded edges in the gradient graph not accounted for. The edges arise
446-
# because the sharded bijectors' parameterizations depend non-sharded
447-
# parameters.
448-
with self.assertRaises(AssertionError):
449-
self.assertAllCloseNested(true_g, g)
450-
451-
self.assertAllClose(true_lp, manual_lp)
452-
self.assertAllCloseNested(true_g, manual_g)
417+
self.assertAllCloseNested(true_g, g)
453418

454419
if __name__ == '__main__':
455420
tf.test.main()

0 commit comments

Comments
 (0)