@@ -380,7 +380,6 @@ def fit_with_gibbs_sampling(model,
380
380
update for the posterior precision of the weight in case of a spike and
381
381
slab sampler.
382
382
383
-
384
383
Returns:
385
384
model: A `GibbsSamplerState` structure of posterior samples.
386
385
"""
@@ -436,8 +435,13 @@ def fit_with_gibbs_sampling(model,
436
435
seed = samplers .sanitize_seed (seed , salt = 'initial_GibbsSamplerState' ))
437
436
438
437
sampler_loop_body = _build_sampler_loop_body (
439
- model , observed_time_series , is_missing , default_pseudo_observations ,
440
- experimental_use_dynamic_cholesky , experimental_use_weight_adjustment )
438
+ model = model ,
439
+ observed_time_series = observed_time_series ,
440
+ is_missing = is_missing ,
441
+ default_pseudo_observations = default_pseudo_observations ,
442
+ experimental_use_dynamic_cholesky = experimental_use_dynamic_cholesky ,
443
+ experimental_use_weight_adjustment = experimental_use_weight_adjustment
444
+ )
441
445
442
446
samples = tf .scan (sampler_loop_body ,
443
447
np .arange (num_warmup_steps + num_results ), initial_state )
@@ -885,6 +889,19 @@ def _build_sampler_loop_body(model,
885
889
else :
886
890
weights_prior_scale = (regression_component .parameters [0 ].prior .scale )
887
891
892
+ # Sub-selects in `forward_filter_sequential` take up a lot of the runtime
893
+ # with a dynamic Cholesky, but compiling here seems to help.
894
+ # TODO(b/234726324): Should this always be compiled?
895
+ if experimental_use_dynamic_cholesky :
896
+ resample_latents = tf .function (
897
+ jit_compile = True , autograph = False )(
898
+ _resample_latents )
899
+ resample_scale = tf .function (
900
+ jit_compile = True , autograph = False )(
901
+ _resample_scale )
902
+ else :
903
+ resample_latents = _resample_latents
904
+ resample_scale = _resample_scale
888
905
def sampler_loop_body (previous_sample , _ ):
889
906
"""Runs one sampler iteration, resampling all model variables."""
890
907
@@ -940,7 +957,7 @@ def sampler_loop_body(previous_sample, _):
940
957
observation_noise_scale = previous_sample .observation_noise_scale
941
958
weights = previous_sample .weights
942
959
943
- latents = _resample_latents (
960
+ latents = resample_latents (
944
961
observed_residuals = regression_residuals ,
945
962
level_scale = previous_sample .level_scale ,
946
963
slope_scale = previous_sample .slope_scale if model_has_slope else None ,
@@ -956,20 +973,20 @@ def sampler_loop_body(previous_sample, _):
956
973
slope_residuals = slope [..., 1 :] - slope [..., :- 1 ]
957
974
958
975
# Estimate level scale from the empirical changes in level.
959
- level_scale = _resample_scale (
976
+ level_scale = resample_scale (
960
977
prior = level_scale_variance_prior ,
961
978
observed_residuals = level_residuals ,
962
979
is_missing = None ,
963
980
seed = level_scale_seed )
964
981
if model_has_slope :
965
- slope_scale = _resample_scale (
982
+ slope_scale = resample_scale (
966
983
prior = slope_scale_variance_prior ,
967
984
observed_residuals = slope_residuals ,
968
985
is_missing = None ,
969
986
seed = slope_scale_seed )
970
987
if not (regression_component and model_has_spike_slab_regression ):
971
988
# Estimate noise scale from the residuals.
972
- observation_noise_scale = _resample_scale (
989
+ observation_noise_scale = resample_scale (
973
990
prior = observation_noise_variance_prior ,
974
991
observed_residuals = regression_residuals - level ,
975
992
is_missing = is_missing ,
0 commit comments