Skip to content

Commit 3a20509

Browse files
committed
add dim tests
1 parent 3eae834 commit 3a20509

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

sigllm/primitives/prompting/timeseries_preprocessing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@ def rolling_window_sequences(X, window_size=500, step_size=100):
2626
* rolling window sequences.
2727
* first index value of each input sequence.
2828
"""
29+
if X.ndim == 1:
30+
dim = 1
31+
else:
32+
dim = X.shape[1]
33+
2934
index = range(len(X))
3035
out_X = list()
3136
X_index = list()
32-
dim = X.shape[1]
3337

3438
start = 0
3539
max_start = len(X) - window_size + 1

tests/primitives/prompting/test_timeseries_preprocessing.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@ def test_rolling_window_sequences(values, window_size, step_size):
3131
np.array([0, 1, 2, 3, 4]),
3232
3,
3333
1,
34+
1
35+
)
36+
37+
result = rolling_window_sequences(values, window_size, step_size)
38+
39+
if len(result) != len(expected):
40+
raise AssertionError('Tuples has different length')
41+
42+
for arr1, arr2 in zip(result, expected):
43+
np.testing.assert_equal(arr1, arr2)
44+
45+
46+
def test_rolling_window_sequences_multivariate(window_size, step_size):
47+
values = np.array([
48+
[0.555, 2.345],
49+
[1.501, 5.903],
50+
[9.116, 3.068],
51+
[7.432, 4.532]
52+
])
53+
54+
expected = (
55+
np.array([
56+
[[0.555, 2.345], [1.501, 5.903], [9.116, 3.068]],
57+
[[1.501, 5.903], [9.116, 3.068], [7.432, 4.532]],
58+
]),
59+
np.array([0, 1]),
60+
3,
61+
1,
62+
2
3463
)
3564

3665
result = rolling_window_sequences(values, window_size, step_size)

0 commit comments

Comments
 (0)