Skip to content

Commit 40d9757

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Allow partially specifying values for nested joint distributions.
The issue was twofold. First, the various sample_and_value functions were over-eager in bypassing the sample method. This was fixed via a cludge-like change, it kind of feels like the `value` kwarg should be moved up to Distribution so we can just call sample unconditionally. The second issue was two places where we used map_structure to convert value to a nest of Tensors. This isn't a reasonable thing to do when there are None's present. This was replaced with a recursive call of a utility implemented via _model_flatten/unflatten etc. PiperOrigin-RevId: 377168479
1 parent 849e403 commit 40d9757

File tree

4 files changed

+202
-42
lines changed

4 files changed

+202
-42
lines changed

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -81,27 +81,36 @@ def trace_distributions_and_values(dist, sample_shape, seed, value=None):
8181
"""Draws a sample, and traces both the distribution and sampled value."""
8282
if value is None:
8383
value = dist.sample(sample_shape, seed=seed)
84+
elif tf.nest.is_nested(dist.dtype) and any(
85+
v is None for v in tf.nest.flatten(value)):
86+
# TODO(siege): This is making an assumption that nested dtype => partial
87+
# value support, which is not necessarily reasonable.
88+
value = dist.sample(sample_shape, seed=seed, value=value)
8489
return ValueWithTrace(value=value, traced=(dist, value))
8590

8691

8792
def trace_distributions_only(dist, sample_shape, seed, value=None):
8893
"""Draws a sample, and traces the sampled value."""
89-
if value is None:
90-
value = dist.sample(sample_shape, seed=seed)
91-
return ValueWithTrace(value=value, traced=dist)
94+
ret = trace_distributions_and_values(dist, sample_shape, seed, value)
95+
return ret._replace(traced=ret.traced[0])
9296

9397

9498
def trace_values_only(dist, sample_shape, seed, value=None):
9599
"""Draws a sample, and traces the sampled value."""
96-
if value is None:
97-
value = dist.sample(sample_shape, seed=seed)
98-
return ValueWithTrace(value=value, traced=value)
100+
ret = trace_distributions_and_values(dist, sample_shape, seed, value)
101+
return ret._replace(traced=ret.traced[1])
99102

100103

101104
def trace_values_and_log_probs(dist, sample_shape, seed, value=None):
102105
"""Draws a sample, and traces both the sampled value and its log density."""
103106
if value is None:
104107
value, lp = dist.experimental_sample_and_log_prob(sample_shape, seed=seed)
108+
elif tf.nest.is_nested(dist.dtype) and any(
109+
v is None for v in tf.nest.flatten(value)):
110+
# TODO(siege): This is making an assumption that nested dtype => partial
111+
# value support, which is not necessarily reasonable.
112+
value, lp = dist.experimental_sample_and_log_prob(
113+
sample_shape, seed=seed, value=value)
105114
else:
106115
lp = dist.log_prob(value)
107116
return ValueWithTrace(value=value, traced=(value, lp))
@@ -210,7 +219,9 @@ class JointDistribution(distribution_lib.Distribution):
210219
- `_model_coroutine`: A generator that yields a sequence of
211220
`tfd.Distribution`-like instances.
212221
213-
- `_model_flatten`: takes a structured input and returns a sequence.
222+
- `_model_flatten`: takes a structured input and returns a sequence. The
223+
sequence order must match the order distributions are yielded from
224+
`_model_coroutine`.
214225
215226
- `_model_unflatten`: takes a sequence and returns a structure matching the
216227
semantics of the `JointDistribution` subclass.
@@ -613,33 +624,14 @@ def _map_attr_over_dists(self, attr, dists=None):
613624
if dists is None else dists)
614625
return (getattr(d, attr)() for d in dists)
615626

