Skip to content

Commit 155c24a

Browse files
Merge pull request #35 from Quantmetry/refacto_cross_validation
Refacto cross validation
2 parents e7ddc07 + c4332f7 commit 155c24a

File tree

7 files changed

+68
-80
lines changed

7 files changed

+68
-80
lines changed

HISTORY.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
=======
22
History
33
=======
4-
0.0.12 (2023-05-31)
4+
0.0.13 (2023-06-07)
55
-------------------
66

7+
* Refacto cross validation
78
* Fix Readme
89
* Add test utils.plot
910

qolmat/benchmark/comparator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ def evaluate_errors_sample(
121121
hole_generator=self.generator_holes,
122122
n_calls=self.n_calls_opt,
123123
)
124-
df_imputed = cv.fit_transform(df_corrupted)
124+
imputer.hyperparams_optim = cv.optimize_hyperparams(df_corrupted)
125125
else:
126-
df_imputed = imputer.fit_transform(df_corrupted)
126+
imputer.hyperparams_optim = {}
127+
df_imputed = imputer.fit_transform(df_corrupted)
127128
subset = self.generator_holes.subset
128129
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
129130
list_errors.append(errors)

qolmat/benchmark/cross_validation.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def obj_func(**hyperparams_flat):
207207

208208
return obj_func
209209

210-
def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int, str]]:
210+
def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Any]:
211211
"""Optimize hyperparamaters
212212
213213
Parameters
@@ -217,7 +217,7 @@ def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int,
217217
218218
Returns
219219
-------
220-
Dict[str, Union[float,int, str]]
220+
Dict[str, Any]
221221
hyperparameters optimize flat
222222
"""
223223
list_spaces = get_search_space(self.dict_config_opti_imputer)
@@ -231,25 +231,5 @@ def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int,
231231
)
232232

233233
hyperparams_flat = {space.name: val for space, val in zip(list_spaces, res["x"])}
234-
return hyperparams_flat
235-
236-
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
237-
"""
238-
Fit and transform estimator and impute the missing values.
239-
240-
Parameters
241-
----------
242-
df : pd.DataFrame
243-
dataframe to impute
244-
245-
Returns
246-
-------
247-
pd.DataFrame
248-
imputed dataframe
249-
"""
250-
251-
hyperparams_flat = self.optimize_hyperparams(df)
252-
self.imputer.hyperparams_optim = deflat_hyperparams(hyperparams_flat)
253-
df_imputed = self.imputer.fit_transform(df)
254-
255-
return df_imputed
234+
hyperparams = deflat_hyperparams(hyperparams_flat)
235+
return hyperparams

qolmat/imputations/imputers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def __init__(
4848
random_state: Union[None, int, np.random.RandomState] = None,
4949
):
5050
self.hyperparams_user = hyperparams
51-
self.hyperparams_optim: Dict = {}
52-
self.hyperparams_local: Dict = {}
5351
self.groups = groups
5452
self.columnwise = columnwise
5553
self.shrink = shrink
@@ -82,7 +80,8 @@ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
8280
self.estimator.random_state = self.rng
8381

8482
hyperparams = self.hyperparams_user.copy()
85-
hyperparams.update(self.hyperparams_optim)
83+
if hasattr(self, "hyperparams_optim"):
84+
hyperparams.update(self.hyperparams_optim)
8685
cols_with_nans = df.columns[df.isna().any()]
8786

8887
if self.groups == []:

qolmat/utils/data.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import zipfile
3-
from datetime import datetime
43
from math import pi
54
from typing import List, Optional
65
from urllib import request
@@ -11,6 +10,22 @@
1110
from qolmat.benchmark import missing_patterns
1211

1312

