Skip to content

Commit 8e17749

Browse files
srvasudetensorflower-gardener
authored andcommitted
Ensure that sample_halton_sequence is jittable.
PiperOrigin-RevId: 565432730
1 parent 6bba52b commit 8e17749

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,14 @@ def sample_halton_sequence(dim,
182182
# The coefficient dimension is an intermediate axes which will hold the
183183
# weights of the starting integer when expressed in the (prime) base for
184184
# an event dimension.
185-
if num_results is not None:
186-
num_results = tf.convert_to_tensor(num_results)
187185
if sequence_indices is not None:
188186
sequence_indices = tf.convert_to_tensor(sequence_indices)
189187
indices = _get_indices(num_results, sequence_indices, dtype)
190-
radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])
191-
192-
max_sizes_by_axes = _base_expansion_size(
193-
tf.reduce_max(indices), radixes)
194-
195-
max_size = tf.reduce_max(max_sizes_by_axes)
188+
if num_results is None:
189+
num_results = ps.reduce_max(indices)
190+
radixes = _PRIMES[0:dim][..., np.newaxis]
191+
max_sizes_by_axes = _base_expansion_size(num_results, radixes, dtype)
192+
max_size = ps.reduce_max(max_sizes_by_axes)
196193

197194
# The powers of the radixes that we will need. Note that there is a bit
198195
# of an excess here. Suppose we need the place value coefficients of 7
@@ -204,14 +201,13 @@ def sample_halton_sequence(dim,
204201
# dimensions, then the 10th prime (29) we will end up computing 29^10 even
205202
# though we don't need it. We avoid this by setting the exponents for each
206203
# axes to 0 beyond the maximum value needed for that dimension.
207-
exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1])
204+
exponents_by_axes = tf.tile([tf.range(max_size, dtype=dtype)], [dim, 1])
208205

209206
# The mask is true for those coefficients that are irrelevant.
210207
weight_mask = exponents_by_axes < max_sizes_by_axes
211-
capped_exponents = tf.where(weight_mask,
212-
exponents_by_axes,
213-
tf.constant(0, exponents_by_axes.dtype))
214-
weights = radixes ** capped_exponents
208+
capped_exponents = tf.where(
209+
weight_mask, exponents_by_axes, dtype_util.as_numpy_dtype(dtype)(0.))
210+
weights = tf.cast(radixes ** capped_exponents, dtype=dtype)
215211
# The following computes the base b expansion of the indices. Suppose,
216212
# x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with
217213
# the vector (1, b, b^2, b^3, ...) will produce
@@ -246,22 +242,22 @@ def sample_halton_sequence(dim,
246242
zero_correction = samplers.uniform([dim, 1],
247243
seed=zero_correction_seed,
248244
dtype=dtype)
249-
zero_correction /= radixes ** max_sizes_by_axes
245+
zero_correction /= tf.cast(radixes ** max_sizes_by_axes, dtype)
250246
return base_values + tf.reshape(zero_correction, [-1])
251247

252248

253249
def _randomize(coeffs, radixes, seed=None):
254250
"""Applies the Owen (2017) randomization to the coefficients."""
255251
given_dtype = coeffs.dtype
256252
coeffs = tf.cast(coeffs, dtype=tf.int32)
257-
num_coeffs = tf.shape(coeffs)[-1]
258-
radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1])
259-
perms = _get_permutations(num_coeffs, radixes, seed=seed)
253+
num_coeffs = ps.shape(coeffs)[-1]
254+
perms = _get_permutations(num_coeffs, np.squeeze(radixes, axis=-1), seed=seed)
260255
perms = tf.reshape(perms, shape=[-1])
256+
radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1])
261257
radix_sum = tf.reduce_sum(radixes)
262258
radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True),
263259
shape=[-1, 1])
264-
offsets = radix_offsets + tf.range(num_coeffs) * radix_sum
260+
offsets = radix_offsets + ps.range(num_coeffs, dtype=tf.int32) * radix_sum
265261
permuted_coeffs = tf.gather(perms, coeffs + offsets)
266262
return tf.cast(permuted_coeffs, dtype=given_dtype)
267263

@@ -291,10 +287,11 @@ def _get_permutations(num_results, dims, seed=None):
291287
seeds = samplers.split_seed(seed, n=ps.size(dims))
292288