616-
def _sanitize_value(self, value):
617-
"""Ensures `value` matches `self.dtype` with `Tensor` or `None` elements."""
618-
if value is None:
619-
return value
620-
621-
if len(value) < len(self.dtype):
622-
# Fill in missing entries with `None`.
623-
if hasattr(self.dtype, 'keys'):
624-
value = {k: value.get(k, None) for k in self.dtype.keys()}
625-
else: # dtype is a sequence.
626-
value = [value[i] if i < len(value) else None
627-
for i in range(len(self.dtype))]
628-
629-
value = nest_util.cast_structure(value, self.dtype)
630-
return nest.map_structure_up_to(
631-
self.dtype,
632-
lambda x, d: x if x is None else tf.convert_to_tensor(x, dtype_hint=d),
633-
value, self.dtype)
634-
635627
def _resolve_value(self, *args, allow_partially_specified=False, **kwargs):
636628
"""Resolves a `value` structure from user-passed arguments."""
637629
value = kwargs.pop('value', None)
638630
if not (args or kwargs):
639-
# Fast path when `value` is the only kwarg. The case where `value` is
640-
# passed as a positional arg is handled by `_resolve_value_from_args`
641-
# below.
642-
return self._sanitize_value(value)
631+
# Fast path when `value` is the only kwarg. The case where `value` is
632+
# passed as a positional arg is handled by `_resolve_value_from_args`
633+
# below.
634+
return _sanitize_value(self, value)
643635
elif value is not None:
644636
raise ValueError('Supplied both `value` and keyword '
645637
'arguments to parameterize sampling. Supplied keyword '
@@ -665,7 +657,7 @@ def _resolve_value(self, *args, allow_partially_specified=False, **kwargs):
665657
'Found unexpected keyword arguments. Distribution names '
666658
'are\n{}\nbut received\n{}\nThese names were '
667659
'invalid:\n{}'.format(dist_name_str, kwarg_names, unmatched_str))
668-
return self._sanitize_value(value)
660+
return _sanitize_value(self, value)
669661

670662
def _call_execute_model(self,
671663
sample_shape=(),
@@ -793,17 +785,7 @@ def _execute_model(self,
793785
value_at_index = None
794786
if (value is not None and len(value) > index and
795787
value[index] is not None):
796-
797-
def convert_tree_to_tensor(x, dtype_hint):
798-
return tf.convert_to_tensor(x, dtype_hint=dtype_hint)
799-
800-
# This signature does not allow kwarg names. Applies
801-
# `convert_to_tensor` on the next value.
802-
value_at_index = nest.map_structure_up_to(
803-
actual_distribution.dtype, # shallow_tree
804-
convert_tree_to_tensor, # func
805-
value[index], # x
806-
actual_distribution.dtype) # dtype_hint
788+
value_at_index = _sanitize_value(actual_distribution, value[index])
807789
try:
808790
next_value, traced_values = sample_and_trace_fn(
809791
actual_distribution,
@@ -1175,6 +1157,46 @@ def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs):
11751157
y, event_ndims, _jd_conditioning=y, **kwargs)
11761158

11771159

1160+
def _sanitize_value(distribution, value):
1161+
"""Ensures `value` matches `distribution.dtype`, adding `None`s as needed."""
1162+
if value is None:
1163+
return value
1164+
1165+
if not tf.nest.is_nested(distribution.dtype):
1166+
return tf.convert_to_tensor(value, dtype_hint=distribution.dtype)
1167+
1168+
if len(value) < len(distribution.dtype):
1169+
# Fill in missing entries with `None`.
1170+
if hasattr(distribution.dtype, 'keys'):
1171+
value = {k: value.get(k, None) for k in distribution.dtype.keys()}
1172+
else: # dtype is a sequence.
1173+
value = [value[i] if i < len(value) else None
1174+
for i in range(len(distribution.dtype))]
1175+
1176+
value = nest_util.cast_structure(value, distribution.dtype)
1177+
jdlike_attrs = [
1178+
'_get_single_sample_distributions',
1179+
'_model_flatten',
1180+
'_model_unflatten',
1181+
]
1182+
if all(hasattr(distribution, attr) for attr in jdlike_attrs):
1183+
flat_dists = distribution._get_single_sample_distributions()
1184+
flat_value = distribution._model_flatten(value)
1185+
flat_value = map(_sanitize_value, flat_dists, flat_value)
1186+
return distribution._model_unflatten(flat_value)
1187+
else:
1188+
# A joint distribution that isn't tfd.JointDistribution-like; assume it has
1189+
# some reasonable dtype semantics. We can't use this for
1190+
# tfd.JointDistribution because we might have a None standing in for a
1191+
# sub-tree (e.g. consider omitting a nested JD).
1192+
return nest.map_structure_up_to(
1193+
distribution.dtype,
1194+
lambda x, d: x if x is None else tf.convert_to_tensor(x, dtype_hint=d),
1195+
value,
1196+
distribution.dtype,
1197+
)
1198+
1199+
11781200
@log_prob_ratio.RegisterLogProbRatio(JointDistribution)
11791201
def _jd_log_prob_ratio(p, x, q, y, name=None):
11801202
"""Implements `log_prob_ratio` for tfd.JointDistribution*."""

tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,49 @@ def _get_support_bijectors(dists, xs=None, ys=None):
980980
self.evaluate(bijectors[i].inverse_event_shape_tensor(
981981
event_shapes[i])))
982982

