@@ -276,23 +276,28 @@ def _get_permutations(num_results, dims, seed=None):
276276 Args:
277277 num_results: A positive scalar `Tensor` of integral type. The number of
278278 draws from the discrete uniform distribution over the permutation groups.
279- dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
279+ dims: A 1D numpy array of the same dtype as `num_results`. The degree of the
280280 permutation groups from which to sample.
281281 seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
282282
283283 Returns:
284284 permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
285285 dtype as `dims`.
286286 """
287- seeds = samplers .split_seed (seed , n = ps .size (dims ))
288-
289- def generate_one (dim , seed ):
290- return tf .argsort (samplers .uniform (
291- [num_results , dim ], seed = seed ), axis = - 1 )
292-
293- return tf .concat ([generate_one (dim , seed )
294- for dim , seed in zip (dims , tf .unstack (seeds ))],
295- axis = - 1 )
287+ n = dims .size
288+ max_size = np .max (dims )
289+ samples = samplers .uniform ([num_results , n , max_size ], seed = seed )
290+ should_mask = np .arange (max_size ) >= dims [..., np .newaxis ]
291+ # Choose a number that does not affect the permutation and relative location.
292+ samples = tf .where (
293+ should_mask ,
294+ dtype_util .as_numpy_dtype (samples .dtype )(np .arange (max_size ) + 10. ),
295+ samples )
296+ samples = tf .argsort (samples , axis = - 1 )
297+ # Generate the set of indices to gather.
298+ should_mask = np .tile (should_mask , [num_results , 1 , 1 ])
299+ indices = np .stack (np .where (~ should_mask ), axis = - 1 )
300+ return tf .gather_nd (samples , indices )
296301
297302
298303def _get_indices (num_results , sequence_indices , dtype , name = None ):
0 commit comments