Skip to content

Commit 909135a

Browse files
davmretensorflower-gardener
authored andcommitted
Minor cleanup of JD*AB to keep iid_sample out of stack traces for log_prob.
I found it confusing to see `iid_sample` in the stack trace for a deterministic `log_prob` call (since we now use the same code for everything). Performance-wise, `iid_sample` was already cheap when sample_shape=(), so this won't be a big optimization but also shouldn't make things any worse. PiperOrigin-RevId: 376042972
1 parent ab18b76 commit 909135a

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tensorflow_probability/python/distributions/joint_distribution_vmap_mixin.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,11 @@ def _call_execute_model(self,
116116
self._model_flatten(value),
117117
check_types=False),
118118
flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))
119+
sample_shape_may_be_nontrivial = (
120+
distribution_util.shape_may_be_nontrivial(sample_shape))
119121

120122
if not self.use_vectorized_map or not (
121-
distribution_util.shape_may_be_nontrivial(sample_shape) or # pylint: disable=protected-access
123+
sample_shape_may_be_nontrivial or # pylint: disable=protected-access
122124
value_might_have_sample_dims):
123125
# No need to auto-vectorize.
124126
return joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access
@@ -134,7 +136,8 @@ def _call_execute_model(self,
134136
lambda v, nd: None if v is None else nd,
135137
value, self._model_unflatten(self._single_sample_ndims),
136138
check_types=False)
137-
batch_execute_model = vectorization_util.make_rank_polymorphic(
139+
140+
vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic(
138141
lambda v, seed: ( # pylint: disable=g-long-lambda
139142
joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access
140143
self,
@@ -144,12 +147,16 @@ def _call_execute_model(self,
144147
sample_and_trace_fn=sample_and_trace_fn)),
145148
core_ndims=[value_core_ndims, None],
146149
validate_args=self.validate_args)
150+
# Redefine the polymorphic fn to hack around `make_rank_polymorphic`
151+
# not currently supporting keyword args. This is needed because the
152+
# `iid_sample` wrapper below expects to pass through a `seed` kwarg.
153+
vectorized_execute_model = (
154+
lambda v, seed: vectorized_execute_model_helper(v, seed)) # pylint: disable=unnecessary-lambda
155+
156+
if sample_shape_may_be_nontrivial:
157+
vectorized_execute_model = vectorization_util.iid_sample(
158+
vectorized_execute_model, sample_shape)
147159

148-
# Draw samples.
149-
vectorized_execute_model = vectorization_util.iid_sample(
150-
# Redefine the polymorphic fn to hack around `make_rank_polymorphic`
151-
# not currently supporting keyword args.
152-
lambda v, seed: batch_execute_model(v, seed), sample_shape) # pylint: disable=unnecessary-lambda
153160
return vectorized_execute_model(value, seed=seed)
154161

155162
def _default_event_space_bijector(self, *args, **kwargs):

0 commit comments

Comments
 (0)