Skip to content

Commit a204ec8

Browse files
srvasudetensorflower-gardener
authored andcommitted
Reduce number of ops sample_halton_sequence adds to the graph.
- Turn loop over sampling and argmax in to a single sample and single argmax. PiperOrigin-RevId: 566408146
1 parent 200e753 commit a204ec8

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,23 +276,28 @@ def _get_permutations(num_results, dims, seed=None):
276276
Args:
277277
num_results: A positive scalar `Tensor` of integral type. The number of
278278
draws from the discrete uniform distribution over the permutation groups.
279-
dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
279+
dims: A 1D numpy array of the same dtype as `num_results`. The degree of the
280280
permutation groups from which to sample.
281281
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
282282
283283
Returns:
284284
permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
285285
dtype as `dims`.
286286
"""
287-
seeds = samplers.split_seed(seed, n=ps.size(dims))
288-
289-
def generate_one(dim, seed):
290-
return tf.argsort(samplers.uniform(
291-
[num_results, dim], seed=seed), axis=-1)
292-
293-
return tf.concat([generate_one(dim, seed)
294-
for dim, seed in zip(dims, tf.unstack(seeds))],
295-
axis=-1)
287+
n = dims.size
288+
max_size = np.max(dims)
289+
samples = samplers.uniform([num_results, n, max_size], seed=seed)
290+
should_mask = np.arange(max_size) >= dims[..., np.newaxis]
291+
# Choose a number that does not affect the permutation and relative location.
292+
samples = tf.where(
293+
should_mask,
294+
dtype_util.as_numpy_dtype(samples.dtype)(np.arange(max_size) + 10.),
295+
samples)
296+
samples = tf.argsort(samples, axis=-1)
297+
# Generate the set of indices to gather.
298+
should_mask = np.tile(should_mask, [num_results, 1, 1])
299+
indices = np.stack(np.where(~should_mask), axis=-1)
300+
return tf.gather_nd(samples, indices)
296301

297302

298303
def _get_indices(num_results, sequence_indices, dtype, name=None):

0 commit comments

Comments
 (0)