Skip to content

Commit 69c6495

Browse files
committed
feat: merge dev
2 parents 25cfd20 + 501c980 commit 69c6495

File tree

20 files changed

+495
-1122
lines changed

20 files changed

+495
-1122
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
exclude = .git,__pycache__,.vscode,tests
33
max-line-length=99
4-
ignore=E302,E305,W503,E203,E731,E402,E501,E266,E712,F401,F821
4+
ignore=E302,E305,W503,E203,E731,E402,E266,E712,F401,F821
55
indent-size = 4
66
per-file-ignores=
77
qolmat/imputations/imputers.py:F401

.github/workflows/test.yml

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,36 @@
11
name: Unit test Qolmat
22

3-
on: [push, pull_request,workflow_dispatch]
4-
3+
on: [push, pull_request, workflow_dispatch]
54

65
jobs:
76
build-linux:
87
runs-on: ${{matrix.os}}
98
strategy:
109
matrix:
11-
os: [ubuntu-latest,windows-latest]
12-
python-version: [3.8,3.9]
10+
os: [ubuntu-latest, windows-latest]
11+
python-version: [3.8, 3.9]
1312
defaults:
1413
run:
1514
shell: bash -l {0}
1615

1716
steps:
18-
- name: Git clone
19-
uses: actions/checkout@v3
20-
- name: Set up venv for ci
21-
uses: conda-incubator/setup-miniconda@v2
22-
with:
23-
python-version: ${{matrix.python-version}}
24-
channels: default, conda-forge
25-
- name: Lint with flake8
26-
run: |
27-
conda install flake8
28-
flake8
29-
- name: Test with pytest
30-
run: |
31-
conda install pytest
32-
#pytest
33-
echo you should uncomment pytest and delete this line
34-
- name: typing with mypy
35-
run: |
36-
#mypy qolmat
37-
echo you should uncomment mypy qolmat and delete this line
17+
- name: Git clone
18+
uses: actions/checkout@v3
19+
- name: Set up venv for ci
20+
uses: conda-incubator/setup-miniconda@v2
21+
with:
22+
python-version: ${{matrix.python-version}}
23+
channels: default, conda-forge
24+
- name: Lint with flake8
25+
run: |
26+
conda install flake8
27+
flake8
28+
- name: Test with pytest
29+
run: |
30+
conda install pytest
31+
#pytest
32+
echo you should uncomment pytest and delete this line
33+
- name: typing with mypy
34+
run: |
35+
#mypy qolmat
36+
echo you should uncomment mypy qolmat and delete this line

docs/examples/imputation_example.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Some methods take arguments. For instance, if we want to impute by the mean, we
5353
* Here, in the :class:`ImputerMean` , we specify :class:`groups=["datetime.dt.month", "datetime.dt.dayofweek"]`, which means the method will first use a groupby operation (via :class:`pd.DataFrame.groupby`) and then impute missing values with the mean of their corresponding group.
5454
* For the :class:`ImputeInterpolation`, the method can be anything supported by :class:`pd.Series.interpolate`; hence for :class:`spline` and :class:`polynomial`, we have to provide an :class:`order`.
5555
* For the :class:`ImputerRPCA`, we first need to specify the :class:`method`, i.e. :class:`PCP`, :class:`Temporal` or :class:`Online`. It is also mandatory to mention if we deal with multivariate or not. Finally, there is a set of hyperparameters that can be specify. See the doc "Focus on RPCA" for more information.
56-
* For the :class:`ImputerEM`, we can specify the maximum number of iterations or the strategy used, i.e. "sample" or "argmax" (By default, "sample"). See the doc "Focus on EM Sampler" for more information.
56+
* For the :class:`ImputerEM`, we can specify the maximum number of iterations or the model used, i.e. "sample" or "mle" (By default, "sample"). See the doc "Focus on EM Sampler" for more information.
5757
* For the :class:`ImputerIterative`, we can specify the regression model to use, with its own hyperparameters.
5858
* For the :class:`ImputerRegressor`, we can specify the regression model to use, with its own hyperparameters as well as the name of the columns to impute.
5959

