|
20 | 20 |
|
21 | 21 | import tensorflow.compat.v2 as tf
|
22 | 22 | from tensorflow_probability.python import distributions as distribution_lib
|
| 23 | +from tensorflow_probability.python.distributions import joint_distribution as jd_lib |
23 | 24 | from tensorflow_probability.python.distributions import log_prob_ratio as lp_ratio
|
24 | 25 | from tensorflow_probability.python.internal import distribute_lib
|
25 |
| -from tensorflow_probability.python.internal import samplers |
26 | 26 |
|
27 |
| -from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import |
| 27 | + |
| 28 | +def pbroadcast_value(value, value_axis_names, output_axis_names): |
| 29 | + value_axis_names = distribute_lib.canonicalize_axis_name(value_axis_names) |
| 30 | + pbroadcast_axes = [ |
| 31 | + axis_name for axis_name in output_axis_names |
| 32 | + if axis_name not in value_axis_names |
| 33 | + ] |
| 34 | + return distribute_lib.pbroadcast(value, pbroadcast_axes) |
| 35 | + |
| 36 | + |
| 37 | +def _maybe_substitute_or_add_value_in_tuple(value_tuple, index, value): |
| 38 | + if index > len(value_tuple): |
| 39 | + raise ValueError('Cannot add value to tuple without available slot.') |
| 40 | + if index == len(value_tuple): |
| 41 | + return value_tuple + (value,) |
| 42 | + curr_value = value_tuple[index] |
| 43 | + if curr_value is not None: |
| 44 | + return value_tuple |
| 45 | + return value_tuple[:index] + (value,) + value_tuple[index + 1:] |
28 | 46 |
|
29 | 47 |
|
30 | 48 | class JointDistributionDistributedMixin(object):
|
31 | 49 | """A JDMixin that shards the log_prob calculation."""
|
32 | 50 |
|
33 |
| - def _map_measure_over_dists(self, attr, value): |
34 |
| - """Override the default implementation to shard its log_prob calculation.""" |
35 |
| - if any(x is None for x in tf.nest.flatten(value)): |
36 |
| - raise ValueError('No `value` part can be `None`; saw: {}.'.format(value)) |
37 |
| - if (attr in ('log_prob', 'unnormalized_log_prob')) and any( |
38 |
| - self.experimental_shard_axis_names): |
39 |
| - |
40 |
| - def inner_log_prob_parts(value): |
41 |
| - ds, xs = self._call_flat_sample_distributions( |
42 |
| - value=value, seed=samplers.zeros_seed()) |
43 |
| - # We need to flatten and unflatten here to ensure the output structure |
44 |
| - # matches `flat_sharded_distributions`. |
45 |
| - return self._model_unflatten( |
46 |
| - [getattr(d, attr)(x) for d, x in zip(ds, xs)]) |
47 |
| - |
48 |
| - axis_names = self.experimental_shard_axis_names |
49 |
| - # Individual distributions will apply psum in their `log_prob` methods |
50 |
| - # so we need to pbroadcast `value` according to `axis_names` to provide |
51 |
| - # correct gradients. We are safe to add pbroadcasts to functions with |
52 |
| - # psums already in them. |
53 |
| - log_prob_parts = distribute_lib.make_pbroadcast_function( |
54 |
| - inner_log_prob_parts, (axis_names,), axis_names, |
55 |
| - out_dtype=value)(value) |
56 |
| - return iter(tf.nest.flatten(log_prob_parts)) |
57 |
| - ds, xs = self._call_flat_sample_distributions( |
58 |
| - value=value, seed=samplers.zeros_seed()) |
59 |
| - return (getattr(d, attr)(x) for d, x in zip(ds, xs)) |
| 51 | + def _call_execute_model( |
| 52 | + self, |
| 53 | + sample_shape=(), |
| 54 | + seed=None, |
| 55 | + value=None, |
| 56 | + sample_and_trace_fn=jd_lib.trace_distributions_and_values): |
| 57 | + return self._distribute_execute_model( |
| 58 | + sample_shape=sample_shape, |
| 59 | + seed=seed, |
| 60 | + value=value if value is None else self._model_flatten(value), |
| 61 | + sample_and_trace_fn=sample_and_trace_fn) |
| 62 | + |
| 63 | + def _distribute_execute_model( |
| 64 | + self, |
| 65 | + sample_shape=(), |
| 66 | + seed=None, |
| 67 | + value=None, |
| 68 | + sample_and_trace_fn=jd_lib.trace_distributions_and_values): |
| 69 | + """Executes a model, adding `pbroadcasts` to ensure correct gradients.""" |
| 70 | + shard_axis_names = self._model_flatten(self.experimental_shard_axis_names) |
| 71 | + final_values_out = [] |
| 72 | + if value is None: |
| 73 | + value = () |
| 74 | + |
| 75 | + def sample_and_trace_value_fn(dist, |
| 76 | + sample_shape, |
| 77 | + seed, |
| 78 | + value=None): |
| 79 | + value, traced = sample_and_trace_fn( |
| 80 | + dist=dist, sample_shape=sample_shape, seed=seed, value=value) |
| 81 | + # We trace `next_value` here so we can pass it back in as part of `value` |
| 82 | + # in the next iteration of the coroutine. |
| 83 | + return value, (value, traced) |
| 84 | + |
| 85 | + for output_index, output_axes in enumerate(shard_axis_names): |
| 86 | + # We pbroadcast all values according to the difference between the current |
| 87 | + # `output_axes` and their own active axes. |
| 88 | + previous_shard_axes = shard_axis_names[:len(value)] |
| 89 | + pbroadcasted_value = tuple( |
| 90 | + pbroadcast_value(v, v_axis_names, output_axes) |
| 91 | + for v, v_axis_names in zip(value, previous_shard_axes) |
| 92 | + ) |
| 93 | + pbroadcasted_values, traced_values = zip(*super()._execute_model( |
| 94 | + sample_shape=sample_shape, |
| 95 | + seed=seed, |
| 96 | + value=pbroadcasted_value + (None,), |
| 97 | + stop_index=output_index + 1, |
| 98 | + sample_and_trace_fn=sample_and_trace_value_fn)) |
| 99 | + value = _maybe_substitute_or_add_value_in_tuple( |
| 100 | + value, output_index, pbroadcasted_values[output_index]) |
| 101 | + final_values_out.append(traced_values[output_index]) |
| 102 | + return final_values_out |
60 | 103 |
|
61 | 104 |
|
62 | 105 | class JointDistributionSequential(JointDistributionDistributedMixin,
|
@@ -91,19 +134,4 @@ def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
|
91 | 134 | if p_axis_names != q_axis_names:
|
92 | 135 | raise ValueError('p and q must use the same sharding. '
|
93 | 136 | f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')
|
94 |
| - |
95 |
| - def log_prob_ratio_parts_fn(x, y): |
96 |
| - p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] |
97 |
| - q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] |
98 |
| - return nest.map_structure_up_to( |
99 |
| - p_dists, |
100 |
| - lp_ratio.log_prob_ratio, |
101 |
| - p_dists, x, q_dists, y) |
102 |
| - |
103 |
| - return tf.add_n( |
104 |
| - tf.nest.flatten( |
105 |
| - distribute_lib.make_pbroadcast_function( |
106 |
| - log_prob_ratio_parts_fn, |
107 |
| - in_axes=(p_axis_names, p_axis_names), |
108 |
| - out_axes=p_axis_names, |
109 |
| - out_dtype=x)(x, y))) |
| 137 | + return jd_lib._jd_log_prob_ratio(p, x, q, y, name=name) # pylint: disable=protected-access |
0 commit comments