Skip to content

Commit 12e0f6b

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Improve dynamic Cholesky performance by compiling private functions.
PiperOrigin-RevId: 452573333
1 parent 4547374 commit 12e0f6b

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

tensorflow_probability/python/experimental/sts_gibbs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ multi_substrate_py_library(
9191
# tensorflow dep,
9292
"//tensorflow_probability/python/bijectors:softplus",
9393
"//tensorflow_probability/python/distributions:bernoulli",
94+
"//tensorflow_probability/python/distributions:gamma",
9495
"//tensorflow_probability/python/distributions:inverse_gamma",
9596
"//tensorflow_probability/python/distributions:joint_distribution_auto_batched",
9697
"//tensorflow_probability/python/distributions:sample",

tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
2222
from tensorflow_probability.python.distributions import bernoulli
23+
from tensorflow_probability.python.distributions import gamma
2324
from tensorflow_probability.python.distributions import inverse_gamma
2425
from tensorflow_probability.python.distributions import joint_distribution_auto_batched
2526
from tensorflow_probability.python.distributions import sample as sample_dist
@@ -55,7 +56,11 @@ def _parameter_properties(cls, dtype, num_classes=None):
5556
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
5657

5758
def _sample_n(self, n, seed=None):
58-
xs = super()._sample_n(n, seed=seed)
59+
# TODO(b/151571025): revert to `super()._sample_n` once the InverseGamma
60+
# sampler is XLA-able.
61+
xs = 1. / gamma.Gamma(
62+
concentration=self.concentration, rate=self.scale).sample(
63+
n, seed=seed)
5964
if self._upper_bound is not None:
6065
xs = tf.minimum(xs, self._upper_bound)
6166
return xs

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def fit_with_gibbs_sampling(model,
380380
update for the posterior precision of the weight in case of a spike and
381381
slab sampler.
382382
383-
384383
Returns:
385384
model: A `GibbsSamplerState` structure of posterior samples.
386385
"""
@@ -436,8 +435,13 @@ def fit_with_gibbs_sampling(model,
436435
seed=samplers.sanitize_seed(seed, salt='initial_GibbsSamplerState'))
437436

438437
sampler_loop_body = _build_sampler_loop_body(
439-
model, observed_time_series, is_missing, default_pseudo_observations,
440-
experimental_use_dynamic_cholesky, experimental_use_weight_adjustment)
438+
model=model,
439+
observed_time_series=observed_time_series,
440+
is_missing=is_missing,
441+
default_pseudo_observations=default_pseudo_observations,
442+
experimental_use_dynamic_cholesky=experimental_use_dynamic_cholesky,
443+
experimental_use_weight_adjustment=experimental_use_weight_adjustment
444+
)
441445

442446
samples = tf.scan(sampler_loop_body,
443447
np.arange(num_warmup_steps + num_results), initial_state)
@@ -885,6 +889,19 @@ def _build_sampler_loop_body(model,
885889
else:
886890
weights_prior_scale = (regression_component.parameters[0].prior.scale)
887891

892+
# Sub-selects in `forward_filter_sequential` take up a lot of the runtime
893+
# with a dynamic Cholesky, but compiling here seems to help.
894+
# TODO(b/234726324): Should this always be compiled?
895+
if experimental_use_dynamic_cholesky:
896+
resample_latents = tf.function(
897+
jit_compile=True, autograph=False)(
898+
_resample_latents)
899+
resample_scale = tf.function(
900+
jit_compile=True, autograph=False)(
901+
_resample_scale)
902+
else:
903+
resample_latents = _resample_latents
904+
resample_scale = _resample_scale
888905
def sampler_loop_body(previous_sample, _):
889906
"""Runs one sampler iteration, resampling all model variables."""
890907

@@ -940,7 +957,7 @@ def sampler_loop_body(previous_sample, _):
940957
observation_noise_scale = previous_sample.observation_noise_scale
941958
weights = previous_sample.weights
942959

943-
latents = _resample_latents(
960+
latents = resample_latents(
944961
observed_residuals=regression_residuals,
945962
level_scale=previous_sample.level_scale,
946963
slope_scale=previous_sample.slope_scale if model_has_slope else None,
@@ -956,20 +973,20 @@ def sampler_loop_body(previous_sample, _):
956973
slope_residuals = slope[..., 1:] - slope[..., :-1]
957974

958975
# Estimate level scale from the empirical changes in level.
959-
level_scale = _resample_scale(
976+
level_scale = resample_scale(
960977
prior=level_scale_variance_prior,
961978
observed_residuals=level_residuals,
962979
is_missing=None,
963980
seed=level_scale_seed)
964981
if model_has_slope:
965-
slope_scale = _resample_scale(
982+
slope_scale = resample_scale(
966983
prior=slope_scale_variance_prior,
967984
observed_residuals=slope_residuals,
968985
is_missing=None,
969986
seed=slope_scale_seed)
970987
if not (regression_component and model_has_spike_slab_regression):
971988
# Estimate noise scale from the residuals.
972-
observation_noise_scale = _resample_scale(
989+
observation_noise_scale = resample_scale(
973990
prior=observation_noise_variance_prior,
974991
observed_residuals=regression_residuals - level,
975992
is_missing=is_missing,

0 commit comments

Comments
 (0)