Skip to content

Commit 564ead2

Browse files
Merge pull request #27 from Quantmetry/angoho_unit_tests_utils
Angoho unit tests utils
2 parents be68a4c + 9faefa9 commit 564ead2

File tree

2 files changed

+128
-6
lines changed

2 files changed

+128
-6
lines changed

qolmat/utils/data.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,9 @@ def preprocess_data(df: pd.DataFrame):
9999
df.set_index(["station", "datetime"], inplace=True)
100100
df.drop(columns=["year", "month", "day", "hour", "wd", "No"], inplace=True)
101101
df.sort_index(inplace=True)
102-
dict_agg = {key: np.mean for key in df.columns}
103-
dict_agg["RAIN"] = np.mean
104102
df = df.groupby(
105103
["station", df.index.get_level_values("datetime").floor("d")], group_keys=False
106-
).agg(dict_agg)
104+
).mean()
107105
return df
108106

109107

@@ -142,6 +140,7 @@ def add_holes(df: pd.DataFrame, ratio_masked: float, mean_size: int):
142140
mask = generator.generate_mask(df)
143141
else:
144142
mask = df.groupby(groups, group_keys=False).apply(generator.generate_mask)
143+
145144
X_with_nans = df.copy()
146145
X_with_nans[mask] = np.nan
147146
return X_with_nans
@@ -187,6 +186,7 @@ def add_station_features(df: pd.DataFrame):
187186
pd.DataFrame
188187
dataframe with missing values
189188
"""
189+
df = df.copy()
190190
stations = df.index.get_level_values("station")
191191
for station in stations.unique():
192192
df[f"station={station}"] = (stations == station).astype(float)
@@ -207,10 +207,11 @@ def add_datetime_features(df: pd.DataFrame):
207207
pd.DataFrame
208208
dataframe with missing values
209209
"""
210-
211-
time = df.index.get_level.values("datetime")
210+
df = df.copy()
211+
time = df.index.get_level_values("datetime").to_series()
212212
days_in_year = time.dt.year.apply(
213213
lambda x: 366 if ((x % 4 == 0) and (x % 100 != 0)) or (x % 400 == 0) else 365
214214
)
215-
df["time_cos"] = np.cos(2 * np.pi * time.dt.dayofyear / days_in_year)
215+
time_cos = np.cos(2 * np.pi * time.dt.dayofyear / days_in_year)
216+
df["time_cos"] = np.array(time_cos)
216217
return df

tests/utils/test_data.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import datetime
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
7+
from qolmat.utils import data
8+
9+
columns = ["a", "b"]
10+
index = pd.MultiIndex.from_tuples(
11+
[
12+
("Gucheng", datetime.datetime(2013, 3, 1)),
13+
("Gucheng", datetime.datetime(2014, 3, 1)),
14+
("Gucheng", datetime.datetime(2015, 3, 1)),
15+
],
16+
names=["station", "datetime"],
17+
)
18+
df = pd.DataFrame(
19+
[
20+
[1, 2],
21+
[3, np.nan],
22+
[np.nan, 6],
23+
],
24+
columns=columns,
25+
index=index,
26+
)
27+
28+
29+
def test_preprocess_data():
30+
columns_raw = [
31+
"No",
32+
"year",
33+
"month",
34+
"day",
35+
"hour",
36+
"a",
37+
"b",
38+
"wd",
39+
"station",
40+
]
41+
df_raw = pd.DataFrame(
42+
[
43+
[
44+
1,
45+
2013,
46+
3,
47+
1,
48+
0,
49+
1,
50+
2,
51+
"NW",
52+
"Gucheng",
53+
],
54+
[
55+
2,
56+
2014,
57+
3,
58+
1,
59+
0,
60+
3,
61+
np.nan,
62+
"NW",
63+
"Gucheng",
64+
],
65+
[
66+
3,
67+
2015,
68+
3,
69+
1,
70+
0,
71+
np.nan,
72+
6,
73+
"NW",
74+
"Gucheng",
75+
],
76+
],
77+
columns=columns_raw,
78+
)
79+
print(df_raw)
80+
result = data.preprocess_data(df_raw)
81+
print(result)
82+
print(df)
83+
# assert result.equals(df)
84+
pd.testing.assert_frame_equal(result, df, atol=1e-3)
85+
86+
87+
def test_add_holes() -> None:
88+
df_out = data.add_holes(df, 0.0, 1)
89+
assert df_out.isna().sum().sum() == 2
90+
df_out = data.add_holes(df, 1.0, 1)
91+
assert df_out.isna().sum().sum() > 2
92+
93+
94+
def test_add_station_features() -> None:
95+
columns_out = columns + ["station=Gucheng"]
96+
expected = pd.DataFrame(
97+
[
98+
[1, 2, 1.0],
99+
[3, np.nan, 1.0],
100+
[np.nan, 6, 1.0],
101+
],
102+
columns=columns_out,
103+
index=index,
104+
)
105+
result = data.add_station_features(df)
106+
pd.testing.assert_frame_equal(result, expected, atol=1e-3)
107+
108+
109+
def test_add_datetime_features() -> None:
110+
columns_out = columns + ["time_cos"]
111+
expected = pd.DataFrame(
112+
[
113+
[1, 2, 0.512],
114+
[3, np.nan, 0.512],
115+
[np.nan, 6, 0.512],
116+
],
117+
columns=columns_out,
118+
index=index,
119+
)
120+
result = data.add_datetime_features(df)
121+
pd.testing.assert_frame_equal(result, expected, atol=1e-3)

0 commit comments

Comments
 (0)