Skip to content

Commit d0b9d34

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Make tfd.Gamma.sample use log_space sampling under XLA/JAX.
Benchmarks have shown that log_space sampling is a bit slower in Graph mode, so we keep the old behavior for that configuration. This should help JAX the most, which typically does not have 64 bit dtype enabled. The old default assumed it was, causing warnings and reduced numerical precision. PiperOrigin-RevId: 452391450
1 parent 13952c8 commit d0b9d34

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tensorflow_probability/python/distributions/gamma.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,19 @@ def _log_rate_parameter(self):
265265
caveats.""")
266266
def _sample_n(self, n, seed=None):
267267
seed = samplers.sanitize_seed(seed, salt='gamma')
268+
log_space = implementation_selection.is_xla()
268269

269-
return random_gamma(
270+
res = random_gamma(
270271
shape=ps.convert_to_shape_tensor([n]),
271272
concentration=tf.convert_to_tensor(self.concentration),
272273
rate=None if self.rate is None else tf.convert_to_tensor(self.rate),
273274
log_rate=(None if self.log_rate is None else
274275
tf.convert_to_tensor(self.log_rate)),
276+
log_space=log_space,
275277
seed=seed)
278+
if log_space:
279+
res = tf.math.exp(res)
280+
return res
276281

277282
def _log_prob(self, x, rate=None):
278283
concentration = tf.convert_to_tensor(self.concentration)

tensorflow_probability/python/internal/implementation_selection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
__all__ = [
2424
'implementation_selecting',
25+
'is_xla',
2526
'never_runs_functions_eagerly',
2627
]
2728

@@ -47,7 +48,7 @@
4748
NUMPY_MODE = False
4849

4950

50-
def _is_xla():
51+
def is_xla():
5152
"""Returns `True` when we are tracing a function for XLA compilation."""
5253
if JAX_MODE:
5354
return True
@@ -134,7 +135,7 @@ def stub_fn(**kwargs):
134135

135136
def impl_selecting_fn(**kwargs):
136137
"""The wrapper function to be returned."""
137-
if _is_xla(): # JAX, XLA breakout.
138+
if is_xla(): # JAX, XLA breakout.
138139
return default_fn(**kwargs)
139140
if NUMPY_MODE: # Numpy breakout.
140141
return cpu_fn(**kwargs)

0 commit comments

Comments
 (0)