Skip to content

Commit 92497f4

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Fix TruncatedNormal's sample/log_prob dtype when jax_enable_x64=True.
PiperOrigin-RevId: 378583501
1 parent 3427e09 commit 92497f4

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

tensorflow_probability/python/distributions/truncated_normal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,15 @@ def grad(dy):
341341
return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
342342

343343
def _log_prob(self, x):
344+
np_dtype = dtype_util.as_numpy_dtype(x.dtype)
344345
loc, scale, low, high = self._loc_scale_low_high()
345-
log_prob = -(0.5 * tf.square(
346-
(x - loc) / scale) + 0.5 * np.log(2. * np.pi) + tf.math.log(scale) +
346+
log_prob = -(np_dtype(0.5) * tf.square(
347+
(x - loc) / scale) + (0.5 * np.log(2. * np.pi)).astype(np_dtype) +
348+
tf.math.log(scale) +
347349
self._log_normalizer(loc=loc, scale=scale, low=low, high=high))
348350
# p(x) is 0 outside the bounds.
349351
bounded_log_prob = tf.where((x > high) | (x < low),
350-
dtype_util.as_numpy_dtype(x.dtype)(-np.inf),
352+
np_dtype(-np.inf),
351353
log_prob)
352354
return bounded_log_prob
353355

tensorflow_probability/python/internal/backend/numpy/random_generators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ def _truncated_normal_jax(
246246
import jax.random as jaxrand # pylint: disable=g-import-not-at-top
247247
if seed is None:
248248
raise ValueError('Must provide PRNGKey to sample in JAX.')
249+
dtype = utils.common_dtype([means, stddevs, minvals, maxvals])
249250
std_low = (minvals - means) / stddevs
250251
std_high = (maxvals - means) / stddevs
251-
std_samps = jaxrand.truncated_normal(seed, std_low, std_high, shape)
252+
std_samps = jaxrand.truncated_normal(seed, std_low, std_high, shape, dtype)
252253
return std_samps * stddevs + means
253254

254255

0 commit comments

Comments
 (0)