@@ -182,17 +182,14 @@ def sample_halton_sequence(dim,
182182 # The coefficient dimension is an intermediate axes which will hold the
183183 # weights of the starting integer when expressed in the (prime) base for
184184 # an event dimension.
185- if num_results is not None :
186- num_results = tf .convert_to_tensor (num_results )
187185 if sequence_indices is not None :
188186 sequence_indices = tf .convert_to_tensor (sequence_indices )
189187 indices = _get_indices (num_results , sequence_indices , dtype )
190- radixes = tf .constant (_PRIMES [0 :dim ], dtype = dtype , shape = [dim , 1 ])
191-
192- max_sizes_by_axes = _base_expansion_size (
193- tf .reduce_max (indices ), radixes )
194-
195- max_size = tf .reduce_max (max_sizes_by_axes )
188+ if num_results is None :
189+ num_results = ps .reduce_max (indices )
190+ radixes = _PRIMES [0 :dim ][..., np .newaxis ]
191+ max_sizes_by_axes = _base_expansion_size (num_results , radixes , dtype )
192+ max_size = ps .reduce_max (max_sizes_by_axes )
196193
197194 # The powers of the radixes that we will need. Note that there is a bit
198195 # of an excess here. Suppose we need the place value coefficients of 7
@@ -204,14 +201,13 @@ def sample_halton_sequence(dim,
204201 # dimensions, then the 10th prime (29) we will end up computing 29^10 even
205202 # though we don't need it. We avoid this by setting the exponents for each
206203 # axes to 0 beyond the maximum value needed for that dimension.
207- exponents_by_axes = tf .tile ([tf .range (max_size )], [dim , 1 ])
204+ exponents_by_axes = tf .tile ([tf .range (max_size , dtype = dtype )], [dim , 1 ])
208205
209206 # The mask is true for those coefficients that are irrelevant.
210207 weight_mask = exponents_by_axes < max_sizes_by_axes
211- capped_exponents = tf .where (weight_mask ,
212- exponents_by_axes ,
213- tf .constant (0 , exponents_by_axes .dtype ))
214- weights = radixes ** capped_exponents
208+ capped_exponents = tf .where (
209+ weight_mask , exponents_by_axes , dtype_util .as_numpy_dtype (dtype )(0. ))
210+ weights = tf .cast (radixes ** capped_exponents , dtype = dtype )
215211 # The following computes the base b expansion of the indices. Suppose,
216212 # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with
217213 # the vector (1, b, b^2, b^3, ...) will produce
@@ -246,22 +242,22 @@ def sample_halton_sequence(dim,
246242 zero_correction = samplers .uniform ([dim , 1 ],
247243 seed = zero_correction_seed ,
248244 dtype = dtype )
249- zero_correction /= radixes ** max_sizes_by_axes
245+ zero_correction /= tf . cast ( radixes ** max_sizes_by_axes , dtype )
250246 return base_values + tf .reshape (zero_correction , [- 1 ])
251247
252248
253249def _randomize (coeffs , radixes , seed = None ):
254250 """Applies the Owen (2017) randomization to the coefficients."""
255251 given_dtype = coeffs .dtype
256252 coeffs = tf .cast (coeffs , dtype = tf .int32 )
257- num_coeffs = tf .shape (coeffs )[- 1 ]
258- radixes = tf .reshape (tf .cast (radixes , dtype = tf .int32 ), shape = [- 1 ])
259- perms = _get_permutations (num_coeffs , radixes , seed = seed )
253+ num_coeffs = ps .shape (coeffs )[- 1 ]
254+ perms = _get_permutations (num_coeffs , np .squeeze (radixes , axis = - 1 ), seed = seed )
260255 perms = tf .reshape (perms , shape = [- 1 ])
256+ radixes = tf .reshape (tf .cast (radixes , dtype = tf .int32 ), shape = [- 1 ])
261257 radix_sum = tf .reduce_sum (radixes )
262258 radix_offsets = tf .reshape (tf .cumsum (radixes , exclusive = True ),
263259 shape = [- 1 , 1 ])
264- offsets = radix_offsets + tf .range (num_coeffs ) * radix_sum
260+ offsets = radix_offsets + ps .range (num_coeffs , dtype = tf . int32 ) * radix_sum
265261 permuted_coeffs = tf .gather (perms , coeffs + offsets )
266262 return tf .cast (permuted_coeffs , dtype = given_dtype )
267263
@@ -291,10 +287,11 @@ def _get_permutations(num_results, dims, seed=None):
291287 seeds = samplers .split_seed (seed , n = ps .size (dims ))
292288
293289 def generate_one (dim , seed ):
294- return tf .argsort (samplers .uniform ([num_results , dim ], seed = seed ), axis = - 1 )
290+ return tf .argsort (samplers .uniform (
291+ [num_results , dim ], seed = seed ), axis = - 1 )
295292
296293 return tf .concat ([generate_one (dim , seed )
297- for dim , seed in zip (tf . unstack ( dims ) , tf .unstack (seeds ))],
294+ for dim , seed in zip (dims , tf .unstack (seeds ))],
298295 axis = - 1 )
299296
300297
@@ -325,8 +322,13 @@ def _get_indices(num_results, sequence_indices, dtype, name=None):
325322 """
326323 with tf .name_scope (name or 'get_indices' ):
327324 if sequence_indices is None :
328- num_results = tf .cast (num_results , dtype = dtype )
329- sequence_indices = tf .range (num_results , dtype = dtype )
325+ np_dtype = dtype_util .as_numpy_dtype (dtype )
326+ num_results_ = tf .get_static_value (num_results )
327+ if num_results_ is not None :
328+ sequence_indices = ps .range (np_dtype (num_results_ ), dtype = dtype )
329+ else :
330+ num_results = tf .cast (num_results , dtype = dtype )
331+ sequence_indices = ps .range (num_results , dtype = dtype )
330332 else :
331333 sequence_indices = tf .cast (sequence_indices , dtype )
332334
@@ -338,7 +340,7 @@ def _get_indices(num_results, sequence_indices, dtype, name=None):
338340 return tf .reshape (indices , [- 1 , 1 , 1 ])
339341
340342
341- def _base_expansion_size (num , bases ):
343+ def _base_expansion_size (num , bases , dtype ):
342344 """Computes the number of terms in the place value expansion.
343345
344346 Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of
@@ -349,16 +351,23 @@ def _base_expansion_size(num, bases):
349351 $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$
350352
351353 Args:
352- num: Scalar `Tensor` of dtype either `float32 ` or `float64 `. The number to
354+ num: Scalar `Tensor` of dtype either `int32 ` or `int64 `. The number to
353355 compute the base expansion size of.
354356 bases: `Tensor` of the same dtype as num. The bases to compute the size
355357 against.
358+ dtype: Return `dtype`.
356359
357360 Returns:
358- Tensor of same dtype and shape as `bases` containing the size of num when
361+ Tensor of dtype ` dtype` and shape as `bases` containing the size of num when
359362 written in that base.
360363 """
361- return tf .floor (tf .math .log (num ) / tf .math .log (bases )) + 1
364+ num_ = tf .get_static_value (num )
365+ if num_ is not None :
366+ return (np .floor (np .log (num_ ) / np .log (bases )) + 1 ).astype (
367+ dtype_util .as_numpy_dtype (dtype ))
368+
369+ return tf .floor (
370+ tf .math .log (tf .cast (num , dtype )) / tf .math .log (tf .cast (bases , dtype ))) + 1
362371
363372
364373def _primes_less_than (n ):
0 commit comments