1- import logging
21from functools import partial
3- from typing import Any , Dict , List , Optional , Union
2+ from typing import Any , Callable , Dict , List , Optional
43
54import numpy as np
65import pandas as pd
@@ -21,15 +20,15 @@ class Comparator:
2120 list of column's names selected (all with at least one null value will be imputed)
2221 columnwise_evaluation : Optional[bool], optional
2322 whether the metric should be calculated column-wise or not, by default False
24- search_params : Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
23+ dict_config_opti : Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2524 dictionary of search space for each implementation method. By default, the value is set to
2625 {}.
2726 n_calls_opt: int = 10
2827 number of calls of the optimization algorithm
2928 10.
3029 """
3130
32- dict_metrics : Dict [str , Any ] = {
31+ dict_metrics : Dict [str , Callable ] = {
3332 "mse" : metrics .mean_squared_error ,
3433 "rmse" : metrics .root_mean_squared_error ,
3534 "mae" : metrics .mean_absolute_error ,
@@ -50,14 +49,14 @@ def __init__(
5049 selected_columns : List [str ],
5150 generator_holes : _HoleGenerator ,
5251 metrics : List = ["mae" , "wmape" , "KL_columnwise" ],
53- search_params : Optional [Dict [str , Dict [ str , Union [ float , int , str ]] ]] = {},
52+ dict_config_opti : Optional [Dict [str , Any ]] = {},
5453 n_calls_opt : int = 10 ,
5554 ):
5655 self .dict_imputers = dict_models
5756 self .selected_columns = selected_columns
5857 self .generator_holes = generator_holes
5958 self .metrics = metrics
60- self .search_params = search_params
59+ self .dict_config_opti = dict_config_opti
6160 self .n_calls_opt = n_calls_opt
6261
6362 def get_errors (
@@ -92,7 +91,7 @@ def evaluate_errors_sample(
9291 self ,
9392 imputer : Any ,
9493 df : pd .DataFrame ,
95- list_spaces : List [ Dict ] = [] ,
94+ dict_config_opti_imputer : Dict [ str , Any ] = {} ,
9695 ) -> pd .Series :
9796 """Evaluate the errors in the cross-validation
9897
@@ -102,7 +101,7 @@ def evaluate_errors_sample(
102101 imputation model
103102 df : pd.DataFrame
104103 dataframe to impute
105- search_space : Dict
104+ dict_config_opti_imputer : Dict
106105 search space for tested_model's hyperparameters
107106
108107 Returns
@@ -115,12 +114,10 @@ def evaluate_errors_sample(
115114 for df_mask in self .generator_holes .split (df_origin ):
116115 df_corrupted = df_origin .copy ()
117116 df_corrupted [df_mask ] = np .nan
118-
119- assert not np .logical_and (df_mask , df_origin .isna ()).any ().any ()
120- if list_spaces :
117+ if dict_config_opti_imputer :
121118 cv = cross_validation .CrossValidation (
122119 imputer ,
123- list_spaces = list_spaces ,
120+ dict_config_opti_imputer = dict_config_opti_imputer ,
124121 hole_generator = self .generator_holes ,
125122 n_calls = self .n_calls_opt ,
126123 )
@@ -155,12 +152,12 @@ def compare(
155152 dict_errors = {}
156153
157154 for name , imputer in self .dict_imputers .items ():
158- search_params = self .search_params .get (name , {})
159-
160- list_spaces = utils .get_search_space (search_params )
155+ dict_config_opti_imputer = self .dict_config_opti .get (name , {})
161156
162157 try :
163- dict_errors [name ] = self .evaluate_errors_sample (imputer , df , list_spaces )
158+ dict_errors [name ] = self .evaluate_errors_sample (
159+ imputer , df , dict_config_opti_imputer
160+ )
164161 print (f"Tested model: { type (imputer ).__name__ } " )
165162 except Exception as excp :
166163 print ("Error while testing " , type (imputer ).__name__ )
0 commit comments