@@ -214,28 +214,31 @@ def _event_shape(self):
214
214
return tensorshape_util .concatenate (sample_shape ,
215
215
self .distribution .event_shape )
216
216
217
- def _sample_n (self , n , seed , ** kwargs ):
218
- sample_shape = ps .reshape ( self . sample_shape , shape = [ - 1 ])
219
- fake_sample_ndims = ps .rank_from_shape ( sample_shape )
217
+ def _sampling_permutation (self , sample_ndims ):
218
+ fake_sample_ndims = ps .rank_from_shape (
219
+ ps .reshape ( self . sample_shape , shape = [ - 1 ]) )
220
220
event_ndims = ps .rank_from_shape (
221
221
self .distribution .event_shape_tensor , self .distribution .event_shape )
222
222
batch_ndims = ps .rank_from_shape (
223
223
self .distribution .batch_shape_tensor , self .distribution .batch_shape )
224
- perm = ps .concat ([
225
- [ 0 ] ,
226
- ps .range (1 + fake_sample_ndims ,
227
- 1 + fake_sample_ndims + batch_ndims ,
224
+ return ps .concat ([
225
+ ps . range ( sample_ndims ) ,
226
+ ps .range (sample_ndims + fake_sample_ndims ,
227
+ sample_ndims + fake_sample_ndims + batch_ndims ,
228
228
dtype = tf .int32 ),
229
- ps .range (1 , 1 + fake_sample_ndims , dtype = tf .int32 ),
230
- ps .range (1 + fake_sample_ndims + batch_ndims ,
231
- 1 + fake_sample_ndims + batch_ndims + event_ndims ,
229
+ ps .range (sample_ndims , sample_ndims + fake_sample_ndims ,
230
+ dtype = tf .int32 ),
231
+ ps .range (sample_ndims + fake_sample_ndims + batch_ndims ,
232
+ sample_ndims + fake_sample_ndims + batch_ndims + event_ndims ,
232
233
dtype = tf .int32 ),
233
234
], axis = 0 )
234
- x = self .distribution .sample (
235
- ps .concat ([[n ], sample_shape ], axis = 0 ),
236
- seed = seed ,
237
- ** kwargs )
238
- return tf .transpose (a = x , perm = perm )
235
+
236
+ def _sample_n (self , n , seed , ** kwargs ):
237
+ sample_shape = ps .reshape (self .sample_shape , shape = [- 1 ])
238
+ x = self .distribution .sample (ps .concat ([[n ], sample_shape ], axis = 0 ),
239
+ seed = seed ,
240
+ ** kwargs )
241
+ return tf .transpose (a = x , perm = self ._sampling_permutation (sample_ndims = 1 ))
239
242
240
243
def _sum_fn (self ):
241
244
if self ._experimental_use_kahan_sum :
@@ -259,20 +262,9 @@ def _prepare_for_underlying(self, x):
259
262
ps .shape (x ),
260
263
paddings = [[ps .maximum (0 , - d ), 0 ]],
261
264
constant_values = 1 ))
262
- ndims = ps .rank (x )
263
265
sample_ndims = ps .maximum (0 , d )
264
- # (2) Transpose x's dims.
265
- sample_dims = ps .range (0 , sample_ndims )
266
- batch_dims = ps .range (sample_ndims , sample_ndims + batch_ndims )
267
- extra_sample_dims = ps .range (
268
- sample_ndims + batch_ndims ,
269
- sample_ndims + batch_ndims + extra_sample_ndims )
270
- event_dims = ps .range (
271
- sample_ndims + batch_ndims + extra_sample_ndims ,
272
- ndims )
273
- perm = ps .concat (
274
- [sample_dims , extra_sample_dims , batch_dims , event_dims ], axis = 0 )
275
- x = tf .transpose (x , perm = perm )
266
+ x = tf .transpose (
267
+ x , perm = ps .invert_permutation (self ._sampling_permutation (sample_ndims )))
276
268
return x , (sample_ndims , extra_sample_ndims , batch_ndims )
277
269
278
270
def _finish_log_prob (self , lp , aux ):
@@ -289,6 +281,21 @@ def _finish_log_prob(self, lp, aux):
289
281
axis = ps .range (sample_ndims , sample_ndims + extra_sample_ndims )
290
282
return self ._sum_fn ()(lp , axis = axis )
291
283
284
+ def _sample_and_log_prob (self , sample_shape , seed , ** kwargs ):
285
+ sample_ndims = ps .rank_from_shape (sample_shape )
286
+ batch_ndims = ps .rank_from_shape (
287
+ self .distribution .batch_shape_tensor ,
288
+ self .distribution .batch_shape )
289
+ extra_sample_shape = ps .reshape (self .sample_shape , shape = [- 1 ])
290
+ extra_sample_ndims = ps .rank_from_shape (extra_sample_shape )
291
+ x , lp = self .distribution .experimental_sample_and_log_prob (
292
+ ps .concat ([sample_shape , extra_sample_shape ], axis = 0 ), seed = seed ,
293
+ ** kwargs )
294
+ return (
295
+ tf .transpose (x , perm = self ._sampling_permutation (sample_ndims )),
296
+ self ._finish_log_prob (
297
+ lp , aux = (sample_ndims , extra_sample_ndims , batch_ndims )))
298
+
292
299
def _log_prob (self , x , ** kwargs ):
293
300
x , aux = self ._prepare_for_underlying (x )
294
301
return self ._finish_log_prob (
0 commit comments