Skip to content

Commit a555d60

Browse files
author
Gsaes
committed
Merge branch 'dev' of https://github.com/Quantmetry/qolmat into dev_test_rpca
2 parents edf847c + 0f77bab commit a555d60

File tree

7 files changed

+853
-168
lines changed

7 files changed

+853
-168
lines changed

examples/benchmark.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ imputer_mice = imputers.ImputerMICE(groups=["station"], estimator=LinearRegressi
144144
imputer_regressor = imputers.ImputerRegressor(groups=["station"], estimator=LinearRegression())
145145

146146
dict_imputers = {
147-
# "mean": imputer_mean,
147+
"mean": imputer_mean,
148148
# "median": imputer_median,
149149
# "mode": imputer_mode,
150150
"interpolation": imputer_interpol,
@@ -160,7 +160,7 @@ dict_imputers = {
160160
# "nocb": imputer_nocb,
161161
# "knn": imputer_knn,
162162
"ols": imputer_regressor,
163-
"mice_ols": imputer_mice,
163+
# "mice_ols": imputer_mice,
164164
}
165165
n_imputers = len(dict_imputers)
166166

@@ -193,14 +193,24 @@ comparison = comparator.Comparator(
193193
dict_imputers,
194194
cols_to_impute,
195195
generator_holes = generator_holes,
196-
metrics=["mae", "wmape", "KL"],
196+
metrics=["mae", "wmape", "KL", "ks_test", "energy"],
197197
n_calls_opt=10,
198198
search_params=search_params,
199199
)
200200
results = comparison.compare(df_data)
201201
results
202202
```
203203

204+
```python
205+
df_plot
206+
```
207+
208+
```python
209+
df_plot = results.loc["energy", "All"]
210+
plt.bar(df_plot.index, df_plot, color=tab10(0))
211+
plt.show()
212+
```
213+
204214
```python
205215
fig = plt.figure(figsize=(24, 8))
206216
fig.add_subplot(2, 1, 1)

examples/metrics_usage.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ from sklearn.linear_model import LinearRegression
2828
from qolmat.utils import data, plot, utils
2929
from qolmat.imputations import imputers
3030
from qolmat.benchmark import comparator, missing_patterns
31-
from qolmat.benchmark.utils import wasser_distance, kl_divergence, frechet_distance
31+
from qolmat.benchmark.utils import wasser_distance_columnwise, kl_divergence, frechet_distance
3232
```
3333

3434
```python
@@ -106,7 +106,7 @@ ratio_masked = 0.1
106106
```python
107107
# Métriques
108108
metrics = {
109-
"wasser": wasser_distance,
109+
"wasserstein_columnwise": wasserstein_distance_columnwise,
110110
"KL": kl_divergence
111111
#"frechet": frechet_distance
112112
}

qolmat/benchmark/comparator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from functools import partial
23
from typing import Any, Dict, List, Optional, Union
34

45
import numpy as np
@@ -33,10 +34,14 @@ class Comparator:
3334
"rmse": metrics.root_mean_squared_error,
3435
"mae": metrics.mean_absolute_error,
3536
"wmape": metrics.weighted_mean_absolute_percentage_error,
36-
"wasser": metrics.wasser_distance,
37-
"KL": metrics.kl_divergence,
37+
"wasserstein_columnwise": partial(metrics.wasserstein_distance, method="columnwise"),
38+
"KL_columnwise": partial(metrics.kl_divergence, method="columnwise"),
39+
"KL_gaussian": partial(metrics.kl_divergence, method="gaussian"),
40+
"ks_test": metrics.kolmogorov_smirnov_test,
41+
"correlation_diff": metrics.mean_difference_correlation_matrix_numerical_features,
42+
"pairwise_dist": metrics.sum_pairwise_distances,
43+
"energy": metrics.sum_energy_distances,
3844
"frechet": metrics.frechet_distance,
39-
"energy": metrics.energy_dist,
4045
}
4146

4247
def __init__(

0 commit comments

Comments
 (0)