@@ -224,47 +224,90 @@ def _event_shape(self):
224
224
def _event_shape_tensor (self ):
225
225
return self .distribution .event_shape_tensor ()
226
226
227
- def _sample_n (self , n , seed = None ):
227
+ def _augment_sample_shape (self , sample_shape ):
228
+ # Suppose we have:
229
+ # - sample shape of `[n]`,
230
+ # - underlying distribution batch shape of `[2, 1]`,
231
+ # - final broadcast batch shape of `[4, 2, 3]`.
232
+ # Then we must draw `sample_shape + [12]` samples, where
233
+ # `12 == n_batch // underlying_n_batch`.
228
234
batch_shape = self .batch_shape_tensor ()
229
- batch_rank = ps .rank_from_shape (batch_shape )
230
235
n_batch = ps .reduce_prod (batch_shape )
236
+ underlying_batch_shape = self .distribution .batch_shape_tensor ()
237
+ underlying_n_batch = ps .reduce_prod (underlying_batch_shape )
238
+ return ps .concat (
239
+ [sample_shape ,
240
+ [ps .maximum (0 , n_batch // underlying_n_batch )]],
241
+ axis = 0 )
242
+
243
+ def _transpose_and_reshape_result (self , x , sample_shape , event_shape = None ):
244
+ if event_shape is None :
245
+ event_shape = self .event_shape_tensor ()
246
+
247
+ batch_shape = self .batch_shape_tensor ()
248
+ batch_rank = ps .rank_from_shape (batch_shape )
231
249
232
250
underlying_batch_shape = self .distribution .batch_shape_tensor ()
233
251
underlying_batch_rank = ps .rank_from_shape (underlying_batch_shape )
234
- underlying_n_batch = ps .reduce_prod (underlying_batch_shape )
235
252
236
- # Left pad underlying shape with any necessary ones.
253
+ # Continuing the example from `_augment_sample_shape`, suppose we have:
254
+ # - sample shape of `[n]`,
255
+ # - underlying distribution batch shape of `[2, 1]`,
256
+ # - final broadcast batch shape of `[4, 2, 3]`.
257
+ # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
258
+ # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.
259
+
260
+ # First, we reshape to expand out the batch elements:
261
+ # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
262
+ # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
263
+ # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
237
264
underlying_bcast_shp = ps .concat (
238
265
[ps .ones ([ps .maximum (batch_rank - underlying_batch_rank , 0 )],
239
266
dtype = underlying_batch_shape .dtype ),
240
267
underlying_batch_shape ],
241
268
axis = 0 )
242
-
243
- # Determine how many underlying samples to produce.
244
- n_bcast_samples = ps .maximum (0 , n_batch // underlying_n_batch )
245
- samps = self .distribution .sample ([n , n_bcast_samples ], seed = seed )
246
-
247
269
is_dim_bcast = ps .not_equal (batch_shape , underlying_bcast_shp )
270
+ x_with_doubled_batch = tf .reshape (
271
+ x ,
272
+ ps .concat ([sample_shape ,
273
+ ps .where (is_dim_bcast , batch_shape , 1 ),
274
+ underlying_bcast_shp ,
275
+ event_shape ], axis = 0 ))
276
+
277
+ # Next, construct the permutation that interleaves the batch dimensions,
278
+ # resulting in samples with shape
279
+ # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
280
+ # Note that each interleaved pair of batch dimensions contains exactly one
281
+ # dim of size `1` and one of size `>= 1`.
282
+ sample_ndims = ps .rank_from_shape (sample_shape )
283
+ x_with_interleaved_batch = tf .transpose (
284
+ x_with_doubled_batch ,
285
+ perm = ps .concat ([
286
+ ps .range (sample_ndims ),
287
+ sample_ndims + ps .reshape (
288
+ ps .stack ([ps .range (batch_rank ),
289
+ ps .range (batch_rank ) + batch_rank ], axis = - 1 ),
290
+ [- 1 ]),
291
+ sample_ndims + 2 * batch_rank + ps .range (
292
+ ps .rank_from_shape (event_shape ))], axis = 0 ))
293
+
294
+ # Final reshape to remove the spurious `1` dimensions.
295
+ return tf .reshape (
296
+ x_with_interleaved_batch ,
297
+ ps .concat ([sample_shape , batch_shape , event_shape ], axis = 0 ))
248
298
249
- event_shape = self .event_shape_tensor ()
250
- event_rank = ps .rank_from_shape (event_shape )
251
- shp = ps .concat ([[n ], ps .where (is_dim_bcast , batch_shape , 1 ),
252
- underlying_bcast_shp ,
253
- event_shape ], axis = 0 )
254
- # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
255
- samps = tf .reshape (samps , shp )
256
- # Interleave broadcast and underlying axis indices for transpose.
257
- interleaved_batch_axes = ps .reshape (
258
- ps .stack ([ps .range (batch_rank ),
259
- ps .range (batch_rank ) + batch_rank ],
260
- axis = - 1 ),
261
- [- 1 ]) + 1
262
-
263
- event_axes = ps .range (event_rank ) + (1 + 2 * batch_rank )
264
- perm = ps .concat ([[0 ], interleaved_batch_axes , event_axes ], axis = 0 )
265
- samps = tf .transpose (samps , perm = perm )
266
- # Finally, reshape to the fully-broadcast batch shape.
267
- return tf .reshape (samps , ps .concat ([[n ], batch_shape , event_shape ], axis = 0 ))
299
+ def _sample_n (self , n , seed = None ):
300
+ sample_shape = ps .reshape (n , [1 ])
301
+ x = self .distribution .sample (
302
+ self ._augment_sample_shape (sample_shape ), seed = seed )
303
+ return self ._transpose_and_reshape_result (x , sample_shape = sample_shape )
304
+
305
+ def _sample_and_log_prob (self , sample_shape , seed ):
306
+ x , lp = self .distribution .experimental_sample_and_log_prob (
307
+ self ._augment_sample_shape (sample_shape ), seed = seed )
308
+ return (self ._transpose_and_reshape_result (x , sample_shape ),
309
+ self ._transpose_and_reshape_result (lp , sample_shape ,
310
+ event_shape = ()))
268
311
269
312
_log_prob = _make_bcast_fn ('log_prob' , n_event_shapes = 0 )
270
313
_prob = _make_bcast_fn ('prob' , n_event_shapes = 0 )
0 commit comments