983+
@parameterized.named_parameters(
984+
('_sample', lambda d, **kwargs: d.sample(**kwargs)),
985+
('_sample_and_log_prob',
986+
lambda d, **kwargs: d.experimental_sample_and_log_prob(**kwargs)[0]),
987+
)
988+
def test_nested_partial_value(self, sample_fn):
989+
@tfd.JointDistributionCoroutine
990+
def innermost():
991+
a = yield Root(tfd.Exponential(1., name='a'))
992+
yield tfd.Sample(tfd.LogNormal(a, a), [5], name='b')
993+
994+
@tfd.JointDistributionCoroutine
995+
def inner():
996+
yield Root(tfd.Exponential(1., name='c'))
997+
yield Root(innermost.copy(name='d'))
998+
999+
@tfd.JointDistributionCoroutine
1000+
def outer():
1001+
yield Root(tfd.Exponential(1., name='e'))
1002+
yield Root(inner.copy(name='f'))
1003+
1004+
seed = test_util.test_seed(sampler_type='stateless')
1005+
true_xs = outer.sample(seed=seed)
1006+
1007+
# These asserts work because we advance the stateless seed inside the model
1008+
# whether or not a sample is actually generated.
1009+
partial_xs = true_xs._replace(f=None)
1010+
xs = sample_fn(outer, value=partial_xs, seed=seed)
1011+
self.assertAllCloseNested(true_xs, xs)
1012+
1013+
partial_xs = true_xs._replace(e=None)
1014+
xs = sample_fn(outer, value=partial_xs, seed=seed)
1015+
self.assertAllCloseNested(true_xs, xs)
1016+
1017+
partial_xs = true_xs._replace(f=true_xs.f._replace(d=None))
1018+
xs = sample_fn(outer, value=partial_xs, seed=seed)
1019+
self.assertAllCloseNested(true_xs, xs)
1020+
1021+
partial_xs = true_xs._replace(
1022+
f=true_xs.f._replace(d=true_xs.f.d._replace(a=None)))
1023+
xs = sample_fn(outer, value=partial_xs, seed=seed)
1024+
self.assertAllCloseNested(true_xs, xs)
1025+
9831026
def test_default_event_space_bijector_nested(self):
9841027
@tfd.JointDistributionCoroutine
9851028
def inner():

tensorflow_probability/python/distributions/joint_distribution_named_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,53 @@ def test_can_call_log_prob_with_kwargs(self):
270270
"can't take positional args"):
271271
lp_kwargs = d.log_prob(e, a, x)
272272

