@@ -52,7 +52,7 @@ def get_errors(
5252 df_origin : pd .DataFrame ,
5353 df_imputed : pd .DataFrame ,
5454 df_mask : pd .DataFrame ,
55- ) -> pd .Series :
55+ ) -> pd .DataFrame :
5656 """Functions evaluating the reconstruction's quality
5757
5858 Parameters
@@ -64,15 +64,15 @@ def get_errors(
6464
6565 Returns
6666 -------
67- dictionary
68- dictionay of results obtained via different metrics
67+ pd.DataFrame
68+ DataFrame of results obtained via different metrics
6969 """
7070 dict_errors = {}
7171 for name_metric in self .metrics :
7272 fun_metric = metrics .get_metric (name_metric )
7373 dict_errors [name_metric ] = fun_metric (df_origin , df_imputed , df_mask )
74- errors = pd .concat (dict_errors .values (), keys = dict_errors .keys ())
75- return errors
74+ df_errors = pd .concat (dict_errors .values (), keys = dict_errors .keys ())
75+ return df_errors
7676
7777 def evaluate_errors_sample (
7878 self ,
@@ -96,8 +96,8 @@ def evaluate_errors_sample(
9696
9797 Returns
9898 -------
99- pd.DataFrame
100- DataFrame with the errors for each metric (in column) and at each fold (in index)
99+ pd.Series
100+ Series with the errors for each metric and each variable
101101 """
102102 list_errors = []
103103 df_origin = df [self .selected_columns ].copy ()
@@ -115,8 +115,12 @@ def evaluate_errors_sample(
115115 )
116116 df_imputed = imputer_opti .fit_transform (df_corrupted )
117117 subset = self .generator_holes .subset
118- errors = self .get_errors (df_origin [subset ], df_imputed [subset ], df_mask [subset ])
119- list_errors .append (errors )
118+ if subset is None :
119+ raise ValueError (
120+ "HoleGenerator `subset` should be overwritten in split but it is none!"
121+ )
122+ df_errors = self .get_errors (df_origin [subset ], df_imputed [subset ], df_mask [subset ])
123+ list_errors .append (df_errors )
120124 df_errors = pd .DataFrame (list_errors )
121125 errors_mean = df_errors .mean (axis = 0 )
122126
@@ -136,7 +140,8 @@ def compare(
136140 Returns
137141 -------
138142 pd.DataFrame
139- dataframe with imputation
143+ Dataframe with the metrics results, imputers are in columns and indices represent
144+ metrics and variables.
140145 """
141146
142147 dict_errors = {}
0 commit comments