@@ -39,7 +39,7 @@ def _split_covariance_into_marginals(covariance, block_sizes):
3939
4040
4141def _decompose_from_posterior_marginals (
42- model , posterior_means , posterior_covs , parameter_samples ):
42+ model , posterior_means , posterior_covs , parameter_samples , initial_step = 0 ):
4343 """Utility method to decompose a joint posterior into components.
4444
4545 Args:
@@ -89,7 +89,7 @@ def _decompose_from_posterior_marginals(
8989 tf .shape (posterior_means ))[- 2 ]
9090 component_ssms = model .make_component_state_space_models (
9191 num_timesteps = num_timesteps ,
92- param_vals = parameter_samples )
92+ param_vals = parameter_samples , initial_step = initial_step )
9393 component_predictive_dists = collections .OrderedDict ()
9494 for (component , component_ssm ,
9595 component_mean , component_cov ) in zip (model .components , component_ssms ,
@@ -221,7 +221,7 @@ def decompose_by_component(model, observed_time_series, parameter_samples):
221221 model , posterior_means , posterior_covs , parameter_samples )
222222
223223
224- def decompose_forecast_by_component (model , forecast_dist , parameter_samples ):
224+ def decompose_forecast_by_component (model , forecast_dist , parameter_samples , observed_time_series ):
225225 """Decompose a forecast distribution into contributions from each component.
226226
227227 Args:
@@ -322,6 +322,5 @@ def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
322322 forecast_latent_mean , source_idx = - 3 , dest_idx = 0 )
323323 forecast_latent_covs = dist_util .move_dimension (
324324 forecast_latent_covs , source_idx = - 4 , dest_idx = 0 )
325-
326325 return _decompose_from_posterior_marginals (
327- model , forecast_latent_mean , forecast_latent_covs , parameter_samples )
326+ model , forecast_latent_mean , forecast_latent_covs , parameter_samples , initial_step = len ( observed_time_series ) )
0 commit comments