Skip to content

Commit 7951c1c

Browse files
kylepltensorflower-gardener
authored andcommitted
First piece of Seasonality support in the GibbsSampler - support resampling drift scale in Seasonality components.
This includes: * Splitting out some sampling logic for scale into a separate file to be shared with seasonality * A convenience method for accessing parameters by name. * Comparison of expectation generated by sampling the mean/variance from a similar estimation PiperOrigin-RevId: 475877493
1 parent 389b346 commit 7951c1c

File tree

10 files changed

+673
-27
lines changed

10 files changed

+673
-27
lines changed

tensorflow_probability/python/experimental/sts_gibbs/BUILD

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ multi_substrate_py_library(
6363
srcs = ["gibbs_sampler.py"],
6464
deps = [
6565
":dynamic_spike_and_slab",
66+
":sample_parameters",
6667
":spike_and_slab",
6768
# numpy dep,
6869
# tensorflow dep,
@@ -204,3 +205,32 @@ multi_substrate_py_test(
204205
# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport
205206
],
206207
)
208+
209+
multi_substrate_py_library(
210+
name = "sample_parameters",
211+
srcs = ["sample_parameters.py"],
212+
deps = [
213+
# numpy dep,
214+
# tensorflow dep,
215+
"//tensorflow_probability/python/bijectors:invert",
216+
"//tensorflow_probability/python/bijectors:square",
217+
"//tensorflow_probability/python/distributions:distribution",
218+
"//tensorflow_probability/python/distributions:inverse_gamma",
219+
"//tensorflow_probability/python/distributions:transformed_distribution",
220+
"//tensorflow_probability/python/internal:dtype_util",
221+
"//tensorflow_probability/python/internal:prefer_static",
222+
],
223+
)
224+
225+
multi_substrate_py_test(
226+
name = "sample_parameters_test",
227+
size = "medium",
228+
srcs = ["sample_parameters_test.py"],
229+
deps = [
230+
":sample_parameters",
231+
# absl/testing:parameterized dep,
232+
# tensorflow dep,
233+
"//tensorflow_probability/python/distributions:inverse_gamma",
234+
"//tensorflow_probability/python/internal:test_util",
235+
],
236+
)

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class is somewhat general, in that we assume that any seasonal/holiday variation
7676
from tensorflow_probability.python.distributions import normal_conjugate_posteriors
7777
from tensorflow_probability.python.experimental.distributions import mvn_precision_factor_linop as mvnpflo
7878
from tensorflow_probability.python.experimental.sts_gibbs import dynamic_spike_and_slab
79+
from tensorflow_probability.python.experimental.sts_gibbs import sample_parameters
7980
from tensorflow_probability.python.experimental.sts_gibbs import spike_and_slab
8081
from tensorflow_probability.python.internal import distribution_util as dist_util
8182
from tensorflow_probability.python.internal import dtype_util
@@ -777,27 +778,10 @@ def _resample_scale(prior, observed_residuals, is_missing=None, seed=None):
777778
Returns:
778779
sampled_scale: A `Tensor` sample from the posterior `p(scale | x)`.
779780
"""
780-
dtype = observed_residuals.dtype
781-
782-
if is_missing is not None:
783-
num_missing = tf.reduce_sum(tf.cast(is_missing, dtype), axis=-1)
784-
num_observations = prefer_static.shape(observed_residuals)[-1]
785-
if is_missing is not None:
786-
observed_residuals = tf.where(is_missing, tf.zeros_like(observed_residuals),
787-
observed_residuals)
788-
num_observations -= num_missing
789-
790-
variance_posterior = type(prior)(
791-
concentration=prior.concentration + tf.cast(num_observations / 2., dtype),
792-
scale=prior.scale +
793-
tf.reduce_sum(tf.square(observed_residuals), axis=-1) / 2.)
794-
new_scale = tf.sqrt(variance_posterior.sample(seed=seed))
795-
796-
# Support truncated priors.
797-
if hasattr(prior, 'upper_bound') and prior.upper_bound is not None:
798-
new_scale = tf.minimum(new_scale, prior.upper_bound)
799-
800-
return new_scale
781+
posterior = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
782+
prior, observed_residuals, is_missing)
783+
return sample_parameters.sample_with_optional_upper_bound(
784+
posterior, seed=seed)
801785

802786

803787
def _build_sampler_loop_body(model,

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def _build_test_model(self,
106106
observation_noise_variance_prior = prior_class(
107107
concentration=tf.cast(0.01, dtype),
108108
scale=tf.cast(0.01 * 0.01, dtype))
109-
observation_noise_variance_prior.upper_bound = 100.0
109+
observation_noise_variance_prior.upper_bound = tf.constant(
110+
100.0, dtype=dtype)
110111

111112
observed_time_series = missing_values_util.MaskedTimeSeries(
112113
time_series[..., tf.newaxis], is_missing)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2022 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Functions for sampling parameters useful in Gibbs Sampling."""
16+
17+
import tensorflow.compat.v2 as tf
18+
19+
from tensorflow_probability.python.bijectors import invert
20+
from tensorflow_probability.python.bijectors import square
21+
from tensorflow_probability.python.distributions import inverse_gamma
22+
from tensorflow_probability.python.distributions import transformed_distribution
23+
from tensorflow_probability.python.internal import prefer_static
24+
25+
26+
def normal_scale_posterior_inverse_gamma_conjugate(variance_prior,
27+
observations,
28+
is_missing=None):
29+
"""Returns the conditional posterior of Normal scale given observations.
30+
31+
We assume the conjugate InverseGamma->Normal model:
32+
33+
```
34+
scale ~ Sqrt(InverseGamma(variance_prior.concentration, variance_prior.scale))
35+
for i in [1, ..., num_observations]:
36+
x[i] ~ Normal(0, scale)
37+
```
38+
39+
and return a sample from `p(scale | x)`.
40+
41+
Args:
42+
variance_prior: Variance prior distribution as a `tfd.InverseGamma`
43+
instance. Note that the prior is given on the variance, but the value
44+
returned is a sample of the scale.
45+
observations: Float `Tensor` of shape `[..., num_observations]`, specifying
46+
the centered observations `(x)`.
47+
is_missing: Optional `bool` `Tensor` of shape `[..., num_observations]`. A
48+
`True` value indicates that the corresponding observation is missing.
49+
50+
Returns:
51+
sampled_scale: A `tfd.Distribution` of the conditional posterior
52+
of the inverse gamma scale.
53+
"""
54+
dtype = observations.dtype
55+
num_observations = prefer_static.shape(observations)[-1]
56+
if is_missing is not None:
57+
num_missing = tf.reduce_sum(tf.cast(is_missing, dtype), axis=-1)
58+
observations = tf.where(is_missing, tf.zeros_like(observations),
59+
observations)
60+
num_observations -= num_missing
61+
62+
variance_posterior = inverse_gamma.InverseGamma(
63+
concentration=variance_prior.concentration +
64+
tf.cast(num_observations / 2, dtype),
65+
scale=variance_prior.scale +
66+
tf.reduce_sum(tf.square(observations), axis=-1) / 2.)
67+
scale_posterior = transformed_distribution.TransformedDistribution(
68+
bijector=invert.Invert(square.Square()), distribution=variance_posterior)
69+
70+
if hasattr(variance_prior,
71+
'upper_bound') and variance_prior.upper_bound is not None:
72+
variance_posterior.upper_bound = variance_prior.upper_bound
73+
# TODO(kloveless): This should have sqrt applied, but it is not for
74+
# temporary backwards compatibility.
75+
scale_posterior.upper_bound = variance_prior.upper_bound
76+
77+
return scale_posterior
78+
79+
80+
# TODO(kloveless): This seems like this should be a function on the
81+
# distribution itself.
82+
def sample_with_optional_upper_bound(distribution, sample_shape=(), seed=None):
83+
"""Samples from the given distribution with an optional upper bound."""
84+
sample = distribution.sample(sample_shape=sample_shape, seed=seed)
85+
if hasattr(distribution,
86+
'upper_bound') and distribution.upper_bound is not None:
87+
sample = tf.minimum(sample, distribution.upper_bound)
88+
89+
return sample
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2022 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Tests for sample_parameters."""
16+
17+
import tensorflow as tf
18+
from tensorflow_probability.python.bijectors import invert
19+
from tensorflow_probability.python.bijectors import square
20+
from tensorflow_probability.python.distributions import inverse_gamma
21+
from tensorflow_probability.python.distributions import transformed_distribution
22+
from tensorflow_probability.python.experimental.sts_gibbs import sample_parameters
23+
from tensorflow_probability.python.internal import samplers
24+
from tensorflow_probability.python.internal import test_util
25+
26+
27+
@test_util.test_all_tf_execution_regimes
28+
class NormalScalePosteriorInverseGammaConjugate(test_util.TestCase):
29+
30+
def testNoObservations(self):
31+
distribution = inverse_gamma.InverseGamma(16., 4.)
32+
posterior_distribution = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
33+
distribution, observations=tf.constant([], dtype=tf.float32))
34+
self.assertIsInstance(posterior_distribution,
35+
transformed_distribution.TransformedDistribution)
36+
self.assertIsInstance(posterior_distribution.bijector, invert.Invert)
37+
self.assertIsInstance(posterior_distribution.bijector.bijector,
38+
square.Square)
39+
self.assertIsInstance(posterior_distribution.distribution,
40+
inverse_gamma.InverseGamma)
41+
self.assertAllEqual(distribution.concentration,
42+
posterior_distribution.distribution.concentration)
43+
self.assertAllEqual(distribution.scale,
44+
posterior_distribution.distribution.scale)
45+
46+
def testSingleObservation(self):
47+
concentration = 16.
48+
scale = 4.
49+
distribution = inverse_gamma.InverseGamma(
50+
concentration=concentration, scale=scale)
51+
posterior_distribution = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
52+
distribution, observations=tf.constant([10.]))
53+
self.assertAllEqual(
54+
concentration + 0.5, # Add half the number of observations
55+
posterior_distribution.distribution.concentration)
56+
self.assertAllEqual(
57+
scale + 50, # Add half the square of observations sum.
58+
posterior_distribution.distribution.scale)
59+
60+
def testTwoObservations(self):
61+
concentration = 16.
62+
scale = 4.
63+
distribution = inverse_gamma.InverseGamma(
64+
concentration=concentration, scale=scale)
65+
posterior_distribution = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
66+
distribution, observations=tf.constant([10., 4.]))
67+
self.assertAllEqual(
68+
concentration + 1, # Add half the number of observations
69+
posterior_distribution.distribution.concentration)
70+
self.assertAllEqual(
71+
scale + 58, # Add half the square of observations sum.
72+
posterior_distribution.distribution.scale)
73+
74+
def testUpperBoundPropagatedFromPrior(self):
75+
# If no upper bound is provided, expect there to be none.
76+
distribution = inverse_gamma.InverseGamma(16., 4.)
77+
posterior_distribution = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
78+
distribution, observations=tf.constant([], dtype=tf.float32))
79+
self.assertFalse(hasattr(posterior_distribution, 'upper_bound'))
80+
self.assertFalse(
81+
hasattr(posterior_distribution.distribution, 'upper_bound'))
82+
83+
distribution = inverse_gamma.InverseGamma(16., 4.)
84+
distribution.upper_bound = 16.
85+
posterior_distribution = sample_parameters.normal_scale_posterior_inverse_gamma_conjugate(
86+
distribution, observations=tf.constant([], dtype=tf.float32))
87+
self.assertAllEqual(posterior_distribution.distribution.upper_bound, 16.)
88+
self.assertAllEqual(
89+
posterior_distribution.upper_bound,
90+
# TODO(kloveless): This should have sqrt applied, but it is not for
91+
# temporary backwards compatibility.
92+
16.)
93+
94+
95+
@test_util.test_all_tf_execution_regimes
96+
class SampleWithOptionalUpperBoundTest(test_util.TestCase):
97+
98+
def testBasic(self):
99+
distribution = inverse_gamma.InverseGamma(16., 4.)
100+
seed = samplers.sanitize_seed((0, 1))
101+
unbounded_result = sample_parameters.sample_with_optional_upper_bound(
102+
distribution, seed=seed)
103+
# Whatever the result, we want to get that value again minus a fixed offset.
104+
# Since we fix the seed, we know it is equal just because of the upper
105+
# bound.
106+
new_target_result = unbounded_result - 0.5
107+
distribution.upper_bound = new_target_result
108+
self.assertAllEqual(
109+
new_target_result,
110+
sample_parameters.sample_with_optional_upper_bound(
111+
distribution, seed=seed))
112+
113+
114+
if __name__ == '__main__':
115+
test_util.main()

tensorflow_probability/python/sts/components/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,18 @@ multi_substrate_py_library(
291291
# numpy dep,
292292
# tensorflow dep,
293293
"//tensorflow_probability/python/bijectors:chain",
294+
"//tensorflow_probability/python/bijectors:invert",
294295
"//tensorflow_probability/python/bijectors:scale",
295296
"//tensorflow_probability/python/bijectors:softplus",
297+
"//tensorflow_probability/python/bijectors:square",
298+
"//tensorflow_probability/python/distributions:inverse_gamma",
296299
"//tensorflow_probability/python/distributions:linear_gaussian_ssm",
297300
"//tensorflow_probability/python/distributions:lognormal",
298301
"//tensorflow_probability/python/distributions:mvn_diag",
299302
"//tensorflow_probability/python/distributions:mvn_tril",
300303
"//tensorflow_probability/python/distributions:normal",
304+
"//tensorflow_probability/python/distributions:transformed_distribution",
305+
"//tensorflow_probability/python/experimental/sts_gibbs:sample_parameters",
301306
"//tensorflow_probability/python/internal:docstring_util",
302307
"//tensorflow_probability/python/internal:dtype_util",
303308
"//tensorflow_probability/python/internal:prefer_static",
@@ -309,7 +314,7 @@ multi_substrate_py_test(
309314
name = "seasonal_test",
310315
size = "medium",
311316
srcs = ["seasonal_test.py"],
312-
shard_count = 6,
317+
shard_count = 12,
313318
deps = [
314319
":seasonal",
315320
# numpy dep,

0 commit comments

Comments
 (0)