Skip to content

Commit 29c2052

Browse files
committed
feat: add the attribute groups Comparator.compare(); fix: benchmark notebook
1 parent 98d404f commit 29c2052

File tree

2 files changed

+29
-27
lines changed

2 files changed

+29
-27
lines changed

examples/benchmark.md

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jupyter:
1010
kernelspec:
1111
display_name: env_qolmat_dev
1212
language: python
13-
name: env_qolmat_dev
13+
name: python3
1414
---
1515

1616
**This notebook aims to present the Qolmat repo through an example of a multivariate time series.
@@ -121,27 +121,27 @@ Some methods require hyperparameters. The user can directly specify them, or rat
121121
In pratice, we rely on a cross validation to find the best hyperparams values minimizing an error reconstruction.
122122

123123
```python
124-
imputer_mean = imputers.ImputerMean(groups=["station"])
125-
imputer_median = imputers.ImputerMedian(groups=["station"])
126-
imputer_mode = imputers.ImputerMode(groups=["station"])
127-
imputer_locf = imputers.ImputerLOCF(groups=["station"])
128-
imputer_nocb = imputers.ImputerNOCB(groups=["station"])
129-
imputer_interpol = imputers.ImputerInterpolation(groups=["station"], method="linear")
130-
imputer_spline = imputers.ImputerInterpolation(groups=["station"], method="spline", order=2)
131-
imputer_shuffle = imputers.ImputerShuffle(groups=["station"])
132-
imputer_residuals = imputers.ImputerResiduals(groups=["station"], period=7, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
124+
imputer_mean = imputers.ImputerMean()
125+
imputer_median = imputers.ImputerMedian()
126+
imputer_mode = imputers.ImputerMode()
127+
imputer_locf = imputers.ImputerLOCF()
128+
imputer_nocb = imputers.ImputerNOCB()
129+
imputer_interpol = imputers.ImputerInterpolation(method="linear")
130+
imputer_spline = imputers.ImputerInterpolation(method="spline", order=2)
131+
imputer_shuffle = imputers.ImputerShuffle()
132+
imputer_residuals = imputers.ImputerResiduals(period=7, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
133133

134-
imputer_rpca = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=7, max_iter=200, tau=2, lam=.3)
135-
imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=7, max_iter=100)
134+
imputer_rpca = imputers.ImputerRPCA(columnwise=True, period=7, max_iter=200, tau=2, lam=.3)
135+
imputer_rpca_opti = imputers.ImputerRPCA(columnwise=True, period=7, max_iter=100)
136136

137-
imputer_ou = imputers.ImputerEM(groups=["station"], model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
138-
imputer_tsou = imputers.ImputerEM(groups=["station"], model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
139-
imputer_tsmle = imputers.ImputerEM(groups=["station"], model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
137+
imputer_ou = imputers.ImputerEM(model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
138+
imputer_tsou = imputers.ImputerEM(model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
139+
imputer_tsmle = imputers.ImputerEM(model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
140140

141141

142-
imputer_knn = imputers.ImputerKNN(groups=["station"], k=10)
143-
imputer_mice = imputers.ImputerMICE(groups=["station"], estimator=LinearRegression(), sample_posterior=False, max_iter=100, missing_values=np.nan)
144-
imputer_regressor = imputers.ImputerRegressor(groups=["station"], estimator=LinearRegression())
142+
imputer_knn = imputers.ImputerKNN(k=10)
143+
imputer_mice = imputers.ImputerMICE(estimator=LinearRegression(), sample_posterior=False, max_iter=100, missing_values=np.nan)
144+
imputer_regressor = imputers.ImputerRegressor(estimator=LinearRegression())
145145

146146
dict_imputers = {
147147
"mean": imputer_mean,
@@ -197,7 +197,7 @@ comparison = comparator.Comparator(
197197
n_calls_opt=10,
198198
dict_config_opti=dict_config_opti,
199199
)
200-
results = comparison.compare(df_data)
200+
results = comparison.compare(df_data, groups=["station"])
201201
results
202202
```
203203

@@ -229,7 +229,7 @@ df_plot = df_data[cols_to_impute]
229229
```
230230

231231
```python
232-
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}
232+
dfs_imputed = {name: imp.fit_transform(df_plot, groups=["station"]) for name, imp in dict_imputers.items()}
233233
```
234234

235235
```python
@@ -293,7 +293,7 @@ for i_col, col in enumerate(df_plot):
293293
ax.xaxis.set_major_locator(loc)
294294
ax.tick_params(axis='both', which='major')
295295
i_plot += 1
296-
plt.savefig("figures/imputations_benchmark.png")
296+
plt.savefig("imputations_benchmark.png")
297297
plt.show()
298298

299299
```
@@ -345,7 +345,7 @@ comparison = comparator.Comparator(
345345
n_calls_opt=10,
346346
dict_config_opti=dict_config_opti,
347347
)
348-
results = comparison.compare(df_data)
348+
results = comparison.compare(df_data, groups=["station"])
349349
results
350350
```
351351

@@ -358,7 +358,7 @@ plt.show()
358358

359359
```python
360360
df_plot = df_data
361-
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_imputers.items()}
361+
dfs_imputed = {name: imp.fit_transform(df_plot, groups=["station"]) for name, imp in dict_imputers.items()}
362362
station = df_plot.index.get_level_values("station")[0]
363363
df_station = df_plot.loc[station]
364364
dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()}
@@ -412,7 +412,7 @@ for i_col, col in enumerate(df_plot):
412412
ax.xaxis.set_major_locator(loc)
413413
ax.tick_params(axis='both', which='major')
414414
i_plot += 1
415-
plt.savefig("figures/imputations_benchmark.png")
415+
plt.savefig("imputations_benchmark.png")
416416
plt.show()
417417
```
418418

@@ -462,7 +462,7 @@ for i_col, col in enumerate(df_plot):
462462
plt.plot(acf, color="black", lw=2, ls="--", label="original")
463463
plt.legend()
464464

465-
plt.savefig("figures/acf.png")
465+
plt.savefig("acf.png")
466466
plt.show()
467467

468468
```

qolmat/benchmark/comparator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def evaluate_errors_sample(
9292
imputer: Any,
9393
df: pd.DataFrame,
9494
dict_config_opti_imputer: Dict[str, Any] = {},
95+
**kwargs,
9596
) -> pd.Series:
9697
"""Evaluate the errors in the cross-validation
9798
@@ -124,7 +125,7 @@ def evaluate_errors_sample(
124125
imputer.hyperparams_optim = cv.optimize_hyperparams(df_corrupted)
125126
else:
126127
imputer.hyperparams_optim = {}
127-
df_imputed = imputer.fit_transform(df_corrupted)
128+
df_imputed = imputer.fit_transform(df_corrupted, **kwargs)
128129
subset = self.generator_holes.subset
129130
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
130131
list_errors.append(errors)
@@ -136,6 +137,7 @@ def evaluate_errors_sample(
136137
def compare(
137138
self,
138139
df: pd.DataFrame,
140+
**kwargs,
139141
):
140142
"""Function to compare different imputation methods on dataframe df
141143
@@ -157,7 +159,7 @@ def compare(
157159

158160
try:
159161
dict_errors[name] = self.evaluate_errors_sample(
160-
imputer, df, dict_config_opti_imputer
162+
imputer, df, dict_config_opti_imputer, **kwargs
161163
)
162164
print(f"Tested model: {type(imputer).__name__}")
163165
except Exception as excp:

0 commit comments

Comments
 (0)