Skip to content

Commit 48a9d8b

Browse files
committed
fix: benchmark.md, comparator, imputers
1 parent c301382 commit 48a9d8b

File tree

4 files changed

+86
-79
lines changed

4 files changed

+86
-79
lines changed

examples/benchmark.md

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jupyter:
1010
kernelspec:
1111
display_name: env_qolmat_dev
1212
language: python
13-
name: python3
13+
name: env_qolmat_dev
1414
---
1515

1616
**This notebook aims to present the Qolmat repo through an example of a multivariate time series.
@@ -121,27 +121,27 @@ Some methods require hyperparameters. The user can directly specify them, or rat
121121
In pratice, we rely on a cross validation to find the best hyperparams values minimizing an error reconstruction.
122122

123123
```python
124-
imputer_mean = imputers.ImputerMean()
125-
imputer_median = imputers.ImputerMedian()
126-
imputer_mode = imputers.ImputerMode()
127-
imputer_locf = imputers.ImputerLOCF()
128-
imputer_nocb = imputers.ImputerNOCB()
129-
imputer_interpol = imputers.ImputerInterpolation(method="linear")
130-
imputer_spline = imputers.ImputerInterpolation(method="spline", order=2)
131-
imputer_shuffle = imputers.ImputerShuffle()
132-
imputer_residuals = imputers.ImputerResiduals(period=7, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
124+
imputer_mean = imputers.ImputerMean(groups=["station"])
125+
imputer_median = imputers.ImputerMedian(groups=["station"])
126+
imputer_mode = imputers.ImputerMode(groups=["station"])
127+
imputer_locf = imputers.ImputerLOCF(groups=["station"])
128+
imputer_nocb = imputers.ImputerNOCB(groups=["station"])
129+
imputer_interpol = imputers.ImputerInterpolation(groups=["station"], method="linear")
130+
imputer_spline = imputers.ImputerInterpolation(groups=["station"], method="spline", order=2)
131+
imputer_shuffle = imputers.ImputerShuffle(groups=["station"])
132+
imputer_residuals = imputers.ImputerResiduals(groups=["station"], period=7, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
133133

134-
imputer_rpca = imputers.ImputerRPCA(columnwise=True, period=7, max_iter=200, tau=2, lam=.3)
135-
imputer_rpca_opti = imputers.ImputerRPCA(columnwise=True, period=7, max_iter=100)
134+
imputer_rpca = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=7, max_iter=200, tau=2, lam=.3)
135+
imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=7, max_iter=100)
136136

137-
imputer_ou = imputers.ImputerEM(model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
138-
imputer_tsou = imputers.ImputerEM(model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
139-
imputer_tsmle = imputers.ImputerEM(model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
137+
imputer_ou = imputers.ImputerEM(groups=["station"], model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
138+
imputer_tsou = imputers.ImputerEM(groups=["station"], model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
139+
imputer_tsmle = imputers.ImputerEM(groups=["station"], model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
140140

141141

142-
imputer_knn = imputers.ImputerKNN(k=10)
143-
imputer_mice = imputers.ImputerMICE(estimator=LinearRegression(), sample_posterior=False, max_iter=100, missing_values=np.nan)
144-
imputer_regressor = imputers.ImputerRegressor(estimator=LinearRegression())
142+
imputer_knn = imputers.ImputerKNN(groups=["station"], k=10)
143+
imputer_mice = imputers.ImputerMICE(groups=["station"], estimator=LinearRegression(), sample_posterior=False, max_iter=100, missing_values=np.nan)
144+
imputer_regressor = imputers.ImputerRegressor(groups=["station"], estimator=LinearRegression())
145145

146146
dict_imputers = {
147147
"mean": imputer_mean,
@@ -197,7 +197,7 @@ comparison = comparator.Comparator(
197197
n_calls_opt=10,
198198
dict_config_opti=dict_config_opti,
199199
)
200-
results = comparison.compare(df_data, groups=["station"])
200+
results = comparison.compare(df_data)
201201
results
202202
```
203203

@@ -229,7 +229,7 @@ df_plot = df_data[cols_to_impute]
229229
```
230230

231231
```python
232-
dfs_imputed = {name: imp.fit_transform(df_plot, groups=["station"]) for name, imp in dict_imputers.items()}
232+
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}
233233
```
234234

235235
```python
@@ -293,7 +293,7 @@ for i_col, col in enumerate(df_plot):
293293
ax.xaxis.set_major_locator(loc)
294294
ax.tick_params(axis='both', which='major')
295295
i_plot += 1
296-
plt.savefig("imputations_benchmark.png")
296+
plt.savefig("figures/imputations_benchmark.png")
297297
plt.show()
298298

299299
```
@@ -345,7 +345,7 @@ comparison = comparator.Comparator(
345345
n_calls_opt=10,
346346
dict_config_opti=dict_config_opti,
347347
)
348-
results = comparison.compare(df_data, groups=["station"])
348+
results = comparison.compare(df_data)
349349
results
350350
```
351351

@@ -358,7 +358,7 @@ plt.show()
358358

359359
```python
360360
df_plot = df_data
361-
dfs_imputed = {name: imp.fit_transform(df_plot, groups=["station"]) for name, imp in dict_imputers.items()}
361+
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}
362362
station = df_plot.index.get_level_values("station")[0]
363363
df_station = df_plot.loc[station]
364364
dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()}
@@ -412,7 +412,7 @@ for i_col, col in enumerate(df_plot):
412412
ax.xaxis.set_major_locator(loc)
413413
ax.tick_params(axis='both', which='major')
414414
i_plot += 1
415-
plt.savefig("imputations_benchmark.png")
415+
plt.savefig("figures/imputations_benchmark.png")
416416
plt.show()
417417
```
418418

@@ -462,7 +462,7 @@ for i_col, col in enumerate(df_plot):
462462
plt.plot(acf, color="black", lw=2, ls="--", label="original")
463463
plt.legend()
464464

465-
plt.savefig("acf.png")
465+
plt.savefig("figures/acf.png")
466466
plt.show()
467467

468468
```

qolmat/benchmark/comparator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def evaluate_errors_sample(
125125
imputer.hyperparams_optim = cv.optimize_hyperparams(df_corrupted)
126126
else:
127127
imputer.hyperparams_optim = {}
128-
df_imputed = imputer.fit_transform(df_corrupted, **kwargs)
128+
df_imputed = imputer.fit_transform(df_corrupted)
129129
subset = self.generator_holes.subset
130130
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
131131
list_errors.append(errors)

0 commit comments

Comments
 (0)