diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index faca5fa2c..29aeb24f5 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -2,7 +2,7 @@ Models for timeseries forecasting. """ -from pytorch_forecasting.models.base_model import ( +from pytorch_forecasting.models.base import ( AutoRegressiveBaseModel, AutoRegressiveBaseModelWithCovariates, BaseModel, diff --git a/pytorch_forecasting/models/_base_model_rename_this.py b/pytorch_forecasting/models/_base_model_rename_this.py new file mode 100644 index 000000000..12b95c32f --- /dev/null +++ b/pytorch_forecasting/models/_base_model_rename_this.py @@ -0,0 +1,17 @@ +"""Base classes for pytorch-foercasting models.""" + +from pytorch_forecasting.models.base import ( + AutoRegressiveBaseModel, + AutoRegressiveBaseModelWithCovariates, + BaseModel, + BaseModelWithCovariates, + Prediction, +) + +__all__ = [ + "AutoRegressiveBaseModel", + "AutoRegressiveBaseModelWithCovariates", + "BaseModel", + "BaseModelWithCovariates", + "Prediction", +] diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py new file mode 100644 index 000000000..4860e4838 --- /dev/null +++ b/pytorch_forecasting/models/base/__init__.py @@ -0,0 +1,17 @@ +"""Base classes for pytorch-foercasting models.""" + +from pytorch_forecasting.models.base._base_model import ( + AutoRegressiveBaseModel, + AutoRegressiveBaseModelWithCovariates, + BaseModel, + BaseModelWithCovariates, + Prediction, +) + +__all__ = [ + "AutoRegressiveBaseModel", + "AutoRegressiveBaseModelWithCovariates", + "BaseModel", + "BaseModelWithCovariates", + "Prediction", +] diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base/_base_model.py similarity index 100% rename from pytorch_forecasting/models/base_model.py rename to pytorch_forecasting/models/base/_base_model.py diff --git a/pytorch_forecasting/models/deepar/_deepar.py b/pytorch_forecasting/models/deepar/_deepar.py index 773445a44..3a769d7f0 100644 --- a/pytorch_forecasting/models/deepar/_deepar.py +++ b/pytorch_forecasting/models/deepar/_deepar.py @@ -26,7 +26,7 @@ MultivariateDistributionLoss, NormalDistributionLoss, ) -from pytorch_forecasting.models.base_model import ( +from pytorch_forecasting.models.base import ( AutoRegressiveBaseModelWithCovariates, Prediction, ) diff --git a/pytorch_forecasting/models/mlp/_decodermlp.py b/pytorch_forecasting/models/mlp/_decodermlp.py index 6692cc360..8c0805287 100644 --- a/pytorch_forecasting/models/mlp/_decodermlp.py +++ b/pytorch_forecasting/models/mlp/_decodermlp.py @@ -18,7 +18,7 @@ MultiHorizonMetric, QuantileLoss, ) -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.base import BaseModelWithCovariates from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule from pytorch_forecasting.models.nn.embeddings import MultiEmbedding diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index e3a289a12..3181d818c 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -10,7 +10,7 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.base import BaseModel from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, diff --git a/pytorch_forecasting/models/nhits/_nhits.py b/pytorch_forecasting/models/nhits/_nhits.py index b71dd53af..9e8051896 100644 --- a/pytorch_forecasting/models/nhits/_nhits.py +++ b/pytorch_forecasting/models/nhits/_nhits.py @@ -20,7 +20,7 @@ MultiHorizonMetric, MultiLoss, ) -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.base import BaseModelWithCovariates from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule from pytorch_forecasting.models.nn.embeddings import MultiEmbedding from pytorch_forecasting.utils import create_mask, to_list diff --git a/pytorch_forecasting/models/rnn/_rnn.py b/pytorch_forecasting/models/rnn/_rnn.py index e5d5f0bb5..a1c0fabd1 100644 --- a/pytorch_forecasting/models/rnn/_rnn.py +++ b/pytorch_forecasting/models/rnn/_rnn.py @@ -21,7 +21,7 @@ MultiLoss, QuantileLoss, ) -from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates +from pytorch_forecasting.models.base import AutoRegressiveBaseModelWithCovariates from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn from pytorch_forecasting.utils import apply_to_list, to_list diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py index 6ab878ecb..237006ba8 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py @@ -19,7 +19,7 @@ MultiHorizonMetric, QuantileLoss, ) -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.base import BaseModelWithCovariates from pytorch_forecasting.models.nn import LSTM, MultiEmbedding from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( AddNorm, diff --git a/pytorch_forecasting/models/tide/_tide.py b/pytorch_forecasting/models/tide/_tide.py index cd101dfec..afeee1103 100644 --- a/pytorch_forecasting/models/tide/_tide.py +++ b/pytorch_forecasting/models/tide/_tide.py @@ -12,7 +12,7 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.base import BaseModelWithCovariates from pytorch_forecasting.models.nn.embeddings import MultiEmbedding from pytorch_forecasting.models.tide.sub_modules import _TideModule