@@ -395,10 +395,6 @@ def fit_with_gibbs_sampling(model,
395
395
Returns:
396
396
model: A `GibbsSamplerState` structure of posterior samples.
397
397
"""
398
- if not hasattr (model , 'supports_gibbs_sampling' ):
399
- raise ValueError ('This STS model does not support Gibbs sampling. Models '
400
- 'for Gibbs sampling must be created using the '
401
- 'method `build_model_for_gibbs_fitting`.' )
402
398
if not tf .nest .is_nested (num_chains ):
403
399
num_chains = [num_chains ]
404
400
@@ -420,7 +416,8 @@ def fit_with_gibbs_sampling(model,
420
416
# the slope_scale is always zero.
421
417
initial_slope_scale = 0.
422
418
initial_slope = 0.
423
- if isinstance (model .components [0 ], local_linear_trend .LocalLinearTrend ):
419
+ (_ , level_component ), _ = (_get_components_from_model (model ))
420
+ if isinstance (level_component , local_linear_trend .LocalLinearTrend ):
424
421
initial_slope_scale = 1. * tf .ones (batch_shape , dtype = dtype )
425
422
initial_slope = tf .zeros (level_slope_shape , dtype = dtype )
426
423
@@ -476,17 +473,22 @@ def model_parameter_samples_from_gibbs_samples(model, gibbs_samples):
476
473
A set of posterior samples, that can be used with `make_state_space_model`
477
474
or `sts.forecast`.
478
475
"""
479
- if not hasattr (model , 'supports_gibbs_sampling' ):
480
- raise ValueError ('This STS model does not support Gibbs sampling. Models '
481
- 'for Gibbs sampling must be created using the '
482
- 'method `build_model_for_gibbs_fitting`.' )
476
+ # Make use of the indices in the model to avoid requiring a specific
477
+ # order of components.
478
+ ((level_component_index , level_component ),
479
+ (regression_component_index , _ )) = (
480
+ _get_components_from_model (model ))
481
+ model_parameter_samples = (gibbs_samples .observation_noise_scale ,)
483
482
484
- if isinstance (model .components [0 ], local_linear_trend .LocalLinearTrend ):
485
- return (gibbs_samples .observation_noise_scale , gibbs_samples .level_scale ,
486
- gibbs_samples .slope_scale , gibbs_samples .weights )
487
- else :
488
- return (gibbs_samples .observation_noise_scale , gibbs_samples .level_scale ,
489
- gibbs_samples .weights )
483
+ for index in range (len (model .components )):
484
+ if index == level_component_index :
485
+ model_parameter_samples += (gibbs_samples .level_scale ,)
486
+ if isinstance (level_component , local_linear_trend .LocalLinearTrend ):
487
+ model_parameter_samples += (gibbs_samples .slope_scale ,)
488
+ elif index == regression_component_index :
489
+ model_parameter_samples += (gibbs_samples .weights ,)
490
+
491
+ return model_parameter_samples
490
492
491
493
492
494
def one_step_predictive (model ,
@@ -582,7 +584,7 @@ def one_step_predictive(model,
582
584
tf .range (1. , num_forecast_steps + 1. , dtype = forecast_level .dtype ))
583
585
584
586
level_pred = tf .concat (
585
- ([thinned_samples .level ] if use_zero_step_prediction else [
587
+ ([thinned_samples .level ] if use_zero_step_prediction else [ # pylint:disable=g-long-ternary
586
588
thinned_samples .level [..., :1 ], # t == 0
587
589
(thinned_samples .level [..., :- 1 ] + thinned_samples .slope [..., :- 1 ]
588
590
) # 1 <= t < T. Constructs the next level from previous level
@@ -784,6 +786,56 @@ def _resample_scale(prior, observed_residuals, is_missing=None, seed=None):
784
786
posterior , seed = seed )
785
787
786
788
789
+ def _get_components_from_model (model ):
790
+ """Returns the split-apart components from an STS model.
791
+
792
+ Args:
793
+ model: A `tf.sts.StructuralTimeSeries` to split apart.
794
+
795
+ Returns:
796
+ A tuple of (index and level component, index and regression component) or
797
+ an exception. The regression component may be None (along with its index -
798
+ (None, None)).
799
+
800
+ Each 'index and component' is a tuple of (index, component), where index
801
+ is the position in the model.
802
+ """
803
+ if not hasattr (model , 'supports_gibbs_sampling' ):
804
+ raise ValueError ('This STS model does not support Gibbs sampling. Models '
805
+ 'for Gibbs sampling must be created using the '
806
+ 'method `build_model_for_gibbs_fitting`.' )
807
+
808
+ level_components = []
809
+ regression_components = []
810
+
811
+ for index , component in enumerate (model .components ):
812
+ if (isinstance (component , local_level .LocalLevel ) or
813
+ isinstance (component , local_linear_trend .LocalLinearTrend )):
814
+ level_components .append ((index , component ))
815
+ elif (isinstance (component , regression .LinearRegression ) or
816
+ isinstance (component , SpikeAndSlabSparseLinearRegression )):
817
+ regression_components .append ((index , component ))
818
+ else :
819
+ raise NotImplementedError (
820
+ 'Found unsupported model component for Gibbs Sampling: {}' .format (
821
+ component ))
822
+
823
+ if len (level_components ) != 1 :
824
+ raise ValueError (
825
+ 'Expected exactly one level component, found {} components.' .format (
826
+ len (level_components )))
827
+ level_component = level_components [0 ]
828
+
829
+ regression_component = (None , None )
830
+ if len (regression_components ) > 1 :
831
+ raise ValueError (
832
+ 'Expected at most one regression component, found {} components.'
833
+ .format (len (regression_components )))
834
+ elif len (regression_components ) == 1 :
835
+ regression_component = regression_components [0 ]
836
+ return level_component , regression_component
837
+
838
+
787
839
def _build_sampler_loop_body (model ,
788
840
observed_time_series ,
789
841
is_missing = None ,
@@ -819,29 +871,12 @@ def _build_sampler_loop_body(model,
819
871
"""
820
872
if JAX_MODE and experimental_use_dynamic_cholesky :
821
873
raise ValueError ('Dynamic Cholesky decomposition not supported in JAX' )
822
- level_component = model .components [0 ]
823
- if not (isinstance (level_component , local_level .LocalLevel ) or
824
- isinstance (level_component , local_linear_trend .LocalLinearTrend )):
825
- raise ValueError (
826
- 'Expected the first model component to be an instance of '
827
- '`tfp.sts.LocalLevel` or `tfp.local_linear_trend.LocalLinearTrend`; '
828
- 'instead saw {}' .format (level_component ))
874
+ (_ , level_component ), (_ , regression_component ) = (
875
+ _get_components_from_model (model ))
829
876
model_has_slope = isinstance (level_component ,
830
877
local_linear_trend .LocalLinearTrend )
831
878
832
- # TODO(kloveless): When we add support for more flexible models, remove
833
- # this assumption.
834
- regression_component = (None if len (model .components ) != 2 else
835
- model .components [1 ])
836
- if regression_component :
837
- if not (isinstance (regression_component , regression .LinearRegression ) or
838
- isinstance (regression_component ,
839
- SpikeAndSlabSparseLinearRegression )):
840
- raise ValueError (
841
- 'Expected the second model component to be an instance of '
842
- '`tfp.sts.LinearRegression` or '
843
- '`SpikeAndSlabSparseLinearRegression`; '
844
- 'instead saw {}' .format (regression_component ))
879
+ if regression_component is not None :
845
880
model_has_spike_slab_regression = isinstance (
846
881
regression_component , SpikeAndSlabSparseLinearRegression )
847
882
0 commit comments