13+
def download_data(zipname: str, urllink: str, datapath: str = "data/") -> List[pd.DataFrame]:
14+
path_zip = os.path.join(datapath)
15+
if not os.path.exists(path_zip + ".zip"):
16+
if not os.path.exists(datapath):
17+
os.mkdir(datapath)
18+
request.urlretrieve(urllink + zipname + ".zip", path_zip + ".zip")
19+
20+
with zipfile.ZipFile(path_zip + ".zip", "r") as zip_ref:
21+
zip_ref.extractall(path_zip)
22+
data_folder = os.listdir(path_zip)
23+
subfolder = os.path.join(path_zip, data_folder[0])
24+
data_files = os.listdir(subfolder)
25+
list_df = [pd.read_csv(os.path.join(subfolder, file)) for file in data_files]
26+
return list_df
27+
28+
1429
def get_data(
1530
name_data: str = "Beijing", datapath: str = "data/", download: Optional[bool] = True
1631
) -> pd.DataFrame:
@@ -32,19 +47,7 @@ def get_data(
3247
if name_data == "Beijing":
3348
urllink = "https://archive.ics.uci.edu/ml/machine-learning-databases/00501/"
3449
zipname = "PRSA2017_Data_20130301-20170228"
35-
path_zip = os.path.join(datapath, zipname)
36-
37-
if not os.path.exists(path_zip + ".zip"):
38-
if not os.path.exists(datapath):
39-
os.mkdir(datapath)
40-
request.urlretrieve(urllink + zipname + ".zip", path_zip + ".zip")
41-
42-
with zipfile.ZipFile(path_zip + ".zip", "r") as zip_ref:
43-
zip_ref.extractall(path_zip)
44-
data_folder = os.listdir(path_zip)
45-
subfolder = os.path.join(path_zip, data_folder[0])
46-
data_files = os.listdir(subfolder)
47-
list_df = [pd.read_csv(os.path.join(subfolder, file)) for file in data_files]
50+
list_df = download_data(zipname, urllink, datapath=datapath)
4851
list_df = [preprocess_data(df) for df in list_df]
4952
df = pd.concat(list_df)
5053
return df

tests/benchmark/test_cross_validation.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Dict, Union
2+
23
import numpy as np
34
import pandas as pd
45
import pytest
56

67
from qolmat.benchmark import cross_validation
7-
from qolmat.imputations.imputers import ImputerRPCA
88
from qolmat.benchmark.missing_patterns import EmpiricalHoleGenerator
9+
from qolmat.imputations.imputers import ImputerRPCA
910

1011
df_origin = pd.DataFrame({"col1": [0, np.nan, 2, 4, np.nan], "col2": [-1, np.nan, 0.5, 1, 1.5]})
1112
df_imputed = pd.DataFrame({"col1": [0, 1, 2, 3.5, 4], "col2": [-1.5, 0, 1.5, 2, 1.5]})
@@ -87,7 +88,6 @@ def test_benchmark_cross_validation_deflat_hyperparams(
8788
def test_benchmark_cross_validation_loss_function(
8889
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
8990
) -> None:
90-
9191
cv.loss_norm = 3
9292
np.testing.assert_raises(ValueError, cv.loss_function, df1, df2, df_mask)
9393
cv.loss_norm = 2
@@ -102,17 +102,12 @@ def test_benchmark_cross_validation_loss_function(
102102
def test_benchmark_cross_validation_optimize_hyperparams(df: pd.DataFrame) -> None:
103103
result_hp = cv.optimize_hyperparams(df)
104104
result_expected = {
105-
"lam/col1": 4.799603622475375,
106-
"lam/col2": 1.5503043695984915,
105+
"lam": {
106+
"col1": 4.799603622475375,
107+
"col2": 1.5503043695984915,
108+
},
107109
"tol": 0.07796932033627668,
108110
"max_iter": 100,
109111
"norm": "L1",
110112
}
111113
assert result_hp == result_expected
112-
113-
114-
@pytest.mark.parametrize("df", [df_corrupted])
115-
def test_benchmark_cross_validation_fit_transform(df: pd.DataFrame) -> None:
116-
result_cv = cv.fit_transform(df)
117-
result_expected = pd.DataFrame({"col1": [0, 2, 2, 4, 2], "col2": [1.5, 1.5, 1.5, 1.5, 1.5]})
118-
np.testing.assert_allclose(result_cv, result_expected, atol=1e-5)

tests/utils/test_data.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from qolmat.utils import data
8+
from pytest_mock.plugin import MockerFixture
89

910
columns = ["No", "year", "month", "day", "hour", "a", "b", "wd", "station"]
1011
df = pd.DataFrame(
@@ -28,33 +29,37 @@
2829
[[1, 2], [3, np.nan], [np.nan, 6]], columns=["a", "b"], index=index_preprocess
2930
)
3031

32+
urllink = "https://archive.ics.uci.edu/ml/machine-learning-databases/00501/"
33+
zipname = "PRSA2017_Data_20130301-20170228"
34+
35+
36+
# @pytest.mark.parametrize("zipname, urllink", [(zipname, urllink)])
37+
# def test_utils_data_download_data(zipname: str, urllink: str, mocker: MockerFixture) -> None:
38+
# mocker.patch("urllib.request.urlretrieve")
39+
# mocker.patch("zipfile.ZipFile")
40+
# list_df_result = data.download_data(zipname, urllink)
41+
3142

3243
@pytest.mark.parametrize("name_data", ["Beijing", "Artificial", "Bug"])
33-
def test_utils_data_get_data(name_data: str) -> None:
44+
def test_utils_data_get_data(name_data: str, mocker: MockerFixture) -> None:
45+
mock_download = mocker.patch("qolmat.utils.data.download_data", return_value=[df])
46+
mocker.patch("qolmat.utils.data.preprocess_data", return_value=df_preprocess)
47+
try:
48+
df_result = data.get_data(name_data=name_data)
49+
except ValueError:
50+
assert name_data not in ["Beijing", "Artificial"]
51+
np.testing.assert_raises(ValueError, data.get_data, name_data)
52+
return
53+
3454
if name_data == "Beijing":
35-
df = data.get_data(name_data=name_data)
36-
expected_columns = [
37-
"PM2.5",
38-
"PM10",
39-
"SO2",
40-
"NO2",
41-
"CO",
42-
"O3",
43-
"TEMP",
44-
"PRES",
45-
"DEWP",
46-
"RAIN",
47-
"WSPM",
48-
]
49-
assert isinstance(df, pd.DataFrame)
50-
assert df.columns.tolist() == expected_columns
55+
assert mock_download.call_count == 1
56+
pd.testing.assert_frame_equal(df_result, df_preprocess)
5157
elif name_data == "Artificial":
52-
df = data.get_data(name_data=name_data)
5358
expected_columns = ["signal", "X", "A", "E"]
54-
assert isinstance(df, pd.DataFrame)
55-
assert df.columns.tolist() == expected_columns
59+
assert isinstance(df_result, pd.DataFrame)
60+
assert df_result.columns.tolist() == expected_columns
5661
else:
57-
np.testing.assert_raises(ValueError, data.get_data, name_data)
62+
assert False
5863

5964

6065
@pytest.mark.parametrize("df", [df])
@@ -72,11 +77,15 @@ def test_utils_data_add_holes(df: pd.DataFrame) -> None:
7277

7378

7479
@pytest.mark.parametrize("name_data", ["Beijing"])
75-
def test_utils_data_get_data_corrupted(name_data: str) -> None:
80+
def test_utils_data_get_data_corrupted(name_data: str, mocker: MockerFixture) -> None:
81+
mock_download = mocker.patch("qolmat.utils.data.download_data", return_value=[df])
82+
mocker.patch("qolmat.utils.data.preprocess_data", return_value=df_preprocess)
7683
df_out = data.get_data_corrupted()
77-
size_df_out = df_out.shape
78-
n = size_df_out[0] * size_df_out[1]
79-
np.testing.assert_allclose(df_out.isna().sum().sum() / n, 0.2, atol=0.1)
84+
df_result = pd.DataFrame(
85+
[[1, 2], [np.nan, np.nan], [np.nan, 6]], columns=["a", "b"], index=index_preprocess
86+
)
87+
assert mock_download.call_count == 1
88+
pd.testing.assert_frame_equal(df_result, df_out)
8089

8190

8291
@pytest.mark.parametrize("df", [df_preprocess])

0 commit comments

Comments
 (0)