|
15 | 15 | """Utilities for holiday regressors."""
|
16 | 16 |
|
17 | 17 | __all__ = [
|
| 18 | + 'get_default_holidays', |
18 | 19 | 'create_holiday_regressors',
|
19 | 20 | ]
|
20 | 21 |
|
21 | 22 | # Defines expected holiday file fields.
|
22 | 23 | _HOLIDAY_FILE_FIELDS = frozenset({'geo', 'holiday', 'date'})
|
23 | 24 |
|
24 | 25 |
|
| 26 | +# TODO(b/195771744): Add holidays as unofficial TFP dependency. |
| 27 | +def get_default_holidays(times, country): |
| 28 | + """Creates default holidays for a specific country. |
| 29 | +
|
| 30 | + Args: |
| 31 | + times: a Pandas `DatetimeIndex` that indexes time series data. |
| 32 | + country: `str`, two-letter upper-case [ISO 3166-1 alpha-2 country code]( |
| 33 | + https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) for using holiday |
| 34 | + regressors from a particular country. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + holidays: a Pandas `DataFrame` with default holidays relevant to the input |
| 38 | + times. The `DataFrame` should have the following columns: |
| 39 | + * `geo`: `str`, two-letter upper-case country code |
| 40 | + * `holiday`: `str`, the name of the holiday |
| 41 | + * `date`: `str`, dates in the form of `YYYY-MM-DD` |
| 42 | + """ |
| 43 | + # pylint: disable=g-import-not-at-top |
| 44 | + import pandas as pd |
| 45 | + import holidays |
| 46 | + # pylint: enable=g-import-not-at-top |
| 47 | + |
| 48 | + years = range(times.min().year, times.max().year + 1) |
| 49 | + holidays = holidays.CountryHoliday(country, years=years, expand=False) |
| 50 | + holidays = pd.DataFrame( |
| 51 | + [(country, holidays.get_list(date), date) for date in holidays], |
| 52 | + columns=['geo', 'holiday', 'date']) |
| 53 | + holidays = holidays.explode('holiday') |
| 54 | + # Ensure that only holiday dates covered by times are used. |
| 55 | + holidays = holidays[(holidays['date'] >= times.min()) |
| 56 | + & (holidays['date'] <= times.max())] |
| 57 | + holidays = holidays.reset_index(drop=True) |
| 58 | + holidays['date'] = pd.to_datetime(holidays['date']) |
| 59 | + holidays = holidays.sort_values('date') |
| 60 | + holidays['date'] = holidays['date'].dt.strftime('%Y-%m-%d') |
| 61 | + return holidays |
| 62 | + |
| 63 | + |
25 | 64 | def create_holiday_regressors(times, holidays):
|
26 | 65 | """Creates a design matrix of holiday regressors for a given time series.
|
27 | 66 |
|
|
0 commit comments