Skip to content

Commit 2931ae3

Browse files
author
Gsaes
committed
Nouveau test
1 parent 4a8dd79 commit 2931ae3

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
import skopt
5+
6+
from qolmat.benchmark import cross_validation
7+
from qolmat.imputations.imputers import ImputerRPCA
8+
from qolmat.benchmark.missing_patterns import EmpiricalHoleGenerator
9+
from qolmat.benchmark.utils import get_search_space
10+
11+
df_origin = pd.DataFrame({"col1": [0, np.nan, 2, 4, np.nan], "col2": [-1, np.nan, 0.5, 1, 1.5]})
12+
df_imputed = pd.DataFrame({"col1": [0, 1, 2, 3.5, 4], "col2": [-1.5, 0, 1.5, 2, 1.5]})
13+
df_mask = pd.DataFrame(
14+
{"col1": [False, False, True, True, False], "col2": [True, False, True, True, False]}
15+
)
16+
17+
df_corrupted = df_origin.copy()
18+
df_corrupted[df_mask] = np.nan
19+
20+
imputer_rpca = ImputerRPCA(max_iter=100, tau=2)
21+
generator_holes = EmpiricalHoleGenerator(n_splits=1, ratio_masked=0.5)
22+
search_params = {"rpca": {"lam": {"min": 0.1, "max": 1, "type": "Real"}}}
23+
list_spaces = get_search_space(search_params.get("rpca", {}))
24+
cv = cross_validation.CrossValidation(
25+
imputer=imputer_rpca, list_spaces=list_spaces, hole_generator=generator_holes
26+
)
27+
28+
29+
@pytest.mark.parametrize("df1", [df_origin])
30+
@pytest.mark.parametrize("df2", [df_imputed])
31+
@pytest.mark.parametrize("df_mask", [df_mask])
32+
def test_benchmark_cross_validation_loss_function(
33+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
34+
) -> None:
35+
36+
cv.loss_norm = 3
37+
np.testing.assert_raises(ValueError, cv.loss_function, df1, df2, df_mask)
38+
cv.loss_norm = 2
39+
result_cv2 = cv.loss_function(df_origin=df1, df_imputed=df2, df_mask=df_mask)
40+
np.testing.assert_allclose(result_cv2, 1.58113, atol=1e-5)
41+
cv.loss_norm = 1
42+
result_cv1 = cv.loss_function(df_origin=df1, df_imputed=df2, df_mask=df_mask)
43+
np.testing.assert_allclose(result_cv1, 3, atol=1e-5)
44+
45+
46+
@pytest.mark.parametrize("df", [df_corrupted])
47+
def test_benchmark_cross_validation_deflat_hyperparams(df: pd.DataFrame) -> None:
48+
res = skopt.gp_minimize(
49+
cv.objective(df),
50+
dimensions=cv.list_spaces,
51+
n_calls=cv.n_calls,
52+
n_initial_points=max(5, cv.n_calls // 5),
53+
random_state=42,
54+
n_jobs=cv.n_jobs,
55+
)
56+
hyperparams_flat = {space.name: val for space, val in zip(cv.list_spaces, res["x"])}
57+
result_hyperparams = cv.deflat_hyperparams(hyperparams_flat)
58+
result = result_hyperparams["lam"]
59+
np.testing.assert_allclose(result, 0.816888, atol=1e-5)
60+
61+
62+
@pytest.mark.parametrize("df", [df_corrupted])
63+
@pytest.mark.parametrize("return_hyper_params", [True, False])
64+
def test_benchmark_cross_validation_fit_transform(
65+
df: pd.DataFrame, return_hyper_params: bool
66+
) -> None:
67+
68+
if return_hyper_params:
69+
result_cv, result_hyp = cv.fit_transform(
70+
df_corrupted, return_hyper_params=return_hyper_params
71+
)
72+
np.testing.assert_allclose(result_hyp["lam"], 0.816888, atol=1e-5)
73+
else:
74+
result_cv = cv.fit_transform(df_corrupted, return_hyper_params=return_hyper_params)
75+
result = np.array(result_cv)
76+
result_expected = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 1.5]])
77+
np.testing.assert_allclose(result, result_expected, atol=1e-5)

0 commit comments

Comments
 (0)