22
33import numpy as np
44import pandas as pd
5+ import pytest
56
67from qolmat .benchmark .comparator import Comparator
8+ from qolmat .benchmark .missing_patterns import UniformHoleGenerator
9+ from qolmat .imputations .imputers import ImputerShuffle
710
8- generator_holes_mock = MagicMock ()
9- generator_holes_mock .split .return_value = [
10- pd .DataFrame ({"A" : [False , False , True ], "B" : [True , False , False ]})
11- ]
1211
13- comparator = Comparator (
14- dict_models = {},
15- selected_columns = ["A" , "B" ],
16- generator_holes = generator_holes_mock ,
17- metrics = ["mae" , "mse" ],
18- )
12+ @pytest .fixture
13+ def comparator_fix ():
14+ generator_holes_mock = MagicMock ()
15+ generator_holes_mock .split .return_value = [
16+ pd .DataFrame ({"A" : [False , False , True ], "B" : [True , False , False ]})
17+ ]
18+ generator_holes_mock .random_state = 0
19+ comparator = Comparator (
20+ dict_models = {},
21+ selected_columns = ["A" , "B" ],
22+ generator_holes = generator_holes_mock ,
23+ metrics = ["mae" , "mse" ],
24+ )
25+ return comparator
26+
1927
2028imputer_mock = MagicMock ()
2129expected_get_errors = pd .Series (
2735
2836
2937@patch ("qolmat.benchmark.metrics.get_metric" )
30- def test_get_errors (mock_get_metric ):
38+ def test_get_errors (mock_get_metric , comparator_fix ):
3139 df_origin = pd .DataFrame ({"A" : [1 , np .nan , 3 ], "B" : [np .nan , 5 , 6 ]})
3240 df_imputed = pd .DataFrame ({"A" : [1 , 2 , 4 ], "B" : [4 , 5 , 7 ]})
3341 df_mask = pd .DataFrame (
@@ -39,7 +47,7 @@ def test_get_errors(mock_get_metric):
3947 [1.0 , 1.0 ], index = ["A" , "B" ]
4048 )
4149 )
42- errors = comparator .get_errors (df_origin , df_imputed , df_mask )
50+ errors = comparator_fix .get_errors (df_origin , df_imputed , df_mask )
4351 pd .testing .assert_series_equal (errors , expected_get_errors )
4452
4553
@@ -48,8 +56,10 @@ def test_get_errors(mock_get_metric):
4856 "qolmat.benchmark.comparator.Comparator.get_errors" ,
4957 return_value = expected_get_errors ,
5058)
51- def test_evaluate_errors_sample (mock_get_errors , mock_optimize ):
52- errors_mean = comparator .evaluate_errors_sample (
59+ def test_evaluate_errors_sample (
60+ mock_get_errors , mock_optimize , comparator_fix
61+ ):
62+ errors_mean = comparator_fix .evaluate_errors_sample (
5363 imputer_mock , pd .DataFrame ({"A" : [1 , 2 , 3 ], "B" : [4 , 5 , np .nan ]})
5464 )
5565 expected_errors_mean = expected_get_errors
@@ -62,12 +72,12 @@ def test_evaluate_errors_sample(mock_get_errors, mock_optimize):
6272 "qolmat.benchmark.comparator.Comparator.evaluate_errors_sample" ,
6373 return_value = expected_get_errors ,
6474)
65- def test_compare (mock_evaluate_errors_sample ):
75+ def test_compare (mock_evaluate_errors_sample , comparator_fix ):
6676 df_test = pd .DataFrame ({"A" : [1 , 2 , 3 ], "B" : [4 , 5 , 6 ]})
6777
6878 imputer1 = MagicMock (name = "Imputer1" )
6979 imputer2 = MagicMock (name = "Imputer2" )
70- comparator .dict_imputers = {"imputer1" : imputer1 , "imputer2" : imputer2 }
80+ comparator_fix .dict_imputers = {"imputer1" : imputer1 , "imputer2" : imputer2 }
7181
7282 errors_imputer1 = pd .Series ([0.1 , 0.2 ], index = ["mae" , "mse" ])
7383 errors_imputer2 = pd .Series ([0.3 , 0.4 ], index = ["mae" , "mse" ])
@@ -76,7 +86,7 @@ def test_compare(mock_evaluate_errors_sample):
7686 errors_imputer2 ,
7787 ]
7888
79- df_errors = comparator .compare (df_test )
89+ df_errors = comparator_fix .compare (df_test )
8090 assert mock_evaluate_errors_sample .call_count == 2
8191
8292 mock_evaluate_errors_sample .assert_any_call (imputer1 , df_test , {}, "mse" )
@@ -85,3 +95,28 @@ def test_compare(mock_evaluate_errors_sample):
8595 {"imputer1" : [0.1 , 0.2 ], "imputer2" : [0.3 , 0.4 ]}, index = ["mae" , "mse" ]
8696 )
8797 pd .testing .assert_frame_equal (df_errors , expected_df_errors )
98+
99+
100+ def test_compare_reproducibility ():
101+ seed = 123
102+ dict_models = {
103+ "shuffle1" : ImputerShuffle (random_state = seed ),
104+ "shuffle2" : ImputerShuffle (random_state = seed ),
105+ }
106+ cols = ["A" , "B" ]
107+ df_data = pd .DataFrame (
108+ np .random .random ((100 , 2 )), dtype = float , columns = cols
109+ )
110+ generator_holes = UniformHoleGenerator (
111+ n_splits = 2 , subset = cols , ratio_masked = 0.5
112+ )
113+ comparator = Comparator (
114+ dict_models = dict_models ,
115+ selected_columns = df_data .columns ,
116+ generator_holes = generator_holes ,
117+ metrics = ["mae" , "mse" ],
118+ )
119+ df_errors = comparator .compare (df_data )
120+ pd .testing .assert_series_equal (
121+ df_errors ["shuffle1" ], df_errors ["shuffle2" ], check_names = False
122+ )
0 commit comments