Skip to content

Commit 7baa379

Browse files
author
Gsaes
committed
Modification benchmark
1 parent 2652b8a commit 7baa379

File tree

1 file changed

+11
-53
lines changed

1 file changed

+11
-53
lines changed

examples/benchmark.md

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,30 +64,12 @@ The dataset `Beijing` is the Beijing Multi-Site Air-Quality Data Set. It consist
6464
This dataset only contains numerical vairables.
6565

6666
```python
67-
<<<<<<< HEAD
6867
df_data = data.get_data_corrupted("Beijing", ratio_masked=.2, mean_size=120)
69-
=======
70-
df_data = data.get_data_corrupted("Beijing_offline", ratio_masked=.2, mean_size=20)
71-
df_data = df_data.iloc[:256]
72-
73-
# cols_to_impute = ["TEMP", "PRES", "DEWP", "NO2", "CO", "O3", "WSPM"]
74-
# cols_to_impute = df_data.columns[df_data.isna().any()]
75-
>>>>>>> b88d3e213be41cbb3a3291cec94ff2aa312f48ed
7668
cols_to_impute = ["TEMP", "PRES"]
7769
```
7870

7971
The dataset `Artificial` is designed to have a sum of a periodical signal, a white noise and some outliers.
8072

81-
```python
82-
# df_data = data.get_data_corrupted("Artificial", ratio_masked=.2, mean_size=10)
83-
# cols_to_impute = ["signal"]
84-
```
85-
86-
```python
87-
# df_data = data.get_data("SNCF", n_groups_max=2)
88-
# cols_to_impute = ["val_in"]
89-
```
90-
9173
```python
9274
df_data
9375
```
@@ -186,15 +168,9 @@ dict_imputers = {
186168
# "spline": imputer_spline,
187169
"shuffle": imputer_shuffle,
188170
# "residuals": imputer_residuals,
189-
<<<<<<< HEAD
190-
"OU": imputer_ou,
191-
"TSOU": imputer_tsou,
192-
# "TSMLE": imputer_tsmle,
193-
=======
194171
# "OU": imputer_ou,
195172
"TSOU": imputer_tsou,
196173
"TSMLE": imputer_tsmle,
197-
>>>>>>> b88d3e213be41cbb3a3291cec94ff2aa312f48ed
198174
"RPCA": imputer_rpca,
199175
"RPCA_opti": imputer_rpca_opti,
200176
# "locf": imputer_locf,
@@ -223,23 +199,14 @@ comparison = comparator.Comparator(
223199
dict_imputers,
224200
cols_to_impute,
225201
generator_holes = generator_holes,
226-
<<<<<<< HEAD
227202
metrics=["mae", "wmape", "KL_columnwise", "ks_test"],
228-
n_calls_opt=10,
229-
=======
230-
metrics=["mae", "wmape", "KL_columnwise", "ks_test", "energy"],
231203
max_evals=10,
232-
>>>>>>> b88d3e213be41cbb3a3291cec94ff2aa312f48ed
233204
dict_config_opti=dict_config_opti,
234205
)
235206
results = comparison.compare(df_data)
236207
results
237208
```
238209

239-
```python
240-
results.loc["KL_columnwise"]
241-
```
242-
243210
```python
244211
df_plot = results.loc["KL_columnwise",'TEMP']
245212
plt.barh(df_plot.index, df_plot, color=tab10(0))
@@ -308,15 +275,19 @@ for col in cols_to_impute:
308275

309276
```
310277

278+
```python
279+
dfs_imputed
280+
```
281+
311282
```python
312283
# plot.plot_imputations(df_station, dfs_imputed_station)
313284

314-
n_columns = len(df_plot.columns)
285+
n_columns = len(cols_to_impute)
315286
n_imputers = len(dict_imputers)
316287

317288
fig = plt.figure(figsize=(12 * n_imputers, 4 * n_columns))
318289
i_plot = 1
319-
for i_col, col in enumerate(df_plot):
290+
for i_col, col in enumerate(cols_to_impute):
320291
for name_imputer, df_imp in dfs_imputed_station.items():
321292

322293
fig.add_subplot(n_columns, n_imputers, i_plot)
@@ -337,7 +308,7 @@ for i_col, col in enumerate(df_plot):
337308
loc = plticker.MultipleLocator(base=2*365)
338309
ax.xaxis.set_major_locator(loc)
339310
ax.tick_params(axis='both', which='major')
340-
plt.xlim(datetime(2019, 2, 1), datetime(2019, 3, 1))
311+
plt.xlim(datetime(2010, 1, 1), datetime(2015, 3, 1))
341312
i_plot += 1
342313
plt.savefig("figures/imputations_benchmark.png")
343314
plt.show()
@@ -379,33 +350,20 @@ dict_imputers["MLP"] = imputer_mlp = imputers_keras.ImputerRegressorKeras(estima
379350
```
380351

381352
We can re-run the imputation model benchmark as before.
382-
383-
<<<<<<< HEAD
384353
```python jupyter={"outputs_hidden": true} tags=[]
385-
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, subset = cols_to_impute, groups=['station'], ratio_masked=ratio_masked)
386-
=======
387-
```python
388-
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, subset = cols_to_impute, ratio_masked=ratio_masked)
389-
>>>>>>> b88d3e213be41cbb3a3291cec94ff2aa312f48ed
354+
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=["station"], subset=cols_to_impute, ratio_masked=ratio_masked)
390355

391356
comparison = comparator.Comparator(
392357
dict_imputers,
393-
df_data.columns,
358+
cols_to_impute,
394359
generator_holes = generator_holes,
395-
n_calls_opt=10,
360+
metrics=["mae", "wmape", "KL_columnwise", "ks_test"],
361+
max_evals=10,
396362
dict_config_opti=dict_config_opti,
397363
)
398364
results = comparison.compare(df_data)
399365
results
400366
```
401-
402-
```python
403-
fig = plt.figure(figsize=(24, 4))
404-
plot.multibar(results.loc["mae"], decimals=1)
405-
plt.ylabel("mae")
406-
plt.show()
407-
```
408-
409367
```python jupyter={"outputs_hidden": true, "source_hidden": true} tags=[]
410368
df_plot = df_data
411369
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}

0 commit comments

Comments
 (0)