Skip to content

Commit 039070d

Browse files
Merge pull request #1343 from europeanplaice:dev
PiperOrigin-RevId: 377076353
2 parents 497fc1e + e5dda26 commit 039070d

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tensorflow_probability/python/sts/decomposition.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _split_covariance_into_marginals(covariance, block_sizes):
3939

4040

4141
def _decompose_from_posterior_marginals(
42-
model, posterior_means, posterior_covs, parameter_samples):
42+
model, posterior_means, posterior_covs, parameter_samples, initial_step=0):
4343
"""Utility method to decompose a joint posterior into components.
4444
4545
Args:
@@ -59,6 +59,8 @@ def _decompose_from_posterior_marginals(
5959
param.prior.event_shape]) for param in model.parameters]`. This may
6060
optionally also be a map (Python `dict`) of parameter names to
6161
`Tensor` values.
62+
initial_step: optional `int` specifying the initial timestep of the
63+
decomposition.
6264
6365
Returns:
6466
component_dists: A `collections.OrderedDict` instance mapping
@@ -89,7 +91,7 @@ def _decompose_from_posterior_marginals(
8991
tf.shape(posterior_means))[-2]
9092
component_ssms = model.make_component_state_space_models(
9193
num_timesteps=num_timesteps,
92-
param_vals=parameter_samples)
94+
param_vals=parameter_samples, initial_step=initial_step)
9395
component_predictive_dists = collections.OrderedDict()
9496
for (component, component_ssm,
9597
component_mean, component_cov) in zip(model.components, component_ssms,
@@ -322,6 +324,6 @@ def decompose_forecast_by_component(model, forecast_dist, parameter_samples):
322324
forecast_latent_mean, source_idx=-3, dest_idx=0)
323325
forecast_latent_covs = dist_util.move_dimension(
324326
forecast_latent_covs, source_idx=-4, dest_idx=0)
325-
326327
return _decompose_from_posterior_marginals(
327-
model, forecast_latent_mean, forecast_latent_covs, parameter_samples)
328+
model, forecast_latent_mean, forecast_latent_covs, parameter_samples,
329+
initial_step=forecast_lgssm.initial_step)

0 commit comments

Comments
 (0)