1+ import logging
12from typing import Dict , List , Optional , Union
23
34import numpy as np
45import pandas as pd
56
7+ from qolmat import logging as qlog
68from qolmat .benchmark import cross_validation , utils
79from qolmat .benchmark .missing_patterns import _HoleGenerator
810
11+ qlog .log_setup ()
12+ logger = logging .getLogger (__name__ )
13+ # logger.setLevel(logging.DEBUG)
14+
915
1016class Comparator :
1117 """
@@ -22,8 +28,8 @@ class Comparator:
2228 search_params: Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2329 dictionary of search space for each implementation method. By default, the value is set to
2430 {}.
25- n_cv_calls : Optional[int] = 10
26- number of calls of the hyperparameters cross-validation. By default, the value is set to
31+ n_calls_opt : Optional[int] = 10
32+ number of calls of the optimization algorithm
2733 10.
2834 """
2935
@@ -33,18 +39,18 @@ def __init__(
3339 selected_columns : List [str ],
3440 generator_holes : _HoleGenerator ,
3541 search_params : Optional [Dict [str , Dict [str , Union [float , int , str ]]]] = {},
36- n_cv_calls : Optional [int ] = 10 ,
42+ n_calls_opt : Optional [int ] = 10 ,
3743 ):
3844
3945 self .dict_models = dict_models
4046 self .selected_columns = selected_columns
4147 self .generator_holes = generator_holes
4248 self .search_params = search_params
43- self .n_cv_calls = n_cv_calls
49+ self .n_calls_opt = n_calls_opt
4450
4551 def get_errors (
4652 self , df_origin : pd .DataFrame , df_imputed : pd .DataFrame , df_mask : pd .DataFrame
47- ) -> float :
53+ ) -> pd . DataFrame :
4854 """Functions evaluating the reconstruction's quality
4955
5056 Parameters
@@ -73,6 +79,7 @@ def get_errors(
7379 df_origin [df_mask ],
7480 df_imputed [df_mask ],
7581 )
82+
7683 dict_errors ["kl" ] = utils .kl_divergence (
7784 df_origin [df_mask ],
7885 df_imputed [df_mask ],
@@ -82,8 +89,8 @@ def get_errors(
8289 return errors
8390
8491 def evaluate_errors_sample (
85- self , tested_model : any , df : pd .DataFrame , search_space : Optional [ dict ] = None
86- ) -> Dict :
92+ self , imputer : any , df : pd .DataFrame , list_spaces : List [ Dict ] = {}
93+ ) -> pd . Series :
8794 """Evaluate the errors in the cross-validation
8895
8996 Parameters
@@ -92,8 +99,8 @@ def evaluate_errors_sample(
9299 imputation model
93100 df : pd.DataFrame
94101 dataframe to impute
95- search_space : Optional[dict], optional
96- search space for tested_model's hyperparameters, by default None
102+ search_space : Dict
103+ search space for tested_model's hyperparameters
97104
98105 Returns
99106 -------
@@ -102,19 +109,24 @@ def evaluate_errors_sample(
102109 """
103110 list_errors = []
104111 df_origin = df [self .selected_columns ].copy ()
112+ if list_spaces :
113+ print ("Hyperparameter optimization" )
114+ print (list_spaces )
115+ else :
116+ print ("No hyperparameter optimization" )
105117 for df_mask in self .generator_holes .split (df_origin ):
106118 df_corrupted = df_origin .copy ()
107119 df_corrupted [df_mask ] = np .nan
108- if search_space is None :
109- df_imputed = tested_model .fit_transform (X = df_corrupted )
110- else :
120+ if list_spaces :
111121 cv = cross_validation .CrossValidation (
112- tested_model ,
113- search_space = search_space ,
122+ imputer ,
123+ list_spaces = list_spaces ,
114124 hole_generator = self .generator_holes ,
115- n_calls = self .n_cv_calls ,
125+ n_calls = self .n_calls_opt ,
116126 )
117- df_imputed = cv .fit_transform (X = df_corrupted )
127+ df_imputed = cv .fit_transform (df_corrupted )
128+ else :
129+ df_imputed = imputer .fit_transform (df_corrupted )
118130
119131 subset = self .generator_holes .subset
120132 errors = self .get_errors (df_origin [subset ], df_imputed [subset ], df_mask [subset ])
@@ -140,30 +152,29 @@ def compare(self, df: pd.DataFrame, verbose: bool = True):
140152
141153 dict_errors = {}
142154
143- for name , tested_model in self .dict_models .items ():
144- if verbose :
145- print ("Tested model:" , type (tested_model ).__name__ )
155+ for name , imputer in self .dict_models .items ():
156+ logger .setLevel (logging .DEBUG )
157+ print (f"Tested model: { type (imputer ).__name__ } " )
158+
159+ search_params = self .search_params .get (name , {})
146160
147- if str (type (tested_model ).__name__ ) in self .search_params .keys ():
148- if hasattr (tested_model , "columnwise" ) and tested_model .columnwise :
149- if len (self .selected_columns ) > 0 :
150- search_params = {}
151- for col in self .selected_columns :
152- for key , value in self .search_params [type (tested_model ).__name__ ].items ():
153- search_params [f"('{ col } ', '{ key } ')" ] = value
154- else :
155- search_params = self .search_params [type (tested_model ).__name__ ]
156- else :
157- search_params = self .search_params [type (tested_model ).__name__ ]
158-
159- search_space = utils .get_search_space (tested_model , search_params )
161+ # if imputer.columnwise:
162+ # if len(self.selected_columns) > 0:
163+ # search_params = {}
164+ # for col in self.selected_columns:
165+ # for key, value in self.search_params[type(imputer).__name__].items():
166+ # search_params[f"('{col}', '{key}')"] = value
167+ # else:
168+ # search_params = self.search_params[type(imputer).__name__]
169+ # else:
170+ # search_params = self.search_params[type(imputer).__name__]
171+
172+ list_spaces = utils .get_search_space (search_params )
160173
161- else :
162- search_space = None
163174 try :
164- dict_errors [name ] = self .evaluate_errors_sample (tested_model , df , search_space )
175+ dict_errors [name ] = self .evaluate_errors_sample (imputer , df , list_spaces )
165176 except Exception as excp :
166- print ("Error while testing " , type (tested_model ).__name__ )
177+ print ("Error while testing " , type (imputer ).__name__ )
167178 raise excp
168179
169180 df_errors = pd .DataFrame (dict_errors )
0 commit comments