273+
@parameterized.named_parameters(
274+
('_sample', lambda d, **kwargs: d.sample(**kwargs)),
275+
('_sample_and_log_prob',
276+
lambda d, **kwargs: d.experimental_sample_and_log_prob(**kwargs)[0]),
277+
)
278+
def test_nested_partial_value(self, sample_fn):
279+
innermost = tfd.JointDistributionNamed({
280+
'a': tfd.Exponential(1.),
281+
'b': lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]),
282+
})
283+
284+
inner = tfd.JointDistributionNamed({
285+
'c': tfd.Exponential(1.),
286+
'd': innermost,
287+
})
288+
289+
outer = tfd.JointDistributionNamed({
290+
'e': tfd.Exponential(1.),
291+
'f': inner,
292+
})
293+
294+
seed = test_util.test_seed(sampler_type='stateless')
295+
true_xs = outer.sample(seed=seed)
296+
297+
def _update(dict_, **kwargs):
298+
dict_.copy().update(**kwargs)
299+
return dict_
300+
301+
# These asserts work because we advance the stateless seed inside the model
302+
# whether or not a sample is actually generated.
303+
partial_xs = _update(true_xs, f=None)
304+
xs = sample_fn(outer, value=partial_xs, seed=seed)
305+
self.assertAllCloseNested(true_xs, xs)
306+
307+
partial_xs = _update(true_xs, e=None)
308+
xs = sample_fn(outer, value=partial_xs, seed=seed)
309+
self.assertAllCloseNested(true_xs, xs)
310+
311+
partial_xs = _update(true_xs, f=_update(true_xs['f'], d=None))
312+
xs = sample_fn(outer, value=partial_xs, seed=seed)
313+
self.assertAllCloseNested(true_xs, xs)
314+
315+
partial_xs = _update(
316+
true_xs, f=_update(true_xs['f'], d=_update(true_xs['f']['d'], a=None)))
317+
xs = sample_fn(outer, value=partial_xs, seed=seed)
318+
self.assertAllCloseNested(true_xs, xs)
319+
273320
@parameterized.named_parameters(
274321
('basic', basic_ordered_model_fn),
275322
('nested_lists', nested_lists_model_fn))

tensorflow_probability/python/distributions/joint_distribution_sequential_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,54 @@ def test_dist_fn_takes_varargs(self):
326326
lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed()))
327327
self.assertAllEqual(lp.shape, [5])
328328

329+
@parameterized.named_parameters(
330+
('_sample', lambda d, **kwargs: d.sample(**kwargs)),
331+
('_sample_and_log_prob',
332+
lambda d, **kwargs: d.experimental_sample_and_log_prob(**kwargs)[0]),
333+
)
334+
def test_nested_partial_value(self, sample_fn):
335+
innermost = tfd.JointDistributionSequential((
336+
tfd.Exponential(1.),
337+
lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]),
338+
))
339+
340+
inner = tfd.JointDistributionSequential((
341+
tfd.Exponential(1.),
342+
innermost,
343+
))
344+
345+
outer = tfd.JointDistributionSequential((
346+
tfd.Exponential(1.),
347+
inner,
348+
))
349+
350+
seed = test_util.test_seed(sampler_type='stateless')
351+
true_xs = outer.sample(seed=seed)
352+
353+
def _update(tuple_, index, value):
354+
res = list(tuple_)
355+
res[index] = value
356+
return tuple(res)
357+
358+
# These asserts work because we advance the stateless seed inside the model
359+
# whether or not a sample is actually generated.
360+
partial_xs = _update(true_xs, 1, None)
361+
xs = sample_fn(outer, value=partial_xs, seed=seed)
362+
self.assertAllCloseNested(true_xs, xs)
363+
364+
partial_xs = _update(true_xs, 0, None)
365+
xs = sample_fn(outer, value=partial_xs, seed=seed)
366+
self.assertAllCloseNested(true_xs, xs)
367+
368+
partial_xs = _update(true_xs, 1, _update(true_xs[1], 1, None))
369+
xs = sample_fn(outer, value=partial_xs, seed=seed)
370+
self.assertAllCloseNested(true_xs, xs)
371+
372+
partial_xs = _update(
373+
true_xs, 1, _update(true_xs[1], 1, _update(true_xs[1][1], 0, None)))
374+
xs = sample_fn(outer, value=partial_xs, seed=seed)
375+
self.assertAllCloseNested(true_xs, xs)
376+
329377
@parameterized.named_parameters(
330378
('basic', basic_model_fn),
331379
('nested_lists', nested_lists_model_fn))

0 commit comments

Comments
 (0)