@@ -51,7 +51,6 @@ from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGra
5151
5252import sys
5353from qolmat.benchmark import comparator, missing_patterns, hyperparameters
54- from qolmat.benchmark.metrics import kl_divergence
5554from qolmat.imputations import imputers
5655from qolmat.utils import data, utils, plot
5756
@@ -153,7 +152,7 @@ imputer_regressor = imputers.ImputerRegressor(groups=("station",), estimator=Lin
153152```
154153
155154``` python
156- generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits = 2 , groups = (" station" ,), subset = cols_to_impute, ratio_masked = ratio_masked)
155+ generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits = 1 , groups = (" station" ,), subset = cols_to_impute, ratio_masked = ratio_masked)
157156```
158157
159158``` python
@@ -163,13 +162,13 @@ dict_imputers = {
163162 # "mode": imputer_mode,
164163 " interpolation" : imputer_interpol,
165164 # "spline": imputer_spline,
166- " shuffle" : imputer_shuffle,
167- # "residuals": imputer_residuals,
165+ # "shuffle": imputer_shuffle,
166+ " residuals" : imputer_residuals,
168167 # "OU": imputer_ou,
169168 " TSOU" : imputer_tsou,
170169 " TSMLE" : imputer_tsmle,
171- " RPCA" : imputer_rpca,
172- " RPCA_opti" : imputer_rpca_opti ,
170+ # "RPCA": imputer_rpca,
171+ # "RPCA_opti": imputer_rpca ,
173172 # "RPCA_opticw": imputer_rpca_opti2,
174173 # "locf": imputer_locf,
175174 # "nocb": imputer_nocb,
@@ -193,11 +192,13 @@ Concretely, the comparator takes as input a dataframe to impute, a proportion of
193192Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions.
194193
195194``` python
195+ metrics = [" mae" , " wmape" , " KL_columnwise" , " KL_forest" , " ks_test" , " dist_corr_pattern" ]
196+ # metrics = ["KL_forest"]
196197comparison = comparator.Comparator(
197198 dict_imputers,
198199 cols_to_impute,
199200 generator_holes = generator_holes,
200- metrics = [ " mae " , " wmape " , " KL_columnwise " , " ks_test " , " dist_corr_pattern " ] ,
201+ metrics = metrics ,
201202 max_evals = 10 ,
202203 dict_config_opti = dict_config_opti,
203204)
@@ -206,28 +207,13 @@ results
206207```
207208
208209``` python
209- df_plot = results.loc[" KL_columnwise" ,' TEMP' ]
210- plt.barh(df_plot.index, df_plot, color = tab10(0 ))
211- plt.title(' TEMP' )
212- plt.xlabel(" KL" )
213- plt.show()
214-
215- df_plot = results.loc[" KL_columnwise" ,' PRES' ]
216- plt.barh(df_plot.index, df_plot, color = tab10(0 ))
217- plt.title(' PRES' )
218- plt.xlabel(" KL" )
219- plt.show()
220- ```
221-
222- ``` python
223- fig = plt.figure(figsize = (24 , 8 ))
224- fig.add_subplot(2 , 1 , 1 )
225- plot.multibar(results.loc[" mae" ], decimals = 1 )
226- plt.ylabel(" mae" )
227-
228- fig.add_subplot(2 , 1 , 2 )
229- plot.multibar(results.loc[" dist_corr_pattern" ], decimals = 2 )
230- plt.ylabel(" dist_corr_pattern" )
210+ n_metrics = len (metrics)
211+ fig = plt.figure(figsize = (24 , 4 * n_metrics))
212+ for i, metric in enumerate (metrics):
213+ fig.add_subplot(n_metrics, 1 , i + 1 )
214+ df = results.loc[metric]
215+ plot.multibar(df, decimals = 2 )
216+ plt.ylabel(metric)
231217
232218plt.savefig(" figures/imputations_benchmark_errors.png" )
233219plt.show()
0 commit comments