Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
49c3ab4
add feature scaling to d2
PranavBhatP Oct 12, 2025
1942383
Merge branch 'main' into feature-scaling
PranavBhatP Oct 12, 2025
403e0f2
fix incorrect orig_idx index
PranavBhatP Oct 12, 2025
5661be4
fix incorrect attibute
PranavBhatP Oct 12, 2025
38cefe4
handle unfitted scalers
PranavBhatP Oct 16, 2025
f242290
change accelerator to cpu in v2 notebook cell 10
PranavBhatP Oct 17, 2025
54da1c4
use torch.from_numpy instead of torch.tensor for numpy to torch conve…
PranavBhatP Oct 18, 2025
c145e9b
revert accelerator mode to auto from cpu for example notebook trainin…
PranavBhatP Oct 18, 2025
ec4cf03
potential fix for issue in trainingof v2
PranavBhatP Oct 21, 2025
18f2b2a
replace MAE() with nn.L1Loss() to fix notebook test failures
PranavBhatP Nov 2, 2025
5c99959
Merge branch 'main' into feature-scaling
PranavBhatP Nov 2, 2025
85ba7cb
Merge branch 'main' into feature-scaling
PranavBhatP Nov 4, 2025
d96aed5
Merge branch 'main' into feature-scaling
PranavBhatP Nov 25, 2025
fd8411a
revert notebook state
PranavBhatP Dec 5, 2025
0830090
Merge branch 'main' into feature-scaling
PranavBhatP Dec 5, 2025
ff42a1b
some changes to data module - incomplete
PranavBhatP Dec 7, 2025
f86f9a5
fix scaling and target norm - working
PranavBhatP Dec 8, 2025
728cfad
remove target_scale and add target_normalizer instead
PranavBhatP Dec 8, 2025
091d0f8
restore original notebook
PranavBhatP Dec 8, 2025
6d38331
revert breaking change on target scale
PranavBhatP Dec 8, 2025
4ff3444
Merge branch 'main' into feature-scaling
PranavBhatP Dec 15, 2025
ca5cb97
separate concerns for feature scaling and target normalizers inside _…
PranavBhatP Dec 15, 2025
77dc43f
fix multi target handling during normalization
PranavBhatP Dec 15, 2025
757fe83
fix data module output format
PranavBhatP Dec 15, 2025
bceb0e3
add tests for feature scaling and norm
PranavBhatP Dec 15, 2025
c07a343
remove unecessary dataset param from internal D2 dataset class
PranavBhatP Dec 15, 2025
44e0545
add encoder normalizer support for data module
PranavBhatP Dec 28, 2025
96669f1
Merge branch 'main' into feature-scaling
PranavBhatP Dec 28, 2025
e927ec2
remove contentious line for testing normalizer behavior
PranavBhatP Dec 28, 2025
b45279a
add validation for preprocessing in test and predict dataset
PranavBhatP Jan 1, 2026
67aba1b
skip check for fitting before preprocessing
PranavBhatP Jan 1, 2026
f10fbaf
Merge branch 'main' into feature-scaling
PranavBhatP Jan 4, 2026
94ca96a
Merge branch 'main' into feature-scaling
PranavBhatP Jan 21, 2026
93159fb
add loading and saving of normalizer and scaler metadata for use acro…
PranavBhatP Jan 22, 2026
3665577
improve test suite for normalizer and scalers
PranavBhatP Jan 22, 2026
3cb479b
save and load scaler in base_pkg class
PranavBhatP Jan 22, 2026
f0969d7
fix saving state of scalers when save_ckpt is false in model_pkg fit …
PranavBhatP Jan 24, 2026
9355c01
avoid call to save and fit scalers in data modules which do not suppo…
PranavBhatP Jan 24, 2026
67b1f59
Merge branch 'main' into feature-scaling
PranavBhatP Jan 24, 2026
adc46c2
Merge branch 'main' into feature-scaling
PranavBhatP Jan 28, 2026
ce5f904
Merge branch 'main' of https://www.github.com/PranavBhatP/pytorch-for…
PranavBhatP Feb 15, 2026
5c344f8
Merge branch 'feature-scaling' of https://www.github.com/PranavBhatP/…
PranavBhatP Feb 15, 2026
8324d14
revert change in logic for handling datamodules reused for the test/p…
PranavBhatP Feb 15, 2026
32256d1
change base pkg code to handle saving and loading, while dm handles v…
PranavBhatP Mar 2, 2026
5028cee
Merge branch 'main' into feature-scaling
PranavBhatP Mar 2, 2026
f934b28
move persistence logic completely into base package
PranavBhatP Mar 8, 2026
19dd617
Merge branch 'main' into feature-scaling
PranavBhatP Mar 8, 2026
d7f1af5
raise warning for data modules not supporting feature scaling
PranavBhatP Mar 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion pytorch_forecasting/base/_base_pkg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
import pickle
from typing import Any, Optional, Union
import warnings

from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -46,8 +47,17 @@ def __init__(
trainer_cfg: dict[str, Any] | str | Path | None = None,
datamodule_cfg: dict[str, Any] | str | Path | None = None,
ckpt_path: str | Path | None = None,
scaler_path: str | Path | None = None,
):
self.ckpt_path = Path(ckpt_path) if ckpt_path else None
if scaler_path:
self._scaler_path = Path(scaler_path)
elif self.ckpt_path:
# attempt automatic discovery of scaler path when ckpt path is specified.
potential_scaler = self.ckpt_path.parent / "scalers.pkl"
self._scaler_path = potential_scaler if potential_scaler.exists() else None
else:
self._scaler_path = None
self.model_cfg = self._load_config(
model_cfg, ckpt_path=self.ckpt_path, auto_file_name="model_cfg.pkl"
)
Expand Down Expand Up @@ -156,12 +166,85 @@ def _build_datamodule(self, data: TimeSeries) -> LightningDataModule:
datamodule_cls = self.get_datamodule_cls()
return datamodule_cls(data, **self.datamodule_cfg)

