Skip to content

Commit 4a8dd79

Browse files
author
Gsaes
committed
Ajout des tests pour la class Comparator
1 parent a555d60 commit 4a8dd79

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

qolmat/benchmark/comparator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(
4949
dict_models: Dict[str, Any],
5050
selected_columns: List[str],
5151
generator_holes: _HoleGenerator,
52-
metrics: List = ["mae", "wmape", "KL"],
53-
search_params: Optional[Dict[str, Dict[str, Union[float, int, str]]]] = {},
52+
metrics: List = ["mae", "wmape", "KL_columnwise"],
53+
search_params: Optional[Dict] = {},
5454
n_calls_opt: int = 10,
5555
):
5656
self.dict_imputers = dict_models
@@ -154,7 +154,6 @@ def compare(
154154

155155
for name, imputer in self.dict_imputers.items():
156156
search_params = self.search_params.get(name, {})
157-
158157
list_spaces = utils.get_search_space(search_params)
159158

160159
try:

tests/benchmark/test_comparator.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from qolmat.benchmark import comparator
6+
from qolmat.imputations.imputers import ImputerMedian, ImputerRPCA
7+
from qolmat.benchmark.missing_patterns import EmpiricalHoleGenerator
8+
9+
df_origin = pd.DataFrame({"col1": [0, np.nan, 2, 4, np.nan], "col2": [-1, np.nan, 0.5, 1, 1.5]})
10+
11+
df_imputed = pd.DataFrame({"col1": [0, 1, 2, 3.5, 4], "col2": [-1.5, 0, 1.5, 2, 1.5]})
12+
13+
df_mask = pd.DataFrame(
14+
{"col1": [False, False, True, True, False], "col2": [True, False, True, True, False]}
15+
)
16+
17+
cols_to_impute = ["col1", "col2"]
18+
generator_holes = EmpiricalHoleGenerator(n_splits=1, ratio_masked=0.5)
19+
dict_imputers_median = {"median": ImputerMedian()}
20+
dict_imputers_rpca = {"rpca": ImputerRPCA(max_iter=100, tau=2)}
21+
search_params = {"rpca": {"lam": {"min": 0.1, "max": 1, "type": "Real"}}}
22+
23+
comparison_median = comparator.Comparator(
24+
dict_models=dict_imputers_median,
25+
selected_columns=cols_to_impute,
26+
generator_holes=generator_holes,
27+
)
28+
29+
comparison_rpca = comparator.Comparator(
30+
dict_models=dict_imputers_rpca,
31+
selected_columns=cols_to_impute,
32+
generator_holes=generator_holes,
33+
search_params=search_params,
34+
)
35+
36+
comparison_bug = comparator.Comparator(
37+
dict_models=dict_imputers_median,
38+
selected_columns=["bug"],
39+
generator_holes=generator_holes,
40+
search_params=search_params,
41+
)
42+
43+
result_expected_median = [3.0, 0.5, 0.75, 0.5, 37.88948, 39.68123]
44+
result_expected_rpca = [4.0, 1.0, 1.0, 1.0, 37.60179, 38.98809]
45+
46+
comparison_dict = {"median": comparison_median, "rpca": comparison_rpca, "bug": comparison_bug}
47+
result_expected_dict = {"median": result_expected_median, "rpca": result_expected_rpca}
48+
49+
50+
@pytest.mark.parametrize("df1", [df_origin])
51+
@pytest.mark.parametrize("df2", [df_imputed])
52+
@pytest.mark.parametrize("df_mask", [df_mask])
53+
def test_benchmark_comparator_get_errors(
54+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
55+
) -> None:
56+
result_comparison = comparison_median.get_errors(
57+
df_origin=df1, df_imputed=df2, df_mask=df_mask
58+
)
59+
result = list(result_comparison.values)
60+
result_expected = [0.25, 0.83333, 0.0625, 1.16666, 18.80089, 36.63671]
61+
np.testing.assert_allclose(result, result_expected, atol=1e-5)
62+
63+
64+
@pytest.mark.parametrize("df1", [df_origin])
65+
def test_benchmark_comparator_evaluate_errors_sample(df1: pd.DataFrame) -> None:
66+
result_comparison = comparison_median.evaluate_errors_sample(
67+
dict_imputers_median["median"], df1
68+
)
69+
result = comparison_rpca.evaluate_errors_sample(dict_imputers_rpca["rpca"], df1)
70+
result = list(result_comparison.values)
71+
np.testing.assert_allclose(result, result_expected_median, atol=1e-5)
72+
73+
74+
@pytest.mark.parametrize("df1", [df_origin])
75+
@pytest.mark.parametrize("imputer", ["median", "rpca", "bug"])
76+
def test_benchmark_comparator_compare(df1: pd.DataFrame, imputer: str) -> None:
77+
comparison = comparison_dict[imputer]
78+
if imputer == "bug":
79+
np.testing.assert_raises(Exception, comparison.compare, df1)
80+
else:
81+
result_comparison = comparison.compare(df1)
82+
result = list(result_comparison.values.flatten())
83+
np.testing.assert_allclose(result, result_expected_dict[imputer], atol=1e-5)

0 commit comments

Comments
 (0)