Skip to content

Commit 807d22c

Browse files
brunnedubrunnedu-u8jakubchlapekdennisbader
authored
Fix/max_samples_per_ts (#2987)
* fix max_samples_per_ts not acting as upper bound; add test; update changelog * Update CHANGELOG.md Co-authored-by: Jakub Chłapek <147340544+jakubchlapek@users.noreply.github.com> --------- Co-authored-by: Dustin Brunner <dustin.brunner@unit8.co> Co-authored-by: Jakub Chłapek <147340544+jakubchlapek@users.noreply.github.com> Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
1 parent 72edd10 commit 807d22c

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1313

1414
**Fixed**
1515

16+
- Fixed a bug in `TorchTrainingDataset` where `max_samples_per_ts` was not acting as an upper bound on the number of samples per time series. Now `max_samples_per_ts` correctly acts as an upper bound, capping the dataset size at the actual number of samples that can be extracted from the longest series. [#2987](https://github.com/unit8co/darts/pull/2987) by [Dustin Brunner](https://github.com/brunnedu).
1617
- Updated s(m)ape to not raise a ValueError when actuals and predictions are zero for the same timestep. [#2984](https://github.com/unit8co/darts/pull/2984) by [eschibli](https://github.com/eschibli).
1718

1819
**Dependencies**

darts/tests/utils/torch_datasets/test_torch_datasets.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,70 @@ def test_horizon_training_dataset_invalid_lh(self):
881881
"with `1 <= min_lh <= max_lh`."
882882
)
883883

884+
def test_max_samples_per_ts_upper_bound(self):
885+
# Use cov1 with length 100
886+
series = self.cov1
887+
888+
# With input_chunk_length=11, output_chunk_length=13, and shift=24
889+
# size_of_both_chunks = max(11, 24 + 13) = 37
890+
# actual extractable samples = 100 - 37 + 1 = 64
891+
892+
# Case 1: max_samples_per_ts=None should extract all 64 samples
893+
ds_no_limit = ShiftedTorchTrainingDataset(
894+
series=series,
895+
input_chunk_length=11,
896+
output_chunk_length=13,
897+
shift=24,
898+
max_samples_per_ts=None,
899+
)
900+
assert len(ds_no_limit) == 64
901+
902+
# Case 2: max_samples_per_ts > actual max should cap at actual max (64)
903+
ds_high_limit = ShiftedTorchTrainingDataset(
904+
series=series,
905+
input_chunk_length=11,
906+
output_chunk_length=13,
907+
shift=24,
908+
max_samples_per_ts=5000, # Much higher than 64
909+
)
910+
# Should be capped at 64, not 5000
911+
assert len(ds_high_limit) == 64
912+
913+
# Case 3: max_samples_per_ts < actual max should use the limit
914+
ds_low_limit = ShiftedTorchTrainingDataset(
915+
series=series,
916+
input_chunk_length=11,
917+
output_chunk_length=13,
918+
shift=24,
919+
max_samples_per_ts=50,
920+
)
921+
assert len(ds_low_limit) == 50
922+
923+
# Case 4: Test with stride > 1
924+
# actual extractable samples with stride=2 = ceil(64 / 2) = 32
925+
ds_stride = ShiftedTorchTrainingDataset(
926+
series=series,
927+
input_chunk_length=11,
928+
output_chunk_length=13,
929+
shift=24,
930+
stride=2,
931+
max_samples_per_ts=100,
932+
)
933+
assert len(ds_stride) == 32
934+
935+
# Case 5: Multiple series with different lengths
936+
series1 = gaussian_timeseries(length=50) # 50 - 37 + 1 = 14 samples
937+
series2 = gaussian_timeseries(length=100) # 100 - 37 + 1 = 64 samples
938+
ds_multi = ShiftedTorchTrainingDataset(
939+
series=[series1, series2],
940+
input_chunk_length=11,
941+
output_chunk_length=13,
942+
shift=24,
943+
max_samples_per_ts=5000,
944+
)
945+
# Should be capped at 64 (max of both series), so 2 * 64 = 128
946+
assert len(ds_multi) == 2 * 64
947+
884948
def test_past_covariates_sequential_dataset(self):
885949
# one target series
886950
ds = SequentialTorchTrainingDataset(

darts/utils/data/torch_datasets/training_dataset.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,19 +170,24 @@ def __init__(
170170

171171
size_of_both_chunks = max(input_chunk_length, shift + output_chunk_length)
172172

173-
# setup samples
174-
if max_samples_per_ts is None:
175-
# read all time series to get the maximum size
176-
max_samples_per_ts = max(len(ts) for ts in series) - size_of_both_chunks + 1
177-
if max_samples_per_ts <= 0:
178-
raise_log(
179-
ValueError(
180-
f"The input `series` are too short to extract even a single sample. "
181-
f"Expected min length: `{size_of_both_chunks}`, received max length: "
182-
f"`{max_samples_per_ts + size_of_both_chunks - 1}`."
183-
)
173+
# compute the maximum available samples over all series
174+
max_available_indices = max(len(ts) for ts in series) - size_of_both_chunks + 1
175+
max_available_samples = ceil(max_available_indices / stride)
176+
177+
if max_available_indices <= 0:
178+
raise_log(
179+
ValueError(
180+
f"The input `series` are too short to extract even a single sample. "
181+
f"Expected min length: `{size_of_both_chunks}`, received max length: "
182+
f"`{max(len(ts) for ts in series)}`."
184183
)
185-
max_samples_per_ts = ceil(max_samples_per_ts / stride)
184+
)
185+
186+
if max_samples_per_ts is None:
187+
max_samples_per_ts = max_available_samples
188+
else:
189+
# upper bound maximum available samples by max_samples_per_ts
190+
max_samples_per_ts = min(max_samples_per_ts, max_available_samples)
186191

187192
self.input_chunk_length = input_chunk_length
188193
self.output_chunk_length = output_chunk_length

0 commit comments

Comments
 (0)