Skip to content

Commit 4f3e370

Browse files
author
Gsaes
committed
Modification test data
1 parent 13ae7dc commit 4f3e370

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

qolmat/utils/data.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111
from qolmat.benchmark import missing_patterns
1212

1313

14+
def download_data(zipname: str, urllink: str, datapath: str = "data/") -> List[pd.DataFrame]:
15+
path_zip = os.path.join(datapath)
16+
print("Download data")
17+
if not os.path.exists(path_zip + ".zip"):
18+
if not os.path.exists(datapath):
19+
os.mkdir(datapath)
20+
request.urlretrieve(urllink + zipname + ".zip", path_zip + ".zip")
21+
22+
with zipfile.ZipFile(path_zip + ".zip", "r") as zip_ref:
23+
zip_ref.extractall(path_zip)
24+
data_folder = os.listdir(path_zip)
25+
subfolder = os.path.join(path_zip, data_folder[0])
26+
data_files = os.listdir(subfolder)
27+
list_df = [pd.read_csv(os.path.join(subfolder, file)) for file in data_files]
28+
return list_df
29+
30+
1431
def get_data(
1532
name_data: str = "Beijing", datapath: str = "data/", download: Optional[bool] = True
1633
) -> pd.DataFrame:
@@ -32,19 +49,7 @@ def get_data(
3249
if name_data == "Beijing":
3350
urllink = "https://archive.ics.uci.edu/ml/machine-learning-databases/00501/"
3451
zipname = "PRSA2017_Data_20130301-20170228"
35-
path_zip = os.path.join(datapath, zipname)
36-
37-
if not os.path.exists(path_zip + ".zip"):
38-
if not os.path.exists(datapath):
39-
os.mkdir(datapath)
40-
request.urlretrieve(urllink + zipname + ".zip", path_zip + ".zip")
41-
42-
with zipfile.ZipFile(path_zip + ".zip", "r") as zip_ref:
43-
zip_ref.extractall(path_zip)
44-
data_folder = os.listdir(path_zip)
45-
subfolder = os.path.join(path_zip, data_folder[0])
46-
data_files = os.listdir(subfolder)
47-
list_df = [pd.read_csv(os.path.join(subfolder, file)) for file in data_files]
52+
list_df = download_data(zipname, urllink, datapath=datapath)
4853
list_df = [preprocess_data(df) for df in list_df]
4954
df = pd.concat(list_df)
5055
return df

tests/utils/test_data.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from qolmat.utils import data
8+
from pytest_mock.plugin import MockerFixture
89

910
columns = ["No", "year", "month", "day", "hour", "a", "b", "wd", "station"]
1011
df = pd.DataFrame(
@@ -30,31 +31,25 @@
3031

3132

3233
@pytest.mark.parametrize("name_data", ["Beijing", "Artificial", "Bug"])
33-
def test_utils_data_get_data(name_data: str) -> None:
34+
def test_utils_data_get_data(name_data: str, mocker: MockerFixture) -> None:
35+
mock_download = mocker.patch("qolmat.utils.data.download_data", return_value=[df])
36+
mocker.patch("qolmat.utils.data.preprocess_data", return_value=df_preprocess)
37+
try:
38+
df_result = data.get_data(name_data=name_data)
39+
except ValueError:
40+
assert name_data not in ["Beijing", "Artificial"]
41+
np.testing.assert_raises(ValueError, data.get_data, name_data)
42+
return
43+
3444
if name_data == "Beijing":
35-
df = data.get_data(name_data=name_data)
36-
expected_columns = [
37-
"PM2.5",
38-
"PM10",
39-
"SO2",
40-
"NO2",
41-
"CO",
42-
"O3",
43-
"TEMP",
44-
"PRES",
45-
"DEWP",
46-
"RAIN",
47-
"WSPM",
48-
]
49-
assert isinstance(df, pd.DataFrame)
50-
assert df.columns.tolist() == expected_columns
45+
assert mock_download.call_count == 1
46+
pd.testing.assert_frame_equal(df_result, df_preprocess)
5147
elif name_data == "Artificial":
52-
df = data.get_data(name_data=name_data)
5348
expected_columns = ["signal", "X", "A", "E"]
54-
assert isinstance(df, pd.DataFrame)
55-
assert df.columns.tolist() == expected_columns
49+
assert isinstance(df_result, pd.DataFrame)
50+
assert df_result.columns.tolist() == expected_columns
5651
else:
57-
np.testing.assert_raises(ValueError, data.get_data, name_data)
52+
assert False
5853

5954

6055
@pytest.mark.parametrize("df", [df])

0 commit comments

Comments
 (0)