@@ -39,7 +39,7 @@ def _split_covariance_into_marginals(covariance, block_sizes):
39
39
40
40
41
41
def _decompose_from_posterior_marginals (
42
- model , posterior_means , posterior_covs , parameter_samples ):
42
+ model , posterior_means , posterior_covs , parameter_samples , initial_step = 0 ):
43
43
"""Utility method to decompose a joint posterior into components.
44
44
45
45
Args:
@@ -59,6 +59,8 @@ def _decompose_from_posterior_marginals(
59
59
param.prior.event_shape]) for param in model.parameters]`. This may
60
60
optionally also be a map (Python `dict`) of parameter names to
61
61
`Tensor` values.
62
+ initial_step: optional `int` specifying the initial timestep of the
63
+ decomposition.
62
64
63
65
Returns:
64
66
component_dists: A `collections.OrderedDict` instance mapping
@@ -89,7 +91,7 @@ def _decompose_from_posterior_marginals(
89
91
tf .shape (posterior_means ))[- 2 ]
90
92
component_ssms = model .make_component_state_space_models (
91
93
num_timesteps = num_timesteps ,
92
- param_vals = parameter_samples )
94
+ param_vals = parameter_samples , initial_step = initial_step )
93
95
component_predictive_dists = collections .OrderedDict ()
94
96
for (component , component_ssm ,
95
97
component_mean , component_cov ) in zip (model .components , component_ssms ,
@@ -322,6 +324,6 @@ def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
322
324
forecast_latent_mean , source_idx = - 3 , dest_idx = 0 )
323
325
forecast_latent_covs = dist_util .move_dimension (
324
326
forecast_latent_covs , source_idx = - 4 , dest_idx = 0 )
325
-
326
327
return _decompose_from_posterior_marginals (
327
- model , forecast_latent_mean , forecast_latent_covs , parameter_samples )
328
+ model , forecast_latent_mean , forecast_latent_covs , parameter_samples ,
329
+ initial_step = forecast_lgssm .initial_step )
0 commit comments