Skip to content

Commit d1b11b9

Browse files
Update decomposition.py
1 parent 838888a commit d1b11b9

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tensorflow_probability/python/sts/decomposition.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _split_covariance_into_marginals(covariance, block_sizes):
3939

4040

4141
def _decompose_from_posterior_marginals(
42-
model, posterior_means, posterior_covs, parameter_samples):
42+
model, posterior_means, posterior_covs, parameter_samples, initial_step=0):
4343
"""Utility method to decompose a joint posterior into components.
4444
4545
Args:
@@ -89,7 +89,7 @@ def _decompose_from_posterior_marginals(
8989
tf.shape(posterior_means))[-2]
9090
component_ssms = model.make_component_state_space_models(
9191
num_timesteps=num_timesteps,
92-
param_vals=parameter_samples)
92+
param_vals=parameter_samples, initial_step=initial_step)
9393
component_predictive_dists = collections.OrderedDict()
9494
for (component, component_ssm,
9595
component_mean, component_cov) in zip(model.components, component_ssms,
@@ -221,7 +221,7 @@ def decompose_by_component(model, observed_time_series, parameter_samples):
221221
model, posterior_means, posterior_covs, parameter_samples)
222222

223223

224-
def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
224+
def decompose_forecast_by_component(model, forecast_dist, parameter_samples, observed_time_series):
225225
"""Decompose a forecast distribution into contributions from each component.
226226
227227
Args:
@@ -322,6 +322,5 @@ def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
322322
forecast_latent_mean, source_idx=-3, dest_idx=0)
323323
forecast_latent_covs = dist_util.move_dimension(
324324
forecast_latent_covs, source_idx=-4, dest_idx=0)
325-
326325
return _decompose_from_posterior_marginals(
327-
model, forecast_latent_mean, forecast_latent_covs, parameter_samples)
326+
model, forecast_latent_mean, forecast_latent_covs, parameter_samples, initial_step=len(observed_time_series))

0 commit comments

Comments
 (0)