Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,9 @@ def _validate_data(self, data: pd.DataFrame) -> None:
), "Timeseries index should be of type integer"
# numeric categoricals which can cause issues in tensorborad logging
category_columns = data.head(1).select_dtypes("category").columns
object_columns = data.head(1).select_dtypes(object).columns
object_columns = (
data.head(1).select_dtypes(include=["object", "string"]).columns
)
for name in self.flat_categoricals:
if name not in data.columns:
raise KeyError(f"variable {name} specified but not found in data")
Expand Down Expand Up @@ -1882,7 +1884,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
if predict_mode and "sequence_id" in df_index.columns:
minimal_columns.append("sequence_id")

df_index = df_index[minimal_columns].astype("int32", copy=False)
df_index = df_index[minimal_columns].astype("int32")
return df_index.reset_index(drop=True)

def filter(self, filter_func: Callable, copy: bool = True) -> TimeSeriesDataType:
Expand Down
14 changes: 12 additions & 2 deletions pytorch_forecasting/data/timeseries/_timeseries_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@
if col not in [self.time] + self._group + [self.weight] + self._target
]
if self._group:
self._groups = self.data.groupby(self._group).groups
group_arg = (
self._group[0]
if isinstance(self._group, (list, tuple)) and len(self._group) == 1
else self._group
)
self._groups = self.data.groupby(group_arg).groups

Check warning on line 139 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 139 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 139 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.

Check warning on line 139 in pytorch_forecasting/data/timeseries/_timeseries_v2.py

View workflow job for this annotation

GitHub Actions / test-deps-2025 (ubuntu-latest, 3.12)

The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
self._group_ids = list(self._groups.keys())
else:
self._groups = {"_single_group": self.data.index}
Expand Down Expand Up @@ -255,7 +260,12 @@

if data_future is not None:
if _group:
future_mask = self.data_future.groupby(_group).groups[group_id]
group_arg = (
self._group[0]
if isinstance(self._group, (list, tuple)) and len(self._group) == 1
else self._group
)
future_mask = self.data_future.groupby(group_arg).groups[group_id]
future_data = self.data_future.loc[future_mask]
else:
future_data = self.data_future
Expand Down
6 changes: 3 additions & 3 deletions tests/test_data/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,8 @@ def test_pytorch_unwriteable_data():
already have been issued.
"""
# save current mode
copy_on_write = pd.options.mode.copy_on_write
pd.options.mode.copy_on_write = True
# copy_on_write = pd.options.mode.copy_on_write
# pd.options.mode.copy_on_write = True

# Create a small dataset
data = pd.DataFrame(
Expand Down Expand Up @@ -762,7 +762,7 @@ def test_pytorch_unwriteable_data():
next(iter(dataset))

# reset original mode
pd.options.mode.copy_on_write = copy_on_write
# pd.options.mode.copy_on_write = copy_on_write

# Check if the specific warning was triggered
to_catch = "The given NumPy array is not writable, and PyTorch"
Expand Down
Loading