@@ -165,111 +165,3 @@ def compare(
165165 df_errors = pd .DataFrame (dict_errors )
166166
167167 return df_errors
168-
169-
170- class ComparatorBasedPattern (Comparator ):
171- def __init__ (
172- self ,
173- dict_models : Dict [str , Any ],
174- selected_columns : List [str ],
175- generator_holes : _HoleGenerator ,
176- metrics : List = ["mae" , "wmape" , "KL" ],
177- search_params : Optional [Dict [str , Dict [str , Union [float , int , str ]]]] = {},
178- n_calls_opt : int = 10 ,
179- num_patterns : int = 5 ,
180- ):
181- super ().__init__ (
182- dict_models = dict_models ,
183- selected_columns = selected_columns ,
184- generator_holes = generator_holes ,
185- metrics = metrics ,
186- search_params = search_params ,
187- n_calls_opt = n_calls_opt ,
188- )
189-
190- self .num_patterns = num_patterns
191-
192- def evaluate_errors_sample (
193- self ,
194- imputer : Any ,
195- df : pd .DataFrame ,
196- list_spaces : List [Dict ] = [],
197- ) -> pd .Series :
198- """Evaluate the errors in the cross-validation
199-
200- Parameters
201- ----------
202- tested_model : any
203- imputation model
204- df : pd.DataFrame
205- dataframe to impute
206- search_space : Dict
207- search space for tested_model's hyperparameters
208-
209- Returns
210- -------
211- pd.DataFrame
212- DataFrame with the errors for each metric (in column) and at each fold (in index)
213- """
214-
215- list_errors = []
216- df_origin = df [self .selected_columns ].copy ()
217- dfs_pattern = self .get_df_based_pattern (df_origin )
218- weights = []
219- # Fit then split, or fit in split ?
220- self .generator_holes .fit (df_origin )
221- for df_pattern in dfs_pattern :
222- # Get all columns in pattern
223- cols_pattern = df_pattern .dropna (axis = 1 ).columns
224- for df_mask in self .generator_holes .split (df_pattern ):
225- weights .append (len (df_pattern ))
226- df_corrupted = df_pattern .copy ()
227- df_corrupted [df_mask ] = np .nan
228-
229- if list_spaces :
230- cv = cross_validation .CrossValidation (
231- imputer ,
232- list_spaces = list_spaces ,
233- hole_generator = self .generator_holes ,
234- n_calls = self .n_calls_opt ,
235- )
236- df_imputed = cv .fit_transform (df_corrupted )
237- else :
238- df_imputed = imputer .fit_transform (df_corrupted )
239-
240- subset = self .generator_holes .subset # columns selected
241- subset = [col for col in subset if col in cols_pattern ]
242- errors = self .get_errors (df_pattern [subset ], df_imputed [subset ], df_mask [subset ])
243- list_errors .append (errors )
244-
245- df_errors = pd .DataFrame (list_errors )
246- # Weighted errors
247- errors_mean = df_errors .apply (
248- lambda x : (x * np .array (weights )).sum () / np .sum (weights ), axis = 0
249- )
250- return errors_mean .sort_index ()
251-
252- def get_df_based_pattern (self , df : pd .DataFrame ) -> List [pd .DataFrame ]:
253- def get_pattern (row ):
254- list_col_pattern = [col for col in row .index .to_list () if row [col ] == True ]
255- if len (list_col_pattern ) == 0 :
256- return "_EMPTY_"
257- elif len (list_col_pattern ) == row .index .size :
258- return "_ALLNAN_"
259- else :
260- return "_" .join (list_col_pattern )
261-
262- df_isna = df .isna ().apply (lambda x : get_pattern (x ), axis = 1 ).to_frame (name = "pattern" )
263- df_isna_pattern = df_isna ["pattern" ].value_counts ()
264-
265- patterns = df_isna_pattern .index .to_list ()
266- patterns .remove ("_ALLNAN_" )
267- patterns .remove ("_EMPTY_" )
268-
269- dfs = []
270- for idx_pattern in range (min (len (patterns ), self .num_patterns )):
271- patterns_selected = ["_EMPTY_" ] + [patterns [idx_pattern ]]
272- df_pattern = df .loc [df_isna [df_isna ["pattern" ].isin (patterns_selected )].index ]
273- dfs .append (df_pattern )
274-
275- return dfs
0 commit comments