|
| 1 | +# Copyright 2021 The TensorFlow Probability Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================ |
| 15 | +"""Utilities for automatically building StructuralTimeSeries models.""" |
| 16 | + |
| 17 | +import tensorflow.compat.v2 as tf |
| 18 | + |
| 19 | +from tensorflow_probability.python.sts import components as sts_components |
| 20 | +from tensorflow_probability.python.sts import structural_time_series |
| 21 | +from tensorflow_probability.python.sts.internal import seasonality_util |
| 22 | +from tensorflow_probability.python.sts.internal import util as sts_util |
| 23 | + |
| 24 | +__all__ = [ |
| 25 | + 'build_default_model', |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +# TODO(davmre): before exposing publicly, consider simplifying this function |
| 30 | +# (e.g., not exposing prior specification args) and/or renaming it to something |
| 31 | +# like `auto_build_model`. |
| 32 | +def build_default_model(observed_time_series, |
| 33 | + base_component=sts_components.LocalLinearTrend, |
| 34 | + observation_noise_scale_prior=None, |
| 35 | + drift_scale_prior=None, |
| 36 | + allow_seasonal_effect_drift=True, |
| 37 | + name=None): |
| 38 | + """Builds a model with seasonality from a Pandas Series or DataFrame. |
| 39 | +
|
| 40 | + Returns a model of the form |
| 41 | + `tfp.sts.Sum([base_component] + seasonal_components)`, where |
| 42 | + `seasonal_components` are automatically selected using the frequency from the |
| 43 | + `DatetimeIndex` of the provided `pd.Series` or `pd.DataFrame`. If the index |
| 44 | + does not have a set frequency, one will be inferred from the index dates, and |
| 45 | +
|
| 46 | + Args: |
| 47 | + observed_time_series: Instance of `pd.Series` or `pd.DataFrame` containing |
| 48 | + one or more time series indexed by a `pd.DatetimeIndex`. |
| 49 | + base_component: Optional subclass of `tfp.sts.StructuralTimeSeries` |
| 50 | + specifying the model used for residual variation in the series not |
| 51 | + explained by seasonal or other effects. May also be an *instance* of such |
| 52 | + a class with specific priors set; if not provided, such an instance will |
| 53 | + be constructed with heuristic default priors. |
| 54 | + Default value: `tfp.sts.LocalLinearTrend`. |
| 55 | + observation_noise_scale_prior: Optional `tfd.Distribution` instance |
| 56 | + specifying a prior on `observation_noise_scale`. If `None`, a heuristic |
| 57 | + default prior is constructed based on the provided `observed_time_series`. |
| 58 | + Default value: `None`. |
| 59 | + drift_scale_prior: Optional `tfd.Distribution` instance |
| 60 | + specifying a prior on the `drift_scale` parameter of Seasonal components. |
| 61 | + If `None`, a heuristic default prior is constructed based on the provided |
| 62 | + `observed_time_series`. |
| 63 | + Default value: `None`. |
| 64 | + allow_seasonal_effect_drift: optional Python `bool` specifying whether the |
| 65 | + seasonal effects can drift over time. Setting this to `False` |
| 66 | + removes the `drift_scale` parameter from the model. This is |
| 67 | + mathematically equivalent to `drift_scale_prior = tfd.Deterministic(0.)`, |
| 68 | + but removing drift directly is preferred because it avoids the use of a |
| 69 | + degenerate prior. |
| 70 | + Default value: `True`. |
| 71 | + name: Python `str` name for ops created by this function. |
| 72 | + Default value: `None` (i.e., 'build_default_model'). |
| 73 | + Returns: |
| 74 | + model: instance of `tfp.sts.Sum` representing a model for the given data. |
| 75 | +
|
| 76 | + #### Example |
| 77 | +
|
| 78 | + Consider a series of eleven data points, covering a period of two weeks |
| 79 | + with three missing days. |
| 80 | +
|
| 81 | + ```python |
| 82 | + import pandas as pd |
| 83 | + import tensorflow as tf |
| 84 | + import tensorflow_probability as tfp |
| 85 | +
|
| 86 | + series = pd.Series( |
| 87 | + [100., 27., 92., 66., 51., 126., 113., 95., 48., 20., 59.,], |
| 88 | + index=pd.to_datetime(['2020-01-01', '2020-01-02', '2020-01-04', |
| 89 | + '2020-01-05', '2020-01-06', '2020-01-07', |
| 90 | + '2020-01-10', '2020-01-11', '2020-01-12', |
| 91 | + '2020-01-13', '2020-01-14'])) |
| 92 | + ``` |
| 93 | +
|
| 94 | + Before calling `build_default_model`, we must regularize the series to follow |
| 95 | + a fixed frequency (here, daily observations): |
| 96 | +
|
| 97 | + ```python |
| 98 | + series = tfp.sts.regularize_series(series) |
| 99 | + # len(series) ==> 14 |
| 100 | + ``` |
| 101 | +
|
| 102 | + The default model will combine a LocalLinearTrend baseline with a Seasonal |
| 103 | + component to capture day-of-week effects. We can then fit this model to our |
| 104 | + observed data. Here we'll use variational inference: |
| 105 | +
|
| 106 | + ```python |
| 107 | + model = tfp.sts.build_default_model(series) |
| 108 | + # len(model.components) == 2 |
| 109 | +
|
| 110 | + # Fit the model using variational inference. |
| 111 | + surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(model) |
| 112 | + losses = tfp.vi.fit_surrogate_posterior( |
| 113 | + target_log_prob_fn=model.joint_log_prob(series), |
| 114 | + surrogate_posterior=surrogate_posterior, |
| 115 | + optimizer=tf.optimizers.Adam(0.1), |
| 116 | + num_steps=1000, |
| 117 | + convergence_criterion=( |
| 118 | + tfp.optimizer.convergence_criteria.SuccessiveGradientsAreUncorrelated( |
| 119 | + window_size=20, min_num_steps=50)), |
| 120 | + jit_compile=True) |
| 121 | + parameter_samples = surrogate_posterior.sample(50) |
| 122 | + ``` |
| 123 | +
|
| 124 | + Finally, use the fitted parameters to forecast the next week of data: |
| 125 | +
|
| 126 | + ```python |
| 127 | + forecast_dist = tfp.sts.forecast(model, |
| 128 | + observed_time_series=series, |
| 129 | + parameter_samples=parameter_samples, |
| 130 | + num_steps_forecast=7) |
| 131 | + # Strip trailing unit dimension from LinearGaussianStateSpaceModel events. |
| 132 | + forecast_mean = forecast_dist.mean()[..., 0] |
| 133 | + forecast_stddev = forecast_dist.stddev()[..., 0] |
| 134 | +
|
| 135 | + forecast = pd.DataFrame( |
| 136 | + {'mean': forecast_mean, |
| 137 | + 'lower_bound': forecast_mean - 2. * forecast_stddev, |
| 138 | + 'upper_bound': forecast_mean + 2. * forecast_stddev} |
| 139 | + index=pd.date_range(start=series.index[-1] + series.index.freq, |
| 140 | + periods=7, |
| 141 | + freq=series.index.freq)) |
| 142 | + ``` |
| 143 | +
|
| 144 | + """ |
| 145 | + with tf.name_scope(name or 'build_default_model'): |
| 146 | + frequency = getattr(observed_time_series.index, 'freq', None) |
| 147 | + if frequency is None: |
| 148 | + raise ValueError('Provided series has no set frequency. Consider ' |
| 149 | + 'using `tfp.sts.regularize_series` to infer a frequency ' |
| 150 | + 'and build a regularly spaced series.') |
| 151 | + observed_time_series = sts_util.canonicalize_observed_time_series_with_mask( |
| 152 | + observed_time_series) |
| 153 | + |
| 154 | + if not isinstance(base_component, |
| 155 | + structural_time_series.StructuralTimeSeries): |
| 156 | + # Build a component of the given type using default priors. |
| 157 | + base_component = base_component(observed_time_series=observed_time_series) |
| 158 | + |
| 159 | + components = [base_component] |
| 160 | + seasonal_structure = seasonality_util.create_seasonal_structure( |
| 161 | + frequency=frequency, |
| 162 | + num_steps=int(observed_time_series.time_series.shape[-2])) |
| 163 | + for season_type, season in seasonal_structure.items(): |
| 164 | + components.append( |
| 165 | + sts_components.Seasonal(num_seasons=season.num, |
| 166 | + num_steps_per_season=season.duration, |
| 167 | + drift_scale_prior=drift_scale_prior, |
| 168 | + allow_drift=allow_seasonal_effect_drift, |
| 169 | + observed_time_series=observed_time_series, |
| 170 | + name=str(season_type))) |
| 171 | + return sts_components.Sum( |
| 172 | + components, |
| 173 | + observed_time_series=observed_time_series, |
| 174 | + observation_noise_scale_prior=observation_noise_scale_prior) |
0 commit comments