@@ -59,6 +59,7 @@ def _decompose_from_posterior_marginals(
59
59
param.prior.event_shape]) for param in model.parameters]`. This may
60
60
optionally also be a map (Python `dict`) of parameter names to
61
61
`Tensor` values.
62
+ initial_step: optional `int` specifying the initial timestep of the decomposition.
62
63
63
64
Returns:
64
65
component_dists: A `collections.OrderedDict` instance mapping
@@ -221,7 +222,7 @@ def decompose_by_component(model, observed_time_series, parameter_samples):
221
222
model , posterior_means , posterior_covs , parameter_samples )
222
223
223
224
224
- def decompose_forecast_by_component (model , forecast_dist , parameter_samples , observed_time_series ):
225
+ def decompose_forecast_by_component (model , forecast_dist , parameter_samples ):
225
226
"""Decompose a forecast distribution into contributions from each component.
226
227
227
228
Args:
@@ -323,4 +324,4 @@ def decompose_forecast_by_component(model, forecast_dist, parameter_samples, obs
323
324
forecast_latent_covs = dist_util .move_dimension (
324
325
forecast_latent_covs , source_idx = - 4 , dest_idx = 0 )
325
326
return _decompose_from_posterior_marginals (
326
- model , forecast_latent_mean , forecast_latent_covs , parameter_samples , initial_step = len ( observed_time_series ) )
327
+ model , forecast_latent_mean , forecast_latent_covs , parameter_samples , initial_step = forecast_lgssm . initial_step )
0 commit comments