diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 8271f0ef4..774d1952c 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -1003,7 +1003,9 @@ def _validate_data(self, data: pd.DataFrame) -> None: ), "Timeseries index should be of type integer" # numeric categoricals which can cause issues in tensorborad logging category_columns = data.head(1).select_dtypes("category").columns - object_columns = data.head(1).select_dtypes(object).columns + object_columns = ( + data.head(1).select_dtypes(include=["object", "string"]).columns + ) for name in self.flat_categoricals: if name not in data.columns: raise KeyError(f"variable {name} specified but not found in data") @@ -1882,7 +1884,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra if predict_mode and "sequence_id" in df_index.columns: minimal_columns.append("sequence_id") - df_index = df_index[minimal_columns].astype("int32", copy=False) + df_index = df_index[minimal_columns].astype("int32") return df_index.reset_index(drop=True) def filter(self, filter_func: Callable, copy: bool = True) -> TimeSeriesDataType: diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index aff3f8862..154f3a96d 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -131,7 +131,12 @@ def __init__( if col not in [self.time] + self._group + [self.weight] + self._target ] if self._group: - self._groups = self.data.groupby(self._group).groups + group_arg = ( + self._group[0] + if isinstance(self._group, (list, tuple)) and len(self._group) == 1 + else self._group + ) + self._groups = self.data.groupby(group_arg).groups self._group_ids = list(self._groups.keys()) else: self._groups = {"_single_group": self.data.index} @@ -255,7 +260,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: if data_future is not None: if _group: - future_mask = self.data_future.groupby(_group).groups[group_id] + group_arg = ( + self._group[0] + if isinstance(self._group, (list, tuple)) and len(self._group) == 1 + else self._group + ) + future_mask = self.data_future.groupby(group_arg).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future diff --git a/tests/test_data/test_timeseries.py b/tests/test_data/test_timeseries.py index fd0f001f4..e5ce26bc5 100644 --- a/tests/test_data/test_timeseries.py +++ b/tests/test_data/test_timeseries.py @@ -727,8 +727,8 @@ def test_pytorch_unwriteable_data(): already have been issued. """ # save current mode - copy_on_write = pd.options.mode.copy_on_write - pd.options.mode.copy_on_write = True + # copy_on_write = pd.options.mode.copy_on_write + # pd.options.mode.copy_on_write = True # Create a small dataset data = pd.DataFrame( @@ -762,7 +762,7 @@ def test_pytorch_unwriteable_data(): next(iter(dataset)) # reset original mode - pd.options.mode.copy_on_write = copy_on_write + # pd.options.mode.copy_on_write = copy_on_write # Check if the specific warning was triggered to_catch = "The given NumPy array is not writable, and PyTorch"