Skip to content

Commit c2c9a7f

Browse files
authored
Add Autoregression Compatibility for Optimized Historical Forecasts on SKLearn models (#2921)
1 parent 39f8c91 commit c2c9a7f

File tree

11 files changed

+593
-422
lines changed

11 files changed

+593
-422
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- 🚀 We optimized auto-regressive historical forecasts for `SKLearnModel` (when `forecast_horizon > output_chunk_length`), increasing throughput by multiple orders of magnitude! Now all historical forecasting scenarios for `SKLearnModel` are optimized. [#2921](https://github.com/unit8co/darts/pull/2921) by [Alain Gysi](https://github.com/Kurokabe)
1415
- 🚀 Added a new configuration system for Darts, similar to pandas' options and settings. [#2956](https://github.com/unit8co/darts/pull/2956) by [Dennis Bader](https://github.com/dennisbader).
1516
- Users can now configure global behavior such as:
1617
- `display.[max_rows, max_cols]`: Maximum number of rows or columns to display in TimeSeries representation (default: 10)
@@ -23,10 +24,14 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2324
- To use either variant, simply set `hub_model_name` parameter to the desired model ID, e.g., `"autogluon/chronos-2-small"`.
2425
- Both models can be used in the same way as the original Chronos-2 model.
2526
- `TorchForecastingModel` parameter `torch_metrics` now supports all input metric types from ``torchmetrics.MetricCollection``. Eg. now you can also pass a dictionary or sequence of metrics. [#2958](https://github.com/unit8co/darts/pull/2958) by [CorticallyAI](https://github.com/CorticallyAI).
27+
- `SKLearnModel` now raises a more informative exception, when (any of) the input target `series` is (are) too short. [#2921](https://github.com/unit8co/darts/pull/2921) by [Dennis Bader](https://github.com/dennisbader).
2628

2729
**Fixed**
2830

2931
- Fixed an issue in `TFTExplainer` where attempting to explain a list of series longer than the model's batch size resulted in an `IndexError`. A more informative error message is now raised instead. [#2957](https://github.com/unit8co/darts/pull/2957) by [Dennis Bader](https://github.com/dennisbader).
32+
- Fixed an issue in `TorchForecastingModel` where it was not possible to run historical forecasts with `overlap=True` if the only possible start point was one step after the end of the target series (e.g. the equivalent to a `predict()` call). [#2921](https://github.com/unit8co/darts/pull/2921) by [Dennis Bader](https://github.com/dennisbader).
33+
- Fixed an issue in `SKLearnModel` where attempting to run historical forecasts on a multivariate target series with component-specific lags did not work properly. [#2921](https://github.com/unit8co/darts/pull/2921) by [Dennis Bader](https://github.com/dennisbader).
34+
- Fixed a bug in `SKLearnModel` with `multi_models=False` where running historical forecasts using `start=None` started later than the actual first possible start point. [#2921](https://github.com/unit8co/darts/pull/2921) by [Dennis Bader](https://github.com/dennisbader).
3035

3136
**Dependencies**
3237

darts/models/forecasting/forecasting_model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,7 @@ def _get_last_prediction_time(
660660

661661
def _check_optimizable_historical_forecasts(
662662
self,
663-
forecast_horizon: int,
664663
retrain: Union[bool, int, Callable[..., bool]],
665-
show_warnings: bool,
666664
) -> bool:
667665
"""By default, historical forecasts cannot be optimized"""
668666
return False
@@ -940,11 +938,7 @@ def retrain_func(
940938
if (
941939
enable_optimization
942940
and model.supports_optimized_historical_forecasts
943-
and model._check_optimizable_historical_forecasts(
944-
forecast_horizon=forecast_horizon,
945-
retrain=retrain,
946-
show_warnings=show_warnings,
947-
)
941+
and model._check_optimizable_historical_forecasts(retrain)
948942
):
949943
forecasts = model._optimized_historical_forecasts(
950944
series=series,

darts/models/forecasting/sklearn_model.py

Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@
8383
)
8484
from darts.utils.historical_forecasts import (
8585
_check_optimizable_historical_forecasts_global_models,
86-
_optimized_historical_forecasts_all_points,
87-
_optimized_historical_forecasts_last_points_only,
86+
_optimized_historical_forecasts_regression,
8887
_process_historical_forecast_input,
8988
)
9089
from darts.utils.likelihood_models.base import LikelihoodType
@@ -1262,6 +1261,33 @@ def predict(
12621261
shift = self.output_chunk_length - 1
12631262
step = 1
12641263

1264+
# check all target series are long enough
1265+
target_lags = self.lags.get("target")
1266+
if target_lags is not None:
1267+
min_target_length = abs(min(target_lags)) + shift
1268+
for idx, series_ in enumerate(series):
1269+
if len(series_) < min_target_length:
1270+
index_text = (
1271+
" "
1272+
if called_with_single_series
1273+
else f" at list/sequence index {idx} "
1274+
)
1275+
end_ts = series_.end_time()
1276+
start_ts = (
1277+
series_.end_time() - (min_target_length - 1) * series_.freq
1278+
)
1279+
raise_log(
1280+
ValueError(
1281+
f"The `series`{index_text}is not long enough. "
1282+
f"Given horizon `n={n}`, `min(lags)={target_lags[0]}`, "
1283+
f"`max(lags)={target_lags[-1]}` and "
1284+
f"`output_chunk_length={self.output_chunk_length}`, the `series` has to "
1285+
f"range from {start_ts} until {end_ts} (inclusive), but it only ranges from "
1286+
f"{series_.start_time()} until {end_ts}."
1287+
),
1288+
logger=logger,
1289+
)
1290+
12651291
# dictionary containing covariate data over time span required for prediction
12661292
covariate_matrices = {}
12671293
# dictionary containing covariate lags relative to minimum covariate lag
@@ -1507,21 +1533,10 @@ def val_set_params(self) -> tuple[Optional[str], Optional[str]]:
15071533

15081534
def _check_optimizable_historical_forecasts(
15091535
self,
1510-
forecast_horizon: int,
15111536
retrain: Union[bool, int, Callable[..., bool]],
1512-
show_warnings: bool,
15131537
) -> bool:
1514-
"""
1515-
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= model.output_chunk_length`
1516-
(no auto-regression required).
1517-
"""
1518-
return _check_optimizable_historical_forecasts_global_models(
1519-
model=self,
1520-
forecast_horizon=forecast_horizon,
1521-
retrain=retrain,
1522-
show_warnings=show_warnings,
1523-
allow_autoregression=False,
1524-
)
1538+
"""Historical forecast can be optimized if no re-training is involved"""
1539+
return _check_optimizable_historical_forecasts_global_models(retrain)
15251540

15261541
def _optimized_historical_forecasts(
15271542
self,
@@ -1545,56 +1560,33 @@ def _optimized_historical_forecasts(
15451560
For SKLearnModels we create the lagged prediction data once per series using a moving window.
15461561
With this, we can avoid having to recreate the tabular input data and call `model.predict()` for each
15471562
forecastable index and series.
1548-
Additionally, there is a dedicated subroutines for `last_points_only=True` and `last_points_only=False`.
1549-
1550-
TODO: support forecast_horizon > output_chunk_length (auto-regression)
15511563
"""
15521564
series, past_covariates, future_covariates = _process_historical_forecast_input(
15531565
model=self,
15541566
series=series,
15551567
past_covariates=past_covariates,
15561568
future_covariates=future_covariates,
15571569
forecast_horizon=forecast_horizon,
1558-
allow_autoregression=False,
15591570
)
15601571

1561-
# TODO: move the loop here instead of duplicated code in each sub-routine?
1562-
if last_points_only:
1563-
hfc = _optimized_historical_forecasts_last_points_only(
1564-
model=self,
1565-
series=series,
1566-
past_covariates=past_covariates,
1567-
future_covariates=future_covariates,
1568-
num_samples=num_samples,
1569-
start=start,
1570-
start_format=start_format,
1571-
forecast_horizon=forecast_horizon,
1572-
stride=stride,
1573-
overlap_end=overlap_end,
1574-
show_warnings=show_warnings,
1575-
verbose=verbose,
1576-
predict_likelihood_parameters=predict_likelihood_parameters,
1577-
random_state=random_state,
1578-
predict_kwargs=predict_kwargs,
1579-
)
1580-
else:
1581-
hfc = _optimized_historical_forecasts_all_points(
1582-
model=self,
1583-
series=series,
1584-
past_covariates=past_covariates,
1585-
future_covariates=future_covariates,
1586-
num_samples=num_samples,
1587-
start=start,
1588-
start_format=start_format,
1589-
forecast_horizon=forecast_horizon,
1590-
stride=stride,
1591-
overlap_end=overlap_end,
1592-
show_warnings=show_warnings,
1593-
verbose=verbose,
1594-
predict_likelihood_parameters=predict_likelihood_parameters,
1595-
random_state=random_state,
1596-
predict_kwargs=predict_kwargs,
1597-
)
1572+
hfc = _optimized_historical_forecasts_regression(
1573+
model=self,
1574+
series=series,
1575+
past_covariates=past_covariates,
1576+
future_covariates=future_covariates,
1577+
num_samples=num_samples,
1578+
start=start,
1579+
start_format=start_format,
1580+
forecast_horizon=forecast_horizon,
1581+
stride=stride,
1582+
overlap_end=overlap_end,
1583+
show_warnings=show_warnings,
1584+
verbose=verbose,
1585+
predict_likelihood_parameters=predict_likelihood_parameters,
1586+
random_state=random_state,
1587+
predict_kwargs=predict_kwargs,
1588+
last_points_only=last_points_only,
1589+
)
15981590
return hfc
15991591

16001592
@property

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,21 +2477,10 @@ def _requires_training(self) -> bool:
24772477

24782478
def _check_optimizable_historical_forecasts(
24792479
self,
2480-
forecast_horizon: int,
24812480
retrain: Union[bool, int, Callable[..., bool]],
2482-
show_warnings: bool,
24832481
) -> bool:
2484-
"""
2485-
Historical forecast can be optimized only if `retrain=False` and `forecast_horizon <= model.output_chunk_length`
2486-
(no auto-regression required).
2487-
"""
2488-
return _check_optimizable_historical_forecasts_global_models(
2489-
model=self,
2490-
forecast_horizon=forecast_horizon,
2491-
retrain=retrain,
2492-
show_warnings=show_warnings,
2493-
allow_autoregression=True,
2494-
)
2482+
"""Historical forecast can be optimized if no re-training is involved"""
2483+
return _check_optimizable_historical_forecasts_global_models(retrain)
24952484

24962485
def _optimized_historical_forecasts(
24972486
self,
@@ -2521,7 +2510,6 @@ def _optimized_historical_forecasts(
25212510
past_covariates=past_covariates,
25222511
future_covariates=future_covariates,
25232512
forecast_horizon=forecast_horizon,
2524-
allow_autoregression=True,
25252513
)
25262514
forecasts_list = _optimized_historical_forecasts(
25272515
model=self,

0 commit comments

Comments
 (0)