Skip to content

Commit d4c22a8

Browse files
g-luotensorflower-gardener
authored andcommitted
Create STS module for modeling holiday effects.
PiperOrigin-RevId: 389190872
1 parent 7b214b9 commit d4c22a8

File tree

3 files changed

+252
-0
lines changed

3 files changed

+252
-0
lines changed

tensorflow_probability/python/sts/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ py_library(
3232
":default_model",
3333
":fitting",
3434
":forecast",
35+
":holiday_effects",
3536
":regularization",
3637
":structural_time_series",
3738
"//tensorflow_probability/python/internal:all_util",
@@ -161,6 +162,34 @@ py_test(
161162
],
162163
)
163164

165+
py_library(
166+
name = "holiday_effects",
167+
srcs = ["holiday_effects.py"],
168+
srcs_version = "PY3",
169+
deps = [
170+
# numpy dep,
171+
# tensorflow dep,
172+
"//tensorflow_probability/python/experimental/util",
173+
"//tensorflow_probability/python/sts/internal",
174+
],
175+
)
176+
177+
py_test(
178+
name = "holiday_effects_test",
179+
size = "medium",
180+
srcs = ["holiday_effects_test.py"],
181+
python_version = "PY3",
182+
shard_count = 6,
183+
srcs_version = "PY3",
184+
deps = [
185+
# holidays dep,
186+
# numpy dep,
187+
# tensorflow dep,
188+
"//tensorflow_probability",
189+
"//tensorflow_probability/python/internal:test_util",
190+
],
191+
)
192+
164193
py_library(
165194
name = "regularization",
166195
srcs = ["regularization.py"],
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Utilities for holiday regressors."""
16+
17+
__all__ = [
18+
'create_holiday_regressors',
19+
]
20+
21+
# Defines expected holiday file fields.
22+
_HOLIDAY_FILE_FIELDS = frozenset({'geo', 'holiday', 'date'})
23+
24+
25+
def create_holiday_regressors(times, holidays):
26+
"""Creates a design matrix of holiday regressors for a given time series.
27+
28+
Args:
29+
times: a Pandas `DatetimeIndex` that indexes time series data.
30+
holidays: a Pandas `DataFrame` containing the dates of holidays. The
31+
`DataFrame` should have the following columns:
32+
* `geo`: `str`, two-letter upper-case country code
33+
* `holiday`: `str`, the name of the holiday
34+
* `date`: `str`, dates in the form of `YYYY-MM-DD`
35+
36+
Returns:
37+
holiday_regressors: a Pandas `DataFrame` where the columns are the names of
38+
holidays. This matrix of one hot encodings is shape
39+
(N, H), where N is the length of `times` and H is the number of unique
40+
holiday names in `holidays.holiday`.
41+
"""
42+
# pylint: disable=g-import-not-at-top
43+
import pandas as pd
44+
# pylint: enable=g-import-not-at-top
45+
46+
# TODO(b/195346554): Expand fixed holidays.
47+
_check_times(times)
48+
_check_holidays(holidays)
49+
50+
holidays = holidays.sort_values('date')
51+
52+
holiday_types = list(holidays.holiday.unique())
53+
holiday_regressors = pd.DataFrame()
54+
for holiday in holiday_types:
55+
holiday_dates = holidays.loc[holidays.holiday == holiday]
56+
holiday_dates = pd.to_datetime(
57+
list(holiday_dates.date), errors='raise', format='%Y-%m-%d')
58+
holiday_regressors.loc[:, holiday] = _match_dates(times, holiday_dates)
59+
60+
# Remove all regressors with only zeros.
61+
holiday_regressors = (
62+
holiday_regressors.loc[:, (holiday_regressors != 0).any(axis=0)])
63+
return holiday_regressors
64+
65+
66+
def _check_times(times):
67+
"""Checks that times are in the correct format.
68+
69+
Args:
70+
times: a Pandas `DatetimeIndex` that indexes time series data.
71+
72+
Raises:
73+
ValueError: if times is not a Pandas `DatetimeIndex` or does not have a
74+
frequency.
75+
"""
76+
# pylint: disable=g-import-not-at-top
77+
import pandas as pd
78+
# pylint: enable=g-import-not-at-top
79+
if not isinstance(times, pd.core.indexes.datetimes.DatetimeIndex):
80+
raise ValueError('Times is not a Pandas DatetimeIndex.')
81+
if not times.freq:
82+
raise ValueError('Times does not have a frequency.')
83+
84+
85+
def _check_holidays(holidays):
86+
"""Checks that holiday files are in the correct format.
87+
88+
Args:
89+
holidays: a Pandas `DataFrame` containing the dates of holidays.
90+
91+
Raises:
92+
ValueError: if the holidays column names are improperly formatted.
93+
"""
94+
all_column_names = _HOLIDAY_FILE_FIELDS.issubset(holidays.columns)
95+
if not all_column_names:
96+
raise ValueError(
97+
'Holidays column names must contain: {0}.'.format(_HOLIDAY_FILE_FIELDS))
98+
99+
100+
def _match_dates(times, dates):
101+
"""Creates a 0-1 dummy variable that marks every instance of dates that also occurs in times.
102+
103+
Args:
104+
times: a Pandas `DatetimeIndex` for observed data.
105+
dates: a Pandas `DatetimeIndex` for relevant dates of a single holiday.
106+
107+
Returns:
108+
regressor: a list the same length as `times`, with a 1 where there is a date
109+
match, and otherwise 0.
110+
"""
111+
regressor = [0] * len(times)
112+
rounded_times = times.floor('d')
113+
rounded_dates = dates.floor('d')
114+
115+
# TODO(b/195347492): Approximate to the nearest prior day
116+
# for greater than daily granularity.
117+
# TODO(b/195347492): Add _MIN_HOLIDAY_OCCURRENCES.
118+
date_intersection = rounded_times.intersection(rounded_dates).unique()
119+
for date in date_intersection:
120+
date_slice = rounded_times.get_loc(date)
121+
regressor_slice = regressor[date_slice]
122+
if isinstance(regressor_slice, int):
123+
replacement = 1
124+
else:
125+
replacement = [1] * len(regressor[date_slice])
126+
regressor[date_slice] = replacement
127+
return regressor
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Tests for holiday_effects."""
16+
from absl.testing import parameterized
17+
import pandas as pd
18+
import tensorflow as tf
19+
from tensorflow_probability.python.internal import test_util
20+
from tensorflow_probability.python.sts import holiday_effects
21+
22+
HOLIDAY_FILE_FIELDS = ['geo', 'holiday', 'date']
23+
24+
25+
class HolidayEffectsTest(test_util.TestCase):
26+
27+
@parameterized.named_parameters(
28+
('date_wrong_order',
29+
pd.DataFrame([['US', 'TestHoliday', '12-20-2013']],
30+
columns=HOLIDAY_FILE_FIELDS)),
31+
('date_invalid',
32+
pd.DataFrame([['US', 'TestHoliday', '12-00-2013']],
33+
columns=HOLIDAY_FILE_FIELDS)),
34+
('bad_column_names',
35+
pd.DataFrame([['US', 'TestHoliday', '2013-12-20']],
36+
columns=['geo', 'wrong', 'date'])))
37+
def test_holidays_raise_error(self, holidays):
38+
times = pd.date_range(
39+
start='2013-12-20', end='2015-12-20', freq=pd.DateOffset(years=1))
40+
with self.assertRaises(ValueError):
41+
holiday_effects.create_holiday_regressors(times, holidays)
42+
43+
@parameterized.named_parameters(
44+
('data_wrong_format', pd.Series(['2013-12-20'])),
45+
('data_no_frequency', pd.DatetimeIndex(['2013-12-20'])))
46+
def test_times_raise_error(self, times):
47+
holidays = pd.DataFrame([['US', 'TestHoliday', '2013-12-20']],
48+
columns=HOLIDAY_FILE_FIELDS)
49+
with self.assertRaises(ValueError):
50+
holiday_effects.create_holiday_regressors(times, holidays)
51+
52+
@parameterized.named_parameters(
53+
('holiday_daily', pd.DateOffset(days=1), '2012-01-01', '2012-12-31',
54+
[0] * 359 + [1] + [0] * 6),
55+
('holiday_hourly', pd.DateOffset(hours=1), '2012-01-01',
56+
'2012-12-31 23:00:00', [0] * 359 * 24 + [1] * 24 + [0] * 6 * 24),
57+
# Note that expected should be `[0] * 51 + [1] + [0]` if
58+
# _match_dates supports rounding timestamps to the nearest prior day
59+
('holiday_weekly', pd.DateOffset(weeks=1), '2012-01-01', '2012-12-31',
60+
[0] * 51 + [0] + [0]))
61+
def test_match_dates_by_frequency(self, freq, start, end, expected):
62+
holiday_dates = pd.to_datetime(['2012-12-25'])
63+
index = pd.date_range(start, end, freq=freq)
64+
matched_dates = holiday_effects._match_dates(index, holiday_dates)
65+
self.assertEqual(matched_dates, expected)
66+
67+
@parameterized.named_parameters(
68+
('holiday_disjoint', '2011-01-01', '2011-12-31', [0] * 365),
69+
('holiday_intersection', '2011-02-01', '2012-01-31',
70+
[0] * 334 + [1] * 31),
71+
('holiday_subset', '2012-01-01', '2012-01-31', [1] * 31))
72+
def test_match_dates_by_overlap(self, start, end, expected):
73+
holiday_dates = pd.date_range(
74+
'2012-01-01', '2012-12-31', freq=pd.DateOffset(days=1))
75+
index = pd.date_range(start, end, freq=pd.DateOffset(days=1))
76+
matched_dates = holiday_effects._match_dates(index, holiday_dates)
77+
self.assertEqual(matched_dates, expected)
78+
79+
@parameterized.named_parameters(
80+
('diagonal_pattern', [('H1', 0), ('H2', 1)], [[1, 0], [0, 1]]),
81+
('row_pattern', [('H1', 0), ('H2', 0)], [[1, 1], [0, 0]]),
82+
('column_pattern', [('H1', 0), ('H1', 1)], [[1], [1]]))
83+
def test_create_holiday_regressors(self, holiday_patterns, expected):
84+
times = pd.date_range(
85+
'2011-01-01', '2012-01-01', freq=pd.DateOffset(years=1))
86+
holidays_list = []
87+
for name, date_index in holiday_patterns:
88+
holidays_list.append(['US', name, times[date_index]])
89+
holidays = pd.DataFrame(holidays_list, columns=HOLIDAY_FILE_FIELDS)
90+
holiday_regressors = holiday_effects.create_holiday_regressors(
91+
times, holidays)
92+
self.assertEqual(holiday_regressors.values.tolist(), expected)
93+
94+
95+
if __name__ == '__main__':
96+
tf.test.main()

0 commit comments

Comments
 (0)