def _save_scalers(self, scaler_path: Path):
"""Save scalers from DataModule to disk.

BasePkg acts as storage layer - gets state from DataModule and pickles it.

Parameters
----------
scaler_path : Path
Path where scalers should be saved.
"""
if not hasattr(self.datamodule, "get_scalers_state"):
warnings.warn(
f"DataModule of type {type(self.datamodule).__name__} does not support "
"scaler operations. It must implement 'get_scalers_state()' method. "
"Skipping scaler saving.",
UserWarning,
stacklevel=2,
)
return

scaler_state = self.datamodule.get_scalers_state()

with open(scaler_path, "wb") as f:
pickle.dump(scaler_state, f)

self._scaler_path = scaler_path
print(f"Scalers saved to: {scaler_path}")

def _load_scalers(self, datamodule: LightningDataModule, scaler_path: Path):
"""Load scalers from disk and set to DataModule.

BasePkg acts as delivery layer - unpickles state and passes to DataModule
for validation.

Parameters
----------
datamodule : LightningDataModule
The datamodule to load scalers into.
scaler_path : Path
Path to load scalers from.
"""
if not hasattr(datamodule, "set_scalers_state"):
warnings.warn(
f"DataModule of type {type(datamodule).__name__} does not support "
"scaler operations. It must implement 'set_scalers_state()' method. "
"Skipping scaler loading.",
UserWarning,
stacklevel=2,
)
return

if not scaler_path.exists():
raise FileNotFoundError(f"Scaler file not found: {scaler_path}")

try:
with open(scaler_path, "rb") as f:
scaler_state = pickle.load(f) # noqa: S301
datamodule.set_scalers_state(scaler_state)
print(f"Scalers loaded from: {scaler_path}")
except Exception as e:
raise RuntimeError(
f"Failed to load or validate scalers from {scaler_path}: {e}"
)

def _load_dataloader(
self, data: TimeSeries | LightningDataModule | DataLoader
) -> DataLoader:
"""Converts various data input types into a DataLoader for prediction."""
"""Converts various data input types into a DataLoader for prediction.

Reuses datamodules from fitting stage if already exists to persist scalers
across stages of fit->test/predict.
"""
if isinstance(data, TimeSeries): # D1 Layer
dm = self._build_datamodule(data)

# BasePkg handles scaler loading
if self._scaler_path:
self._load_scalers(dm, self._scaler_path)

dm.setup(stage="predict")
return dm.predict_dataloader()
elif isinstance(data, LightningDataModule): # D2 Layer
Expand Down Expand Up @@ -196,6 +279,8 @@ def fit(
save_ckpt: bool = True,
ckpt_dir: str | Path = "checkpoints",
ckpt_kwargs: dict[str, Any] | None = None,
save_scalers: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we expose this param? I mean if the user is checkpointing the model, they will always save the scalers, no? Are there any cases when the user only uses the models and not the scalers?

Copy link
Contributor Author

@PranavBhatP PranavBhatP Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure about whether this param should be exposed, I was thinking maybe due to some reason the user might not want to save their scalers after training due to memory constraints or simply because they know that they are not going to be testing or using the scalers afterwards in a new session, hence won't require any kind of persistence?

scaler_dir: str | Path = None,
**trainer_fit_kwargs,
):
"""
Expand All @@ -212,6 +297,12 @@ def fit(
Directory to save artifacts.
ckpt_kwargs : dict, optional
Keyword arguments passed to ``ModelCheckpoint``.
save_scalers : bool, default=True
If True, save the fitted scalers after training.
Necessary in order to use fitted scalers on new data, e.g.,
during prediction.
scaler_dir : Union[str, Path], default="fitted_scalers"
Directory to save fitted scalers.
**trainer_fit_kwargs :
Additional keyword arguments passed to `trainer.fit()`.

Expand All @@ -224,6 +315,18 @@ def fit(
self.datamodule = self._build_datamodule(data)
else:
self.datamodule = data

# Validate datamodule supports scaling
if save_scalers and not hasattr(self.datamodule, "get_scalers_state"):
warnings.warn(
f"DataModule of type {type(self.datamodule).__name__} does not support "
"scaler operations. Skipping scaler saving. Use a DataModule that "
"implements 'get_scalers_state()' method to enable this feature.",
UserWarning,
stacklevel=2,
)
save_scalers = False

self.datamodule.setup(stage="fit")

if self.model is None:
Expand Down Expand Up @@ -256,6 +359,18 @@ def fit(
self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks)

self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs)

# BasePkg handles all scaler persistence
if save_scalers:
if scaler_dir is None:
scaler_dir = Path(ckpt_dir) if save_ckpt else Path("fitted_scalers")
else:
scaler_dir = Path(scaler_dir)
scaler_dir.mkdir(parents=True, exist_ok=True)
scaler_path = scaler_dir / "scalers.pkl"

self._save_scalers(scaler_path)

if save_ckpt and checkpoint_cb:
best_model_path = Path(checkpoint_cb.best_model_path)
self._save_artifact(best_model_path.parent)
Expand Down
Loading
Loading