@@ -116,6 +116,8 @@ ratio_masked = 0.1
116116```
117117
118118``` python
119+ dict_config_opti = {}
120+
119121imputer_mean = imputers.ImputerMean(groups = (" station" ,))
120122imputer_median = imputers.ImputerMedian(groups = (" station" ,))
121123imputer_mode = imputers.ImputerMode(groups = (" station" ,))
@@ -127,6 +129,18 @@ imputer_shuffle = imputers.ImputerShuffle(groups=("station",))
127129imputer_residuals = imputers.ImputerResiduals(groups = (" station" ,), period = 365 , model_tsa = " additive" , extrapolate_trend = " freq" , method_interpolation = " linear" )
128130
129131imputer_rpca = imputers.ImputerRPCA(groups = (" station" ,), columnwise = False , max_iterations = 500 , tau = 2 , lam = 0.05 )
132+ imputer_rpca_opti = imputers.ImputerRPCA(groups = (" station" ,), columnwise = False , max_iterations = 256 )
133+ dict_config_opti[" RPCA_opti" ] = {
134+ " tau" : ho.hp.uniform(" tau" , low = .5 , high = 5 ),
135+ " lam" : ho.hp.uniform(" lam" , low = .1 , high = 1 ),
136+ }
137+ imputer_rpca_opticw = imputers.ImputerRPCA(groups = (" station" ,), columnwise = False , max_iterations = 256 )
138+ dict_config_opti[" RPCA_opticw" ] = {
139+ " tau/TEMP" : ho.hp.uniform(" tau/TEMP" , low = .5 , high = 5 ),
140+ " tau/PRES" : ho.hp.uniform(" tau/PRES" , low = .5 , high = 5 ),
141+ " lam/TEMP" : ho.hp.uniform(" lam/TEMP" , low = .1 , high = 1 ),
142+ " lam/PRES" : ho.hp.uniform(" lam/PRES" , low = .1 , high = 1 ),
143+ }
130144
131145imputer_ou = imputers.ImputerEM(groups = (" station" ,), model = " multinormal" , method = " sample" , max_iter_em = 34 , n_iter_ou = 15 , dt = 1e-3 )
132146imputer_tsou = imputers.ImputerEM(groups = (" station" ,), model = " VAR1" , method = " sample" , max_iter_em = 34 , n_iter_ou = 15 , dt = 1e-3 )
@@ -142,40 +156,6 @@ imputer_regressor = imputers.ImputerRegressor(groups=("station",), estimator=Lin
142156generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits = 2 , groups = (" station" ,), subset = cols_to_impute, ratio_masked = ratio_masked)
143157```
144158
145- ``` python
146- dict_config_opti = {
147- " tau" : ho.hp.uniform(" tau" , low = .5 , high = 5 ),
148- " lam" : ho.hp.uniform(" lam" , low = .1 , high = 1 ),
149- }
150- imputer_rpca_opti = imputers.ImputerRPCA(groups = (" station" ,), columnwise = False , max_iterations = 256 )
151- imputer_rpca_opti = hyperparameters.optimize(
152- imputer_rpca_opti,
153- df_data,
154- generator = generator_holes,
155- metric = " mae" ,
156- max_evals = 10 ,
157- dict_spaces = dict_config_opti
158- )
159- ```
160-
161- ``` python jupyter={"source_hidden": true}
162- dict_config_opti2 = {
163- " tau/TEMP" : ho.hp.uniform(" tau/TEMP" , low = .5 , high = 5 ),
164- " tau/PRES" : ho.hp.uniform(" tau/PRES" , low = .5 , high = 5 ),
165- " lam/TEMP" : ho.hp.uniform(" lam/TEMP" , low = .1 , high = 1 ),
166- " lam/PRES" : ho.hp.uniform(" lam/PRES" , low = .1 , high = 1 ),
167- }
168- imputer_rpca_opti2 = imputers.ImputerRPCA(groups = (" station" ,), columnwise = True , max_iterations = 256 )
169- imputer_rpca_opti2 = hyperparameters.optimize(
170- imputer_rpca_opti2,
171- df_data,
172- generator = generator_holes,
173- metric = " mae" ,
174- max_evals = 10 ,
175- dict_spaces = dict_config_opti2
176- )
177- ```
178-
179159``` python
180160dict_imputers = {
181161 " mean" : imputer_mean,
@@ -190,7 +170,7 @@ dict_imputers = {
190170 " TSMLE" : imputer_tsmle,
191171 " RPCA" : imputer_rpca,
192172 " RPCA_opti" : imputer_rpca_opti,
193- # "RPCA_opti2 ": imputer_rpca_opti2,
173+ # "RPCA_opticw ": imputer_rpca_opti2,
194174 # "locf": imputer_locf,
195175 # "nocb": imputer_nocb,
196176 # "knn": imputer_knn,
@@ -225,7 +205,7 @@ results = comparison.compare(df_data)
225205results
226206```
227207
228- ``` python jupyter={"source_hidden": true}
208+ ``` python
229209df_plot = results.loc[" KL_columnwise" ,' TEMP' ]
230210plt.barh(df_plot.index, df_plot, color = tab10(0 ))
231211plt.title(' TEMP' )
@@ -246,7 +226,7 @@ plot.multibar(results.loc["mae"], decimals=1)
246226plt.ylabel(" mae" )
247227
248228fig.add_subplot(2 , 1 , 2 )
249- plot.multibar(results.loc[" dist_corr_pattern" ], decimals = 1 )
229+ plot.multibar(results.loc[" dist_corr_pattern" ], decimals = 2 )
250230plt.ylabel(" dist_corr_pattern" )
251231
252232plt.savefig(" figures/imputations_benchmark_errors.png" )
0 commit comments