Skip to content

Commit 02cb469

Browse files
Merge pull request #60 from Quantmetry/review_dcor
Review dcor
2 parents 090b4ba + ffad4f6 commit 02cb469

File tree

6 files changed

+459
-259
lines changed

6 files changed

+459
-259
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
name: Unit test Qolmat
22

3-
on: [push, pull_request, workflow_dispatch]
3+
on:
4+
push:
5+
branches:
6+
-dev
7+
-main
8+
pull_request:
9+
workflow_dispatch:
410

511
jobs:
612
build-linux:

examples/benchmark.md

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGra
5151

5252
import sys
5353
from qolmat.benchmark import comparator, missing_patterns, hyperparameters
54-
from qolmat.benchmark.metrics import kl_divergence
5554
from qolmat.imputations import imputers
5655
from qolmat.utils import data, utils, plot
5756

@@ -153,7 +152,7 @@ imputer_regressor = imputers.ImputerRegressor(groups=("station",), estimator=Lin
153152
```
154153

155154
```python
156-
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=("station",), subset=cols_to_impute, ratio_masked=ratio_masked)
155+
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=1, groups=("station",), subset=cols_to_impute, ratio_masked=ratio_masked)
157156
```
158157

159158
```python
@@ -163,13 +162,13 @@ dict_imputers = {
163162
# "mode": imputer_mode,
164163
"interpolation": imputer_interpol,
165164
# "spline": imputer_spline,
166-
"shuffle": imputer_shuffle,
167-
# "residuals": imputer_residuals,
165+
# "shuffle": imputer_shuffle,
166+
"residuals": imputer_residuals,
168167
# "OU": imputer_ou,
169168
"TSOU": imputer_tsou,
170169
"TSMLE": imputer_tsmle,
171-
"RPCA": imputer_rpca,
172-
"RPCA_opti": imputer_rpca_opti,
170+
# "RPCA": imputer_rpca,
171+
# "RPCA_opti": imputer_rpca,
173172
# "RPCA_opticw": imputer_rpca_opti2,
174173
# "locf": imputer_locf,
175174
# "nocb": imputer_nocb,
@@ -193,11 +192,13 @@ Concretely, the comparator takes as input a dataframe to impute, a proportion of
193192
Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions.
194193

195194
```python
195+
metrics = ["mae", "wmape", "KL_columnwise", "KL_forest", "ks_test", "dist_corr_pattern"]
196+
# metrics = ["KL_forest"]
196197
comparison = comparator.Comparator(
197198
dict_imputers,
198199
cols_to_impute,
199200
generator_holes = generator_holes,
200-
metrics=["mae", "wmape", "KL_columnwise", "ks_test", "dist_corr_pattern"],
201+
metrics=metrics,
201202
max_evals=10,
202203
dict_config_opti=dict_config_opti,
203204
)
@@ -206,28 +207,13 @@ results
206207
```
207208

208209
```python
209-
df_plot = results.loc["KL_columnwise",'TEMP']
210-
plt.barh(df_plot.index, df_plot, color=tab10(0))
211-
plt.title('TEMP')
212-
plt.xlabel("KL")
213-
plt.show()
214-
215-
df_plot = results.loc["KL_columnwise",'PRES']
216-
plt.barh(df_plot.index, df_plot, color=tab10(0))
217-
plt.title('PRES')
218-
plt.xlabel("KL")
219-
plt.show()
220-
```
221-
222-
```python
223-
fig = plt.figure(figsize=(24, 8))
224-
fig.add_subplot(2, 1, 1)
225-
plot.multibar(results.loc["mae"], decimals=1)
226-
plt.ylabel("mae")
227-
228-
fig.add_subplot(2, 1, 2)
229-
plot.multibar(results.loc["dist_corr_pattern"], decimals=2)
230-
plt.ylabel("dist_corr_pattern")
210+
n_metrics = len(metrics)
211+
fig = plt.figure(figsize=(24, 4 * n_metrics))
212+
for i, metric in enumerate(metrics):
213+
fig.add_subplot(n_metrics, 1, i + 1)
214+
df = results.loc[metric]
215+
plot.multibar(df, decimals=2)
216+
plt.ylabel(metric)
231217

232218
plt.savefig("figures/imputations_benchmark_errors.png")
233219
plt.show()

qolmat/benchmark/comparator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ def get_errors(
6767
"""
6868
dict_errors = {}
6969
for name_metric in self.metrics:
70-
dict_errors[name_metric] = metrics.get_metric(name_metric)(
71-
df_origin, df_imputed, df_mask
72-
)
70+
fun_metric = metrics.get_metric(name_metric)
71+
dict_errors[name_metric] = fun_metric(df_origin, df_imputed, df_mask)
7372
errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
7473
return errors
7574

0 commit comments

Comments
 (0)