293289
def generate_one(dim, seed):
294-
return tf.argsort(samplers.uniform([num_results, dim], seed=seed), axis=-1)
290+
return tf.argsort(samplers.uniform(
291+
[num_results, dim], seed=seed), axis=-1)
295292

296293
return tf.concat([generate_one(dim, seed)
297-
for dim, seed in zip(tf.unstack(dims), tf.unstack(seeds))],
294+
for dim, seed in zip(dims, tf.unstack(seeds))],
298295
axis=-1)
299296

300297

@@ -325,8 +322,13 @@ def _get_indices(num_results, sequence_indices, dtype, name=None):
325322
"""
326323
with tf.name_scope(name or 'get_indices'):
327324
if sequence_indices is None:
328-
num_results = tf.cast(num_results, dtype=dtype)
329-
sequence_indices = tf.range(num_results, dtype=dtype)
325+
np_dtype = dtype_util.as_numpy_dtype(dtype)
326+
num_results_ = tf.get_static_value(num_results)
327+
if num_results_ is not None:
328+
sequence_indices = ps.range(np_dtype(num_results_), dtype=dtype)
329+
else:
330+
num_results = tf.cast(num_results, dtype=dtype)
331+
sequence_indices = ps.range(num_results, dtype=dtype)
330332
else:
331333
sequence_indices = tf.cast(sequence_indices, dtype)
332334

@@ -338,7 +340,7 @@ def _get_indices(num_results, sequence_indices, dtype, name=None):
338340
return tf.reshape(indices, [-1, 1, 1])
339341

340342

341-
def _base_expansion_size(num, bases):
343+
def _base_expansion_size(num, bases, dtype):
342344
"""Computes the number of terms in the place value expansion.
343345
344346
Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of
@@ -349,16 +351,23 @@ def _base_expansion_size(num, bases):
349351
$$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$
350352
351353
Args:
352-
num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to
354+
num: Scalar `Tensor` of dtype either `int32` or `int64`. The number to
353355
compute the base expansion size of.
354356
bases: `Tensor` of the same dtype as num. The bases to compute the size
355357
against.
358+
dtype: Return `dtype`.
356359
357360
Returns:
358-
Tensor of same dtype and shape as `bases` containing the size of num when
361+
Tensor of dtype `dtype` and shape as `bases` containing the size of num when
359362
written in that base.
360363
"""
361-
return tf.floor(tf.math.log(num) / tf.math.log(bases)) + 1
364+
num_ = tf.get_static_value(num)
365+
if num_ is not None:
366+
return (np.floor(np.log(num_) / np.log(bases)) + 1).astype(
367+
dtype_util.as_numpy_dtype(dtype))
368+
369+
return tf.floor(
370+
tf.math.log(tf.cast(num, dtype)) / tf.math.log(tf.cast(bases, dtype))) + 1
362371

363372

364373
def _primes_less_than(n):

tensorflow_probability/python/mcmc/sample_halton_sequence_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,23 @@ def test_dtypes_works_correctly(self):
7777
self.assertEqual(self.evaluate(sample_float32).dtype, np.float32)
7878
self.assertEqual(self.evaluate(sample_float64).dtype, np.float64)
7979

80+
@test_util.disable_test_for_backend(
81+
disable_numpy=True, reason="Numpy has no notion of jit compilation.")
82+
def test_jit_works_correctly(self):
83+
@tf.function(jit_compile=True)
84+
def sample_float32():
85+
return sample_halton_sequence_lib.sample_halton_sequence(
86+
5, num_results=10, dtype=tf.float32, seed=test_util.test_seed())
87+
samples = sample_float32()
88+
self.assertEqual(samples.shape, [10, 5])
89+
90+
@tf.function(jit_compile=True)
91+
def sample_float64():
92+
return sample_halton_sequence_lib.sample_halton_sequence(
93+
5, num_results=10, dtype=tf.float64, seed=test_util.test_seed())
94+
samples = sample_float64()
95+
self.assertEqual(samples.shape, [10, 5])
96+
8097
def test_normal_integral_mean_and_var_correctly_estimated(self):
8198
n = 1000
8299
# This test is almost identical to the similarly named test in

0 commit comments

Comments
 (0)