Skip to content

Commit fbfbaec

Browse files
author
vm-aifluence-jro
committed
test values simplified
1 parent 8ac3ac8 commit fbfbaec

File tree

2 files changed

+38
-94
lines changed

2 files changed

+38
-94
lines changed

qolmat/utils/data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,9 @@ def preprocess_data(df: pd.DataFrame):
9999
df.set_index(["station", "datetime"], inplace=True)
100100
df.drop(columns=["year", "month", "day", "hour", "wd", "No"], inplace=True)
101101
df.sort_index(inplace=True)
102-
dict_agg = {key: np.mean for key in df.columns}
103-
dict_agg["RAIN"] = np.mean
104102
df = df.groupby(
105103
["station", df.index.get_level_values("datetime").floor("d")], group_keys=False
106-
).agg(dict_agg)
104+
).mean()
107105
return df
108106

109107

tests/utils/test_data.py

Lines changed: 37 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,41 @@
1+
import datetime
2+
3+
import numpy as np
14
import pandas as pd
25
import pytest
3-
import numpy as np
4-
import datetime
56

67
from qolmat.utils import data
78

8-
columns = ["PM2.5", "PM10", "SO2", "NO2", "CO", "O3", "TEMP", "PRES", "DEWP", "RAIN", "WSPM"]
9+
columns = ["a", "b"]
910
index = pd.MultiIndex.from_tuples(
1011
[
1112
("Gucheng", datetime.datetime(2013, 3, 1)),
1213
("Gucheng", datetime.datetime(2014, 3, 1)),
1314
("Gucheng", datetime.datetime(2015, 3, 1)),
14-
("Gucheng", datetime.datetime(2016, 3, 1)),
1515
],
1616
names=["station", "datetime"],
1717
)
1818
df = pd.DataFrame(
1919
[
20-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4],
21-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4],
22-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4],
23-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4],
20+
[1, 2],
21+
[3, np.nan],
22+
[np.nan, 6],
2423
],
2524
columns=columns,
2625
index=index,
2726
)
2827

2928

30-
def test_preprocess_data() -> None:
31-
29+
def test_preprocess_data():
3230
columns_raw = [
3331
"No",
3432
"year",
3533
"month",
3634
"day",
3735
"hour",
38-
"PM2.5",
39-
"PM10",
40-
"SO2",
41-
"NO2",
42-
"CO",
43-
"O3",
44-
"TEMP",
45-
"PRES",
46-
"DEWP",
47-
"RAIN",
36+
"a",
37+
"b",
4838
"wd",
49-
"WSPM",
5039
"station",
5140
]
5241
df_raw = pd.DataFrame(
@@ -57,18 +46,9 @@ def test_preprocess_data() -> None:
5746
3,
5847
1,
5948
0,
60-
6.0,
61-
18.0,
62-
5.0,
63-
np.nan,
64-
800.0,
65-
88.0,
66-
0.1,
67-
1021.1,
68-
-18.6,
69-
0.0,
49+
1,
50+
2,
7051
"NW",
71-
4.4,
7252
"Gucheng",
7353
],
7454
[
@@ -77,18 +57,9 @@ def test_preprocess_data() -> None:
7757
3,
7858
1,
7959
0,
80-
6.0,
81-
18.0,
82-
5.0,
60+
3,
8361
np.nan,
84-
800.0,
85-
88.0,
86-
0.1,
87-
1021.1,
88-
-18.6,
89-
0.0,
9062
"NW",
91-
4.4,
9263
"Gucheng",
9364
],
9465
[
@@ -97,79 +68,54 @@ def test_preprocess_data() -> None:
9768
3,
9869
1,
9970
0,
100-
6.0,
101-
18.0,
102-
5.0,
103-
0.1,
104-
800.0,
105-
88.0,
106-
0.1,
107-
1021.1,
108-
-18.6,
109-
0.0,
110-
"NW",
111-
4.4,
112-
"Gucheng",
113-
],
114-
[
115-
4,
116-
2016,
117-
3,
118-
1,
119-
0,
120-
6.0,
121-
18.0,
122-
5.0,
123-
0.1,
124-
800.0,
125-
88.0,
126-
0.1,
127-
1021.1,
128-
-18.6,
129-
0.0,
71+
np.nan,
72+
6,
13073
"NW",
131-
4.4,
13274
"Gucheng",
13375
],
13476
],
13577
columns=columns_raw,
13678
)
137-
138-
assert data.preprocess_data(df_raw).equals(df)
79+
print(df_raw)
80+
result = data.preprocess_data(df_raw)
81+
print(result)
82+
print(df)
83+
# assert result.equals(df)
84+
pd.testing.assert_frame_equal(result, df, atol=1e-3)
13985

14086

14187
def test_add_holes() -> None:
142-
assert data.add_holes(df, 0.0, 1).isna().sum().sum() == 2
143-
assert data.add_holes(df, 1.0, 1).isna().sum().sum() > 2
88+
df_out = data.add_holes(df, 0.0, 1)
89+
assert df_out.isna().sum().sum() == 2
90+
df_out = data.add_holes(df, 1.0, 1)
91+
assert df_out.isna().sum().sum() > 2
14492

14593

14694
def test_add_station_features() -> None:
14795
columns_out = columns + ["station=Gucheng"]
148-
df_out = pd.DataFrame(
96+
expected = pd.DataFrame(
14997
[
150-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 1.0],
151-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 1.0],
152-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 1.0],
153-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 1.0],
98+
[1, 2, 1.0],
99+
[3, np.nan, 1.0],
100+
[np.nan, 6, 1.0],
154101
],
155102
columns=columns_out,
156103
index=index,
157104
)
158-
159-
assert data.add_station_features(df).equals(df_out)
105+
result = data.add_station_features(df)
106+
pd.testing.assert_frame_equal(result, expected, atol=1e-3)
160107

161108

162109
def test_add_datetime_features() -> None:
163110
columns_out = columns + ["time_cos"]
164-
df_out = pd.DataFrame(
111+
expected = pd.DataFrame(
165112
[
166-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 0.51237141],
167-
[6.0, 18.0, 5.0, np.nan, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 0.51237141],
168-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 0.51237141],
169-
[6.0, 18.0, 5.0, 0.1, 800.0, 88.0, 0.1, 1021.1, -18.6, 0.0, 4.4, 0.5],
113+
[1, 2, 0.512],
114+
[3, np.nan, 0.512],
115+
[np.nan, 6, 0.512],
170116
],
171117
columns=columns_out,
172118
index=index,
173119
)
174-
175-
np.testing.assert_allclose(data.add_datetime_features(df), df_out, atol=1.0e-5)
120+
result = data.add_datetime_features(df)
121+
pd.testing.assert_frame_equal(result, expected, atol=1e-3)

0 commit comments

Comments
 (0)