examples/benchmark.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGra
4848

4949
import sys
5050
from qolmat.benchmark import comparator, missing_patterns
51-
from qolmat.benchmark.utils import kl_divergence
51+
from qolmat.benchmark.metrics import kl_divergence
5252
from qolmat.imputations import imputers
5353
from qolmat.utils import data, utils, plot
5454
# from qolmat.drawing import display_bar_table
@@ -132,9 +132,9 @@ imputer_residuals = imputers.ImputerResiduals(groups=["station"], period=7, mode
132132
imputer_rpca = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=365, max_iter=200, tau=2, lam=.3)
133133
imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=365, max_iter=100)
134134

135-
imputer_ou = imputers.ImputerEM(groups=["station"], method="multinormal", strategy="ou", max_iter_em=34, n_iter_ou=15, dt=1e-3)
136-
imputer_tsou = imputers.ImputerEM(groups=["station"], method="VAR1", strategy="ou", max_iter_em=34, n_iter_ou=15, dt=1e-3)
137-
imputer_tsmle = imputers.ImputerEM(groups=["station"], method="VAR1", strategy="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
135+
imputer_ou = imputers.ImputerEM(groups=["station"], model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
136+
imputer_tsou = imputers.ImputerEM(groups=["station"], model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
137+
imputer_tsmle = imputers.ImputerEM(groups=["station"], model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15, dt=1e-3)
138138

139139

140140
imputer_knn = imputers.ImputerKNN(groups=["station"], k=10)
@@ -191,6 +191,7 @@ comparison = comparator.Comparator(
191191
dict_imputers,
192192
cols_to_impute,
193193
generator_holes = generator_holes,
194+
metrics=["mae", "wmape", "KL"],
194195
n_calls_opt=10,
195196
search_params=search_params,
196197
)
@@ -205,8 +206,8 @@ plot.multibar(results.loc["mae"], decimals=1)
205206
plt.ylabel("mae")
206207

207208
fig.add_subplot(2, 1, 2)
208-
plot.multibar(results.loc["energy"], decimals=1)
209-
plt.ylabel("energy")
209+
plot.multibar(results.loc["KL"], decimals=1)
210+
plt.ylabel("KL")
210211
plt.show()
211212
```
212213

examples/metrics_usage.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ imputer_residuals = imputers.ImputerResiduals(groups=["station"], period=7, mode
6262
imputer_rpca = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=365, max_iter=200, tau=2, lam=.3)
6363
imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=365, max_iter=100)
6464

65-
imputer_ou = imputers.ImputerEM(groups=["station"], method="multinormal", max_iter_em=34, n_iter_ou=15, strategy="ou")
66-
imputer_tsou = imputers.ImputerEM(groups=["station"], method="VAR1", strategy="ou", max_iter_em=34, n_iter_ou=15)
67-
imputer_tsmle = imputers.ImputerEM(groups=["station"], method="VAR1", strategy="mle", max_iter_em=34, n_iter_ou=15)
65+
imputer_ou = imputers.ImputerEM(groups=["station"], model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15)
66+
imputer_tsou = imputers.ImputerEM(groups=["station"], model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15)
67+
imputer_tsmle = imputers.ImputerEM(groups=["station"], model="VAR1", method="mle", max_iter_em=34, n_iter_ou=15)
6868

6969

7070
imputer_knn = imputers.ImputerKNN(groups=["station"], k=10)

qolmat/benchmark/comparator.py

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import numpy as np
55
import pandas as pd
66

7-
from qolmat.benchmark import cross_validation, utils
8-
from qolmat.benchmark import metrics as mtr
7+
from qolmat.benchmark import cross_validation, metrics, utils
98
from qolmat.benchmark.missing_patterns import _HoleGenerator
109

1110

@@ -30,30 +29,32 @@ class Comparator:
3029
"""
3130

3231
dict_metrics: Dict[str, Any] = {
33-
"mse": mtr.mean_squared_error,
34-
"rmse": mtr.root_mean_squared_error,
35-
"mae": mtr.mean_absolute_error,
36-
"wmape": mtr.weighted_mean_absolute_percentage_error,
37-
"wasser": mtr.wasser_distance,
38-
"KL": mtr.kl_divergence,
39-
"ks_test": mtr.kolmogorov_smirnov_test,
40-
"correlation_diff": mtr.mean_difference_correlation_matrix_numerical_features,
41-
"pairwise_dist": mtr.sum_pairwise_distances,
42-
"energy": mtr.sum_energy_distances,
43-
"frechet": mtr.frechet_distance,
32+
"mse": metrics.mean_squared_error,
33+
"rmse": metrics.root_mean_squared_error,
34+
"mae": metrics.mean_absolute_error,
35+
"wmape": metrics.weighted_mean_absolute_percentage_error,
36+
"wasser": metrics.wasser_distance,
37+
"KL": metrics.kl_divergence_columnwise,
38+
"ks_test": metrics.kolmogorov_smirnov_test,
39+
"correlation_diff": metrics.mean_difference_correlation_matrix_numerical_features,
40+
"pairwise_dist": metrics.sum_pairwise_distances,
41+
"energy": metrics.sum_energy_distances,
42+
"frechet": metrics.frechet_distance,
4443
}
4544

4645
def __init__(
4746
self,
4847
dict_models: Dict[str, Any],
4948
selected_columns: List[str],
5049
generator_holes: _HoleGenerator,
50+
metrics: List = ["mae", "wmape", "KL"],
5151
search_params: Optional[Dict[str, Dict[str, Union[float, int, str]]]] = {},
5252
n_calls_opt: int = 10,
5353
):
5454
self.dict_imputers = dict_models
5555
self.selected_columns = selected_columns
5656
self.generator_holes = generator_holes
57+
self.metrics = metrics
5758
self.search_params = search_params
5859
self.n_calls_opt = n_calls_opt
5960

@@ -62,8 +63,6 @@ def get_errors(
6263
df_origin: pd.DataFrame,
6364
df_imputed: pd.DataFrame,
6465
df_mask: pd.DataFrame,
65-
metrics: List = ["mae", "wmape", "kl"],
66-
on_mask=True,
6766
) -> pd.DataFrame:
6867
"""Functions evaluating the reconstruction's quality
6968
@@ -79,14 +78,11 @@ def get_errors(
7978
dictionary
8079
dictionay of results obtained via different metrics
8180
"""
82-
83-
# TODO comment comparer la distribution initiale et la distribution générée, pas la même taille,
84-
# ne fonctionne pas avec les métriques actuelles
85-
8681
dict_errors = {}
87-
for name_metric in metrics:
88-
dict_errors[name_metric] = Comparator.dict_metrics[name_metric](df_origin, df_imputed)
89-
82+
for name_metric in self.metrics:
83+
dict_errors[name_metric] = Comparator.dict_metrics[name_metric](
84+
df_origin, df_imputed, df_mask
85+
)
9086
errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
9187
return errors
9288

@@ -95,8 +91,6 @@ def evaluate_errors_sample(
9591
imputer: Any,
9692
df: pd.DataFrame,
9793
list_spaces: List[Dict] = [],
98-
metrics: List = ["mae", "wmape", "kl"],
99-
on_mask=True,
10094
) -> pd.Series:
10195
"""Evaluate the errors in the cross-validation
10296
@@ -114,7 +108,6 @@ def evaluate_errors_sample(
114108
pd.DataFrame
115109
DataFrame with the errors for each metric (in column) and at each fold (in index)
116110
"""
117-
118111
list_errors = []
119112
df_origin = df[self.selected_columns].copy()
120113
for df_mask in self.generator_holes.split(df_origin):
@@ -130,11 +123,8 @@ def evaluate_errors_sample(
130123
df_imputed = cv.fit_transform(df_corrupted)
131124
else:
132125
df_imputed = imputer.fit_transform(df_corrupted)
133-
134126
subset = self.generator_holes.subset
135-
errors = self.get_errors(
136-
df_origin[subset], df_imputed[subset], df_mask[subset], metrics, on_mask
137-
)
127+
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
138128
list_errors.append(errors)
139129
df_errors = pd.DataFrame(list_errors)
140130
errors_mean = df_errors.mean(axis=0)
@@ -144,9 +134,6 @@ def evaluate_errors_sample(
144134
def compare(
145135
self,
146136
df: pd.DataFrame,
147-
verbose: bool = True,
148-
metrics: List = ["mae", "wmape", "KL"],
149-
on_mask=True,
150137
):
151138
"""Function to compare different imputation methods on dataframe df
152139
@@ -164,15 +151,12 @@ def compare(
164151
dict_errors = {}
165152

166153
for name, imputer in self.dict_imputers.items():
167-
168154
search_params = self.search_params.get(name, {})
169155

170156
list_spaces = utils.get_search_space(search_params)
171157

172158
try:
173-
dict_errors[name] = self.evaluate_errors_sample(
174-
imputer, df, list_spaces, metrics, on_mask
175-
)
159+
dict_errors[name] = self.evaluate_errors_sample(imputer, df, list_spaces)
176160
print(f"Tested model: {type(imputer).__name__}")
177161
except Exception as excp:
178162
print("Error while testing ", type(imputer).__name__)
@@ -185,25 +169,12 @@ def compare(
185169

186170
class ComparatorBasedPattern(Comparator):
187171

188-
dict_metrics: Dict[str, Any] = {
189-
"mse": mtr.mean_squared_error,
190-
"rmse": mtr.root_mean_squared_error,
191-
"mae": mtr.mean_absolute_error,
192-
"wmape": mtr.weighted_mean_absolute_percentage_error,
193-
"wasser": mtr.wasser_distance,
194-
"KL": mtr.kl_divergence,
195-
"ks_test": mtr.kolmogorov_smirnov_test,
196-
"correlation_diff": mtr.mean_difference_correlation_matrix_numerical_features,
197-
"pairwise_dist": mtr.sum_pairwise_distances,
198-
"energy": mtr.sum_energy_distances,
199-
"frechet": mtr.frechet_distance,
200-
}
201-
202172
def __init__(
203173
self,
204174
dict_models: Dict[str, Any],
205175
selected_columns: List[str],
206176
generator_holes: _HoleGenerator,
177+
metrics: List = ["mae", "wmape", "KL"],
207178
search_params: Optional[Dict[str, Dict[str, Union[float, int, str]]]] = {},
208179
n_calls_opt: int = 10,
209180
num_patterns: int = 5,
@@ -212,6 +183,7 @@ def __init__(
212183
dict_models=dict_models,
213184
selected_columns=selected_columns,
214185
generator_holes=generator_holes,
186+
metrics=metrics,
215187
search_params=search_params,
216188
n_calls_opt=n_calls_opt,
217189
)
@@ -223,8 +195,6 @@ def evaluate_errors_sample(
223195
imputer: Any,
224196
df: pd.DataFrame,
225197
list_spaces: List[Dict] = [],
226-
metrics: List = ["mae", "wmape", "KL"],
227-
on_mask=True,
228198
) -> pd.Series:
229199
"""Evaluate the errors in the cross-validation
230200
@@ -270,9 +240,7 @@ def evaluate_errors_sample(
270240

271241
subset = self.generator_holes.subset # columns selected
272242
subset = [col for col in subset if col in cols_pattern]
273-
errors = self.get_errors(
274-
df_pattern[subset], df_imputed[subset], df_mask[subset], metrics, on_mask
275-
)
243+
errors = self.get_errors(df_pattern[subset], df_imputed[subset], df_mask[subset])
276244
list_errors.append(errors)
277245

278246
df_errors = pd.DataFrame(list_errors)

0 commit comments

Comments
 (0)