Skip to content

Commit b78a8fa

Browse files
g-luotensorflower-gardener
authored andcommitted
Add get_default_holidays for modelling holiday effects.
PiperOrigin-RevId: 389269656
1 parent d4c22a8 commit b78a8fa

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

tensorflow_probability/python/sts/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ py_library(
167167
srcs = ["holiday_effects.py"],
168168
srcs_version = "PY3",
169169
deps = [
170+
# holidays dep,
170171
# numpy dep,
171172
# tensorflow dep,
172173
"//tensorflow_probability/python/experimental/util",
@@ -181,6 +182,8 @@ py_test(
181182
python_version = "PY3",
182183
shard_count = 6,
183184
srcs_version = "PY3",
185+
# TODO(gcluo): Add the holidays package to the OSS testing setup.
186+
tags = ["no-oss-ci"],
184187
deps = [
185188
# holidays dep,
186189
# numpy dep,

tensorflow_probability/python/sts/holiday_effects.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,52 @@
1515
"""Utilities for holiday regressors."""
1616

1717
__all__ = [
18+
'get_default_holidays',
1819
'create_holiday_regressors',
1920
]
2021

2122
# Defines expected holiday file fields.
2223
_HOLIDAY_FILE_FIELDS = frozenset({'geo', 'holiday', 'date'})
2324

2425

26+
# TODO(b/195771744): Add holidays as unofficial TFP dependency.
27+
def get_default_holidays(times, country):
28+
"""Creates default holidays for a specific country.
29+
30+
Args:
31+
times: a Pandas `DatetimeIndex` that indexes time series data.
32+
country: `str`, two-letter upper-case [ISO 3166-1 alpha-2 country code](
33+
https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) for using holiday
34+
regressors from a particular country.
35+
36+
Returns:
37+
holidays: a Pandas `DataFrame` with default holidays relevant to the input
38+
times. The `DataFrame` should have the following columns:
39+
* `geo`: `str`, two-letter upper-case country code
40+
* `holiday`: `str`, the name of the holiday
41+
* `date`: `str`, dates in the form of `YYYY-MM-DD`
42+
"""
43+
# pylint: disable=g-import-not-at-top
44+
import pandas as pd
45+
import holidays
46+
# pylint: enable=g-import-not-at-top
47+
48+
years = range(times.min().year, times.max().year + 1)
49+
holidays = holidays.CountryHoliday(country, years=years, expand=False)
50+
holidays = pd.DataFrame(
51+
[(country, holidays.get_list(date), date) for date in holidays],
52+
columns=['geo', 'holiday', 'date'])
53+
holidays = holidays.explode('holiday')
54+
# Ensure that only holiday dates covered by times are used.
55+
holidays = holidays[(holidays['date'] >= times.min())
56+
& (holidays['date'] <= times.max())]
57+
holidays = holidays.reset_index(drop=True)
58+
holidays['date'] = pd.to_datetime(holidays['date'])
59+
holidays = holidays.sort_values('date')
60+
holidays['date'] = holidays['date'].dt.strftime('%Y-%m-%d')
61+
return holidays
62+
63+
2564
def create_holiday_regressors(times, holidays):
2665
"""Creates a design matrix of holiday regressors for a given time series.
2766

tensorflow_probability/python/sts/holiday_effects_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,32 @@
2424

2525
class HolidayEffectsTest(test_util.TestCase):
2626

27+
def test_get_default_holidays_invalid_country(self):
28+
times = pd.to_datetime(['2012-12-25', '2013-01-01'])
29+
country = 'AA'
30+
with self.assertRaises(Exception):
31+
holiday_effects.get_default_holidays(times, country)
32+
33+
def test_get_default_holidays_invalid_times(self):
34+
times = ['2012-12-25', '2013-01-01']
35+
country = 'US'
36+
with self.assertRaises(Exception):
37+
holiday_effects.get_default_holidays(times, country)
38+
39+
@parameterized.named_parameters(
40+
('united_states_holidays', 'US',
41+
pd.DataFrame([['US', 'Christmas Day', '2012-12-25'],
42+
['US', 'New Year\'s Day', '2013-01-01']],
43+
columns=HOLIDAY_FILE_FIELDS)),
44+
('egypt_holidays', 'EG',
45+
pd.DataFrame([['EG', 'New Year\'s Day - Bank Holiday', '2013-01-01']],
46+
columns=HOLIDAY_FILE_FIELDS)))
47+
def test_get_default_holidays(self, country, expected):
48+
times = pd.date_range(
49+
start='2012-12-25', end='2013-01-01', freq=pd.DateOffset(days=1))
50+
holidays = holiday_effects.get_default_holidays(times, country)
51+
self.assertTrue(holidays.equals(expected))
52+
2753
@parameterized.named_parameters(
2854
('date_wrong_order',
2955
pd.DataFrame([['US', 'TestHoliday', '12-20-2013']],

0 commit comments

Comments
 (0)