23
23
import tensorflow .compat .v2 as tf
24
24
25
25
from tensorflow_probability .python import math as tfp_math
26
+ from tensorflow_probability .python .distributions import joint_distribution as jd_lib
26
27
from tensorflow_probability .python .internal import assert_util
27
28
from tensorflow_probability .python .internal import prefer_static
28
29
@@ -123,9 +124,7 @@ def event_shape_tensor(self, sample_shape=(), name='event_shape_tensor'):
123
124
d .event_shape_tensor ()))
124
125
return self ._model_unflatten (component_shapes )
125
126
126
- def _map_and_reduce_measure_over_dists (self , attr , reduce_fn , value ):
127
- """Reduces all non-batch dimensions of the provided measure."""
128
- xs = list (self ._map_measure_over_dists (attr , value ))
127
+ def _reduce_measure_over_dists (self , xs , reduce_fn ):
129
128
num_trailing_batch_dims_treated_as_event = [
130
129
prefer_static .rank_from_shape (
131
130
d .batch_shape_tensor ()) - self ._batch_ndims
@@ -145,14 +144,31 @@ def _maybe_check_batch_shape(self):
145
144
parts [0 ], s , message = 'Component batch shapes are inconsistent.' ))
146
145
return assertions
147
146
148
- def _log_prob (self , value ):
147
+ def _reduce_log_probs_over_dists (self , lps ):
149
148
if self ._experimental_use_kahan_sum :
150
- xs = self ._map_and_reduce_measure_over_dists (
151
- 'log_prob' , tfp_math .reduce_kahan_sum , value )
152
- return sum (xs ).total
153
- xs = self ._map_and_reduce_measure_over_dists (
154
- 'log_prob' , tf .reduce_sum , value )
155
- return sum (xs )
149
+ return sum (jd_lib .maybe_check_wont_broadcast (
150
+ self ._reduce_measure_over_dists (
151
+ lps , reduce_fn = tfp_math .reduce_kahan_sum ),
152
+ self .validate_args )).total
153
+ else :
154
+ return sum (jd_lib .maybe_check_wont_broadcast (
155
+ self ._reduce_measure_over_dists (lps , reduce_fn = tf .reduce_sum ),
156
+ self .validate_args ))
157
+
158
+ def _sample_and_log_prob (self , sample_shape , seed , value = None , ** kwargs ):
159
+ xs , lps = zip (
160
+ * self ._call_execute_model (
161
+ sample_shape ,
162
+ seed = seed ,
163
+ value = self ._resolve_value (value = value ,
164
+ allow_partially_specified = True ,
165
+ ** kwargs ),
166
+ sample_and_trace_fn = jd_lib .trace_values_and_log_probs ))
167
+ return self ._model_unflatten (xs ), self ._reduce_log_probs_over_dists (lps )
168
+
169
+ def _log_prob (self , value ):
170
+ return self ._reduce_log_probs_over_dists (
171
+ self ._map_measure_over_dists ('log_prob' , value ))
156
172
157
173
def log_prob_parts (self , value , name = 'log_prob_parts' ):
158
174
"""Log probability density/mass function, part-wise.
@@ -172,18 +188,14 @@ def log_prob_parts(self, value, name='log_prob_parts'):
172
188
sum_fn = tf .reduce_sum
173
189
if self ._experimental_use_kahan_sum :
174
190
sum_fn = lambda x , axis : tfp_math .reduce_kahan_sum (x , axis = axis ).total
175
- xs = self ._map_and_reduce_measure_over_dists (
176
- 'log_prob' , sum_fn , value )
177
- return self ._model_unflatten (xs )
191
+ return self ._model_unflatten (
192
+ self ._reduce_measure_over_dists (
193
+ self ._map_measure_over_dists ('log_prob' , value ),
194
+ sum_fn ))
178
195
179
196
def _unnormalized_log_prob (self , value ):
180
- if self ._experimental_use_kahan_sum :
181
- xs = self ._map_and_reduce_measure_over_dists (
182
- 'unnormalized_log_prob' , tfp_math .reduce_kahan_sum , value )
183
- return sum (xs ).total
184
- xs = self ._map_and_reduce_measure_over_dists (
185
- 'unnormalized_log_prob' , tf .reduce_sum , value )
186
- return sum (xs )
197
+ return self ._reduce_log_probs_over_dists (
198
+ self ._map_measure_over_dists ('unnormalized_log_prob' , value ))
187
199
188
200
def unnormalized_log_prob_parts (self , value , name = None ):
189
201
"""Unnormalized log probability density/mass function, part-wise.
@@ -203,9 +215,10 @@ def unnormalized_log_prob_parts(self, value, name=None):
203
215
sum_fn = tf .reduce_sum
204
216
if self ._experimental_use_kahan_sum :
205
217
sum_fn = lambda x , axis : tfp_math .reduce_kahan_sum (x , axis = axis ).total
206
- xs = self ._map_and_reduce_measure_over_dists (
207
- 'unnormalized_log_prob' , sum_fn , value )
208
- return self ._model_unflatten (xs )
218
+ return self ._model_unflatten (
219
+ self ._reduce_measure_over_dists (
220
+ self ._map_measure_over_dists ('unnormalized_log_prob' , value ),
221
+ sum_fn ))
209
222
210
223
def prob_parts (self , value , name = 'prob_parts' ):
211
224
"""Log probability density/mass function.
@@ -221,9 +234,10 @@ def prob_parts(self, value, name='prob_parts'):
221
234
each `distribution_fn` evaluated at each corresponding `value`.
222
235
"""
223
236
with self ._name_and_control_scope (name ):
224
- xs = self ._map_and_reduce_measure_over_dists (
225
- 'prob' , tf .reduce_prod , value )
226
- return self ._model_unflatten (xs )
237
+ return self ._model_unflatten (
238
+ self ._reduce_measure_over_dists (
239
+ self ._map_measure_over_dists ('prob' , value ),
240
+ tf .reduce_prod ))
227
241
228
242
def is_scalar_batch (self , name = 'is_scalar_batch' ):
229
243
"""Indicates that `batch_shape == []`.
0 commit comments