1- from typing import Dict , List , Optional
1+ import logging
2+ from typing import Dict , List , Optional , Union
23
34import numpy as np
45import pandas as pd
@@ -22,31 +23,29 @@ class Comparator:
2223 search_params: Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2324 dictionary of search space for each implementation method. By default, the value is set to
2425 {}.
25- n_cv_calls : Optional[int] = 10
26- number of calls of the hyperparameters cross-validation. By default, the value is set to
26+ n_calls_opt : Optional[int] = 10
27+ number of calls of the optimization algorithm
2728 10.
2829 """
2930
3031 def __init__ (
3132 self ,
32- dict_models : Dict ,
33+ dict_models : Dict [ str , any ] ,
3334 selected_columns : List [str ],
3435 generator_holes : _HoleGenerator ,
35- columnwise_evaluation : Optional [bool ] = True ,
36- search_params : Optional [Dict ] = {},
37- n_cv_calls : Optional [int ] = 10 ,
36+ search_params : Optional [Dict [str , Dict [str , Union [float , int , str ]]]] = {},
37+ n_calls_opt : Optional [int ] = 10 ,
3838 ):
3939
4040 self .dict_models = dict_models
4141 self .selected_columns = selected_columns
4242 self .generator_holes = generator_holes
43- self .columnwise_evaluation = columnwise_evaluation
4443 self .search_params = search_params
45- self .n_cv_calls = n_cv_calls
44+ self .n_calls_opt = n_calls_opt
4645
4746 def get_errors (
4847 self , df_origin : pd .DataFrame , df_imputed : pd .DataFrame , df_mask : pd .DataFrame
49- ) -> float :
48+ ) -> pd . DataFrame :
5049 """Functions evaluating the reconstruction's quality
5150
5251 Parameters
@@ -75,27 +74,18 @@ def get_errors(
7574 df_origin [df_mask ],
7675 df_imputed [df_mask ],
7776 )
77+
7878 dict_errors ["kl" ] = utils .kl_divergence (
7979 df_origin [df_mask ],
8080 df_imputed [df_mask ],
8181 )
82- # if self.columnwise_evaluation:
83- # wd = utils.wasser_distance(
84- # df_origin,
85- # df_imputed,
86- # )
87- # if not self.columnwise_evaluation and df_origin.shape[1] > 1:
88- # frechet = utils.frechet_distance(
89- # df_origin,
90- # df_imputed,
91- # normalized=False,
92- # )
82+
9383 errors = pd .concat (dict_errors .values (), keys = dict_errors .keys ())
9484 return errors
9585
9686 def evaluate_errors_sample (
97- self , tested_model : any , df : pd .DataFrame , search_space : Optional [ dict ] = None
98- ) -> Dict :
87+ self , imputer : any , df : pd .DataFrame , list_spaces : List [ Dict ] = {}
88+ ) -> pd . Series :
9989 """Evaluate the errors in the cross-validation
10090
10191 Parameters
@@ -104,8 +94,8 @@ def evaluate_errors_sample(
10494 imputation model
10595 df : pd.DataFrame
10696 dataframe to impute
107- search_space : Optional[dict], optional
108- search space for tested_model's hyperparameters, by default None
97+ search_space : Dict
98+ search space for tested_model's hyperparameters
10999
110100 Returns
111101 -------
@@ -114,25 +104,30 @@ def evaluate_errors_sample(
114104 """
115105 list_errors = []
116106 df_origin = df [self .selected_columns ].copy ()
107+ if list_spaces :
108+ print ("Hyperparameter optimization" )
109+ print (list_spaces )
110+ else :
111+ print ("No hyperparameter optimization" )
117112 for df_mask in self .generator_holes .split (df_origin ):
118113 df_corrupted = df_origin .copy ()
119114 df_corrupted [df_mask ] = np .nan
120- if search_space is None :
121- df_imputed = tested_model .fit_transform (df_corrupted )
122- else :
115+ if list_spaces :
123116 cv = cross_validation .CrossValidation (
124- tested_model ,
125- search_space = search_space ,
117+ imputer ,
118+ list_spaces = list_spaces ,
126119 hole_generator = self .generator_holes ,
127- n_calls = self .n_cv_calls ,
120+ n_calls = self .n_calls_opt ,
128121 )
129122 df_imputed = cv .fit_transform (df_corrupted )
123+ else :
124+ df_imputed = imputer .fit_transform (df_corrupted )
130125
131126 subset = self .generator_holes .subset
132127 errors = self .get_errors (df_origin [subset ], df_imputed [subset ], df_mask [subset ])
133128 list_errors .append (errors )
134129 df_errors = pd .DataFrame (list_errors )
135- errors_mean = df_errors .mean ()
130+ errors_mean = df_errors .mean (axis = 0 )
136131
137132 return errors_mean
138133
@@ -151,13 +146,30 @@ def compare(self, df: pd.DataFrame, verbose: bool = True):
151146 """
152147
153148 dict_errors = {}
154- for name , tested_model in self .dict_models .items ():
155- if verbose :
156- print (type (tested_model ).__name__ )
157-
158- search_space = utils .get_search_space (tested_model , self .search_params )
159149
160- dict_errors [name ] = self .evaluate_errors_sample (tested_model , df , search_space )
150+ for name , imputer in self .dict_models .items ():
151+ print (f"Tested model: { type (imputer ).__name__ } " )
152+
153+ search_params = self .search_params .get (name , {})
154+
155+ # if imputer.columnwise:
156+ # if len(self.selected_columns) > 0:
157+ # search_params = {}
158+ # for col in self.selected_columns:
159+ # for key, value in self.search_params[type(imputer).__name__].items():
160+ # search_params[f"('{col}', '{key}')"] = value
161+ # else:
162+ # search_params = self.search_params[type(imputer).__name__]
163+ # else:
164+ # search_params = self.search_params[type(imputer).__name__]
165+
166+ list_spaces = utils .get_search_space (search_params )
167+
168+ try :
169+ dict_errors [name ] = self .evaluate_errors_sample (imputer , df , list_spaces )
170+ except Exception as excp :
171+ print ("Error while testing " , type (imputer ).__name__ )
172+ raise excp
161173
162174 df_errors = pd .DataFrame (dict_errors )
163175
0 commit comments