1- from typing import Dict , List , Optional
1+ import logging
2+ from typing import Dict , List , Optional , Union
23
34import numpy as np
45import pandas as pd
@@ -22,33 +23,28 @@ class Comparator:
2223 search_params: Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2324 dictionary of search space for each implementation method. By default, the value is set to
2425 {}.
25- n_cv_calls : Optional[int] = 10
26- number of calls of the hyperparameters cross-validation. By default, the value is set to
26+ n_calls_opt : Optional[int] = 10
27+ number of calls of the optimization algorithm
2728 10.
2829 """
2930
3031 def __init__ (
3132 self ,
32- dict_models : Dict ,
33+ dict_models : Dict [ str , any ] ,
3334 selected_columns : List [str ],
3435 generator_holes : _HoleGenerator ,
35- columnwise_evaluation : Optional [bool ] = True ,
36- search_params : Optional [Dict ] = {},
37- n_cv_calls : Optional [int ] = 10 ,
36+ search_params : Optional [Dict [str , Dict [str , Union [float , int , str ]]]] = {},
37+ n_calls_opt : Optional [int ] = 10 ,
3838 ):
3939 self .dict_models = dict_models
4040 self .selected_columns = selected_columns
4141 self .generator_holes = generator_holes
42- self .columnwise_evaluation = columnwise_evaluation
4342 self .search_params = search_params
44- self .n_cv_calls = n_cv_calls
43+ self .n_calls_opt = n_calls_opt
4544
4645 def get_errors (
47- self ,
48- df_origin : pd .DataFrame ,
49- df_imputed : pd .DataFrame ,
50- df_mask : pd .DataFrame ,
51- ) -> float :
46+ self , df_origin : pd .DataFrame , df_imputed : pd .DataFrame , df_mask : pd .DataFrame
47+ ) -> pd .DataFrame :
5248 """Functions evaluating the reconstruction's quality
5349
5450 Parameters
@@ -77,30 +73,18 @@ def get_errors(
7773 df_origin [df_mask ],
7874 df_imputed [df_mask ],
7975 )
76+
8077 dict_errors ["kl" ] = utils .kl_divergence (
8178 df_origin [df_mask ],
8279 df_imputed [df_mask ],
8380 )
84- # if self.columnwise_evaluation:
85- # wd = utils.wasser_distance(
86- # df_origin,
87- # df_imputed,
88- # )
89- # if not self.columnwise_evaluation and df_origin.shape[1] > 1:
90- # frechet = utils.frechet_distance(
91- # df_origin,
92- # df_imputed,
93- # normalized=False,
94- # )
81+
9582 errors = pd .concat (dict_errors .values (), keys = dict_errors .keys ())
9683 return errors
9784
9885 def evaluate_errors_sample (
99- self ,
100- tested_model : any ,
101- df : pd .DataFrame ,
102- search_space : Optional [dict ] = None ,
103- ) -> Dict :
86+ self , imputer : any , df : pd .DataFrame , list_spaces : List [Dict ] = {}
87+ ) -> pd .Series :
10488 """Evaluate the errors in the cross-validation
10589
10690 Parameters
@@ -109,8 +93,8 @@ def evaluate_errors_sample(
10993 imputation model
11094 df : pd.DataFrame
11195 dataframe to impute
112- search_space : Optional[dict], optional
113- search space for tested_model's hyperparameters, by default None
96+ search_space : Dict
97+ search space for tested_model's hyperparameters
11498
11599 Returns
116100 -------
@@ -119,25 +103,30 @@ def evaluate_errors_sample(
119103 """
120104 list_errors = []
121105 df_origin = df [self .selected_columns ].copy ()
106+ if list_spaces :
107+ print ("Hyperparameter optimization" )
108+ print (list_spaces )
109+ else :
110+ print ("No hyperparameter optimization" )
122111 for df_mask in self .generator_holes .split (df_origin ):
123112 df_corrupted = df_origin .copy ()
124113 df_corrupted [df_mask ] = np .nan
125- if search_space is None :
126- df_imputed = tested_model .fit_transform (df_corrupted )
127- else :
114+ if list_spaces :
128115 cv = cross_validation .CrossValidation (
129- tested_model ,
130- search_space = search_space ,
116+ imputer ,
117+ list_spaces = list_spaces ,
131118 hole_generator = self .generator_holes ,
132- n_calls = self .n_cv_calls ,
119+ n_calls = self .n_calls_opt ,
133120 )
134121 df_imputed = cv .fit_transform (df_corrupted )
122+ else :
123+ df_imputed = imputer .fit_transform (df_corrupted )
135124
136125 subset = self .generator_holes .subset
137126 errors = self .get_errors (df_origin [subset ], df_imputed [subset ], df_mask [subset ])
138127 list_errors .append (errors )
139128 df_errors = pd .DataFrame (list_errors )
140- errors_mean = df_errors .mean ()
129+ errors_mean = df_errors .mean (axis = 0 )
141130
142131 return errors_mean
143132
@@ -156,13 +145,30 @@ def compare(self, df: pd.DataFrame, verbose: bool = True):
156145 """
157146
158147 dict_errors = {}
159- for name , tested_model in self .dict_models .items ():
160- if verbose :
161- print (type (tested_model ).__name__ )
162-
163- search_space = utils .get_search_space (tested_model , self .search_params )
164148
165- dict_errors [name ] = self .evaluate_errors_sample (tested_model , df , search_space )
149+ for name , imputer in self .dict_models .items ():
150+ print (f"Tested model: { type (imputer ).__name__ } " )
151+
152+ search_params = self .search_params .get (name , {})
153+
154+ # if imputer.columnwise:
155+ # if len(self.selected_columns) > 0:
156+ # search_params = {}
157+ # for col in self.selected_columns:
158+ # for key, value in self.search_params[type(imputer).__name__].items():
159+ # search_params[f"('{col}', '{key}')"] = value
160+ # else:
161+ # search_params = self.search_params[type(imputer).__name__]
162+ # else:
163+ # search_params = self.search_params[type(imputer).__name__]
164+
165+ list_spaces = utils .get_search_space (search_params )
166+
167+ try :
168+ dict_errors [name ] = self .evaluate_errors_sample (imputer , df , list_spaces )
169+ except Exception as excp :
170+ print ("Error while testing " , type (imputer ).__name__ )
171+ raise excp
166172
167173 df_errors = pd .DataFrame (dict_errors )
168174
0 commit comments