Skip to content

Commit 0bd5811

Browse files
kylepltensorflower-gardener
authored andcommitted
Preparation change for seasonality support in the GibbsSampler:
* Stop assuming Seasonal components are not supported in tests. * Have a common function for deconstructing the list of STS component * Stop requiring the STS component be in a certain order (not needed for seasonality support, but moves towards a world where build_model_for_gibbs_fitting is not required) * Small typo fix in a test * Small test improvement to share the batch value PiperOrigin-RevId: 475990211
1 parent 01bf41b commit 0bd5811

File tree

3 files changed

+90
-62
lines changed

3 files changed

+90
-62
lines changed

tensorflow_probability/python/experimental/sts_gibbs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ multi_substrate_py_test(
118118
"//tensorflow_probability/python/sts/components:local_linear_trend",
119119
"//tensorflow_probability/python/sts/components:regression",
120120
"//tensorflow_probability/python/sts/components:seasonal",
121+
"//tensorflow_probability/python/sts/components:semilocal_linear_trend",
121122
"//tensorflow_probability/python/sts/components:sum",
122123
"//tensorflow_probability/python/sts/internal:missing_values_util",
123124
# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,6 @@ def fit_with_gibbs_sampling(model,
395395
Returns:
396396
model: A `GibbsSamplerState` structure of posterior samples.
397397
"""
398-
if not hasattr(model, 'supports_gibbs_sampling'):
399-
raise ValueError('This STS model does not support Gibbs sampling. Models '
400-
'for Gibbs sampling must be created using the '
401-
'method `build_model_for_gibbs_fitting`.')
402398
if not tf.nest.is_nested(num_chains):
403399
num_chains = [num_chains]
404400

@@ -420,7 +416,8 @@ def fit_with_gibbs_sampling(model,
420416
# the slope_scale is always zero.
421417
initial_slope_scale = 0.
422418
initial_slope = 0.
423-
if isinstance(model.components[0], local_linear_trend.LocalLinearTrend):
419+
(_, level_component), _ = (_get_components_from_model(model))
420+
if isinstance(level_component, local_linear_trend.LocalLinearTrend):
424421
initial_slope_scale = 1. * tf.ones(batch_shape, dtype=dtype)
425422
initial_slope = tf.zeros(level_slope_shape, dtype=dtype)
426423

@@ -476,17 +473,22 @@ def model_parameter_samples_from_gibbs_samples(model, gibbs_samples):
476473
A set of posterior samples, that can be used with `make_state_space_model`
477474
or `sts.forecast`.
478475
"""
479-
if not hasattr(model, 'supports_gibbs_sampling'):
480-
raise ValueError('This STS model does not support Gibbs sampling. Models '
481-
'for Gibbs sampling must be created using the '
482-
'method `build_model_for_gibbs_fitting`.')
476+
# Make use of the indices in the model to avoid requiring a specific
477+
# order of components.
478+
((level_component_index, level_component),
479+
(regression_component_index, _)) = (
480+
_get_components_from_model(model))
481+
model_parameter_samples = (gibbs_samples.observation_noise_scale,)
483482

484-
if isinstance(model.components[0], local_linear_trend.LocalLinearTrend):
485-
return (gibbs_samples.observation_noise_scale, gibbs_samples.level_scale,
486-
gibbs_samples.slope_scale, gibbs_samples.weights)
487-
else:
488-
return (gibbs_samples.observation_noise_scale, gibbs_samples.level_scale,
489-
gibbs_samples.weights)
483+
for index in range(len(model.components)):
484+
if index == level_component_index:
485+
model_parameter_samples += (gibbs_samples.level_scale,)
486+
if isinstance(level_component, local_linear_trend.LocalLinearTrend):
487+
model_parameter_samples += (gibbs_samples.slope_scale,)
488+
elif index == regression_component_index:
489+
model_parameter_samples += (gibbs_samples.weights,)
490+
491+
return model_parameter_samples
490492

491493

492494
def one_step_predictive(model,
@@ -582,7 +584,7 @@ def one_step_predictive(model,
582584
tf.range(1., num_forecast_steps + 1., dtype=forecast_level.dtype))
583585

584586
level_pred = tf.concat(
585-
([thinned_samples.level] if use_zero_step_prediction else [
587+
([thinned_samples.level] if use_zero_step_prediction else [ # pylint:disable=g-long-ternary
586588
thinned_samples.level[..., :1], # t == 0
587589
(thinned_samples.level[..., :-1] + thinned_samples.slope[..., :-1]
588590
) # 1 <= t < T. Constructs the next level from previous level
@@ -784,6 +786,56 @@ def _resample_scale(prior, observed_residuals, is_missing=None, seed=None):
784786
posterior, seed=seed)
785787

786788

789+
def _get_components_from_model(model):
790+
"""Returns the split-apart components from an STS model.
791+
792+
Args:
793+
model: A `tf.sts.StructuralTimeSeries` to split apart.
794+
795+
Returns:
796+
A tuple of (index and level component, index and regression component) or
797+
an exception. The regression component may be None (along with its index -
798+
(None, None)).
799+
800+
Each 'index and component' is a tuple of (index, component), where index
801+
is the position in the model.
802+
"""
803+
if not hasattr(model, 'supports_gibbs_sampling'):
804+
raise ValueError('This STS model does not support Gibbs sampling. Models '
805+
'for Gibbs sampling must be created using the '
806+
'method `build_model_for_gibbs_fitting`.')
807+
808+
level_components = []
809+
regression_components = []
810+
811+
for index, component in enumerate(model.components):
812+
if (isinstance(component, local_level.LocalLevel) or
813+
isinstance(component, local_linear_trend.LocalLinearTrend)):
814+
level_components.append((index, component))
815+
elif (isinstance(component, regression.LinearRegression) or
816+
isinstance(component, SpikeAndSlabSparseLinearRegression)):
817+
regression_components.append((index, component))
818+
else:
819+
raise NotImplementedError(
820+
'Found unsupported model component for Gibbs Sampling: {}'.format(
821+
component))
822+
823+
if len(level_components) != 1:
824+
raise ValueError(
825+
'Expected exactly one level component, found {} components.'.format(
826+
len(level_components)))
827+
level_component = level_components[0]
828+
829+
regression_component = (None, None)
830+
if len(regression_components) > 1:
831+
raise ValueError(
832+
'Expected at most one regression component, found {} components.'
833+
.format(len(regression_components)))
834+
elif len(regression_components) == 1:
835+
regression_component = regression_components[0]
836+
return level_component, regression_component
837+
838+
787839
def _build_sampler_loop_body(model,
788840
observed_time_series,
789841
is_missing=None,
@@ -819,29 +871,12 @@ def _build_sampler_loop_body(model,
819871
"""
820872
if JAX_MODE and experimental_use_dynamic_cholesky:
821873
raise ValueError('Dynamic Cholesky decomposition not supported in JAX')
822-
level_component = model.components[0]
823-
if not (isinstance(level_component, local_level.LocalLevel) or
824-
isinstance(level_component, local_linear_trend.LocalLinearTrend)):
825-
raise ValueError(
826-
'Expected the first model component to be an instance of '
827-
'`tfp.sts.LocalLevel` or `tfp.local_linear_trend.LocalLinearTrend`; '
828-
'instead saw {}'.format(level_component))
874+
(_, level_component), (_, regression_component) = (
875+
_get_components_from_model(model))
829876
model_has_slope = isinstance(level_component,
830877
local_linear_trend.LocalLinearTrend)
831878

832-
# TODO(kloveless): When we add support for more flexible models, remove
833-
# this assumption.
834-
regression_component = (None if len(model.components) != 2 else
835-
model.components[1])
836-
if regression_component:
837-
if not (isinstance(regression_component, regression.LinearRegression) or
838-
isinstance(regression_component,
839-
SpikeAndSlabSparseLinearRegression)):
840-
raise ValueError(
841-
'Expected the second model component to be an instance of '
842-
'`tfp.sts.LinearRegression` or '
843-
'`SpikeAndSlabSparseLinearRegression`; '
844-
'instead saw {}'.format(regression_component))
879+
if regression_component is not None:
845880
model_has_spike_slab_regression = isinstance(
846881
regression_component, SpikeAndSlabSparseLinearRegression)
847882

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from absl.testing import parameterized
1818

1919
import numpy as np
20-
2120
import tensorflow.compat.v2 as tf
2221

2322
from tensorflow_probability.python.distributions import inverse_gamma
@@ -34,7 +33,7 @@
3433
from tensorflow_probability.python.sts.components import local_level
3534
from tensorflow_probability.python.sts.components import local_linear_trend
3635
from tensorflow_probability.python.sts.components import regression
37-
from tensorflow_probability.python.sts.components import seasonal
36+
from tensorflow_probability.python.sts.components import semilocal_linear_trend
3837
from tensorflow_probability.python.sts.components import sum as sum_lib
3938
from tensorflow_probability.python.sts.forecast import forecast
4039
from tensorflow_probability.python.sts.internal import missing_values_util
@@ -189,11 +188,12 @@ def test_forecasts_match_reference(self,
189188
if not tf.nest.is_nested(num_chains):
190189
num_results = num_results // num_chains
191190

191+
batch_shape = [3]
192192
model, observed_time_series, is_missing = self._build_test_model(
193193
num_timesteps=num_observed_steps + num_forecast_steps,
194194
true_slope_scale=0.5 if use_slope else None,
195-
batch_shape=[3],
196-
time_series_shift=time_series_shift)
195+
time_series_shift=time_series_shift,
196+
batch_shape=batch_shape)
197197

198198
@tf.function(autograph=False)
199199
def do_sampling():
@@ -226,9 +226,9 @@ def reshape_chain_and_sample(x):
226226
predictive_mean, predictive_stddev = self.evaluate((
227227
predictive_dist.mean(), predictive_dist.stddev()))
228228
self.assertAllEqual(predictive_mean.shape,
229-
[3, num_observed_steps + num_forecast_steps])
229+
batch_shape + [num_observed_steps + num_forecast_steps])
230230
self.assertAllEqual(predictive_stddev.shape,
231-
[3, num_observed_steps + num_forecast_steps])
231+
batch_shape + [num_observed_steps + num_forecast_steps])
232232

233233
# big tolerance, but makes sure the predictive mean initializes near
234234
# the initial time series value
@@ -457,31 +457,23 @@ def test_invalid_model_raises_error(self):
457457
],
458458
observed_time_series=observed_time_series)
459459

460-
with self.assertRaisesRegexp(ValueError, 'does not support Gibbs sampling'):
460+
with self.assertRaisesRegex(ValueError, 'does not support Gibbs sampling'):
461461
gibbs_sampler.fit_with_gibbs_sampling(
462462
bad_model, observed_time_series, seed=test_util.test_seed())
463463

464464
bad_model.supports_gibbs_sampling = True
465-
with self.assertRaisesRegexp(
466-
ValueError, 'Expected the first model component to be an instance of'):
467-
gibbs_sampler.fit_with_gibbs_sampling(
468-
bad_model, observed_time_series, seed=test_util.test_seed())
469-
470465
bad_model_with_correct_params = sum_lib.Sum([
471-
# A seasonal model with no drift has no parameters, so adding it
472-
# won't break the check for correct params.
473-
seasonal.Seasonal(
474-
num_seasons=2,
475-
allow_drift=False,
466+
# An unsupported model component.
467+
semilocal_linear_trend.SemiLocalLinearTrend(
476468
observed_time_series=observed_time_series),
477469
local_level.LocalLevel(observed_time_series=observed_time_series),
478470
regression.LinearRegression(design_matrix=tf.ones([5, 2]))
479471
])
480472
bad_model_with_correct_params.supports_gibbs_sampling = True
481473

482-
with self.assertRaisesRegexp(ValueError,
483-
'Expected the first model component to be an '
484-
'instance of `tfp.sts.LocalLevel`'):
474+
with self.assertRaisesRegex(
475+
NotImplementedError,
476+
'Found unsupported model component for Gibbs Sampling'):
485477
gibbs_sampler.fit_with_gibbs_sampling(bad_model_with_correct_params,
486478
observed_time_series,
487479
seed=test_util.test_seed())
@@ -673,13 +665,13 @@ def do_sampling():
673665
@parameterized.named_parameters(
674666
{
675667
'testcase_name': 'Rank1Updates',
676-
'use_dyanamic_cholesky': False,
668+
'use_dynamic_cholesky': False,
677669
}, {
678670
'testcase_name': 'DynamicCholesky',
679-
'use_dyanamic_cholesky': True,
671+
'use_dynamic_cholesky': True,
680672
})
681-
def test_sparse_regression_recovers_plausible_weights(
682-
self, use_dyanamic_cholesky):
673+
def test_sparse_regression_recovers_plausible_weights(self,
674+
use_dynamic_cholesky):
683675
true_weights = tf.constant([0., 0., 2., 0., -2.])
684676
model, observed_time_series, _ = self._build_test_model(
685677
num_timesteps=20,
@@ -698,9 +690,9 @@ def do_sampling():
698690
num_results=100,
699691
num_warmup_steps=100,
700692
seed=test_util.test_seed(sampler_type='stateless'),
701-
experimental_use_dynamic_cholesky=use_dyanamic_cholesky)
693+
experimental_use_dynamic_cholesky=use_dynamic_cholesky)
702694

703-
if JAX_MODE and use_dyanamic_cholesky:
695+
if JAX_MODE and use_dynamic_cholesky:
704696
with self.assertRaises(ValueError):
705697
self.evaluate(do_sampling())
706698
return

0 commit comments

Comments
 (0)