Skip to content

Commit 64aff6d

Browse files
Julien RousselJulien Roussel
authored andcommitted
cross_validation refacto into hyperparameters
1 parent 9b77d4e commit 64aff6d

File tree

18 files changed

+361
-456
lines changed

18 files changed

+361
-456
lines changed

environment.ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies:
99
- flake8
1010
- matplotlib
1111
- mypy
12-
- numpy==1.19
12+
- numpy
1313
- numpydoc
1414
- pytest
1515
- pytest-cov

environment.dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ dependencies:
99
- jupyter=1.0.0
1010
- jupyterlab=1.2.6
1111
- jupytext=1.14.4
12-
- numpy=1.21
12+
- hyperopt=0.2.7
13+
- numpy=1.24.4
1314
- packaging=23.1
1415
- pandas=2.0.1
1516
- python=3.8
1617
- pip=23.0.1
1718
- scipy=1.10.1
1819
- scikit-learn=1.2.2
19-
- scikit-optimize=0.9
2020
- sphinx=6.2.1
2121
- sphinx-gallery=0.13.0
2222
- sphinx_rtd_theme=1.2.0

examples/benchmark.md

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jupyter:
66
extension: .md
77
format_name: markdown
88
format_version: '1.3'
9-
jupytext_version: 1.14.5
9+
jupytext_version: 1.14.4
1010
kernelspec:
1111
display_name: env_qolmat_dev
1212
language: python
@@ -32,6 +32,8 @@ import pandas as pd
3232
from datetime import datetime
3333
import numpy as np
3434
import scipy
35+
import hyperopt as ho
36+
from hyperopt.pyll.base import Apply as hoApply
3537
np.random.seed(1234)
3638
import pprint
3739
from matplotlib import pyplot as plt
@@ -48,7 +50,7 @@ from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, HistGra
4850

4951

5052
import sys
51-
from qolmat.benchmark import comparator, missing_patterns
53+
from qolmat.benchmark import comparator, missing_patterns, hyperparameters
5254
from qolmat.benchmark.metrics import kl_divergence
5355
from qolmat.imputations import imputers
5456
from qolmat.utils import data, utils, plot
@@ -62,7 +64,8 @@ The dataset `Beijing` is the Beijing Multi-Site Air-Quality Data Set. It consist
6264
This dataset only contains numerical vairables.
6365

6466
```python
65-
df_data = data.get_data_corrupted("Beijing", ratio_masked=.2, mean_size=120)
67+
df_data = data.get_data_corrupted("Beijing_offline", ratio_masked=.2, mean_size=20)
68+
df_data = df_data.iloc[:256]
6669

6770
# cols_to_impute = ["TEMP", "PRES", "DEWP", "NO2", "CO", "O3", "WSPM"]
6871
# cols_to_impute = df_data.columns[df_data.isna().any()]
@@ -123,6 +126,10 @@ All presented methods are group-wise: here each station is imputed independently
123126
Some methods require hyperparameters. The user can directly specify them, or rather determine them through an optimization step using the `search_params` dictionary. The keys are the imputation method's name and the values are a dictionary specifying the minimum, maximum or list of categories and type of values (Integer, Real, Category or a dictionary indexed by the variable names) to search.
124127
In pratice, we rely on a cross validation to find the best hyperparams values minimizing an error reconstruction.
125128

129+
```python
130+
ratio_masked = 0.1
131+
```
132+
126133
```python
127134
imputer_mean = imputers.ImputerMean(groups=["station"])
128135
imputer_median = imputers.ImputerMedian(groups=["station"])
@@ -135,7 +142,6 @@ imputer_shuffle = imputers.ImputerShuffle(groups=["station"])
135142
imputer_residuals = imputers.ImputerResiduals(groups=["station"], period=365, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
136143

137144
imputer_rpca = imputers.ImputerRPCA(groups=["station"], columnwise=False, max_iter=256, tau=2, lam=1)
138-
# imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=True, period=7, max_iter=100)
139145

140146
imputer_ou = imputers.ImputerEM(groups=["station"], model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
141147
imputer_tsou = imputers.ImputerEM(groups=["station"], model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
@@ -145,7 +151,30 @@ imputer_tsmle = imputers.ImputerEM(groups=["station"], model="VAR1", method="mle
145151
imputer_knn = imputers.ImputerKNN(groups=["station"], k=10)
146152
imputer_mice = imputers.ImputerMICE(groups=["station"], estimator=LinearRegression(), sample_posterior=False, max_iter=100, missing_values=np.nan)
147153
imputer_regressor = imputers.ImputerRegressor(groups=["station"], estimator=LinearRegression())
154+
```
155+
156+
```python
157+
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=["station"], subset=cols_to_impute, ratio_masked=ratio_masked)
158+
```
159+
160+
```python
161+
dict_config_opti = {
162+
"tau": ho.hp.uniform("tau", low=.5, high=5),
163+
"lam": ho.hp.uniform("lam", low=.1, high=1),
164+
}
165+
imputer_rpca_opti = imputers.ImputerRPCA(groups=["station"], columnwise=False, max_iter=256)
166+
imputer_rpca_opti = hyperparameters.optimize(
167+
imputer_rpca_opti,
168+
df_data,
169+
generator = generator_holes,
170+
metric="mae",
171+
max_evals=10,
172+
dict_config_opti=dict_config_opti
173+
)
174+
# imputer_rpca_opti.params_optim = hyperparams_opti
175+
```
148176

177+
```python
149178
dict_imputers = {
150179
# "mean": imputer_mean,
151180
# "median": imputer_median,
@@ -158,23 +187,14 @@ dict_imputers = {
158187
"TSOU": imputer_tsou,
159188
"TSMLE": imputer_tsmle,
160189
"RPCA": imputer_rpca,
161-
# "RPCA_opti": imputer_rpca_opti,
190+
"RPCA_opti": imputer_rpca_opti,
162191
# "locf": imputer_locf,
163192
# "nocb": imputer_nocb,
164193
# "knn": imputer_knn,
165194
# "ols": imputer_regressor,
166195
# "mice_ols": imputer_mice,
167196
}
168197
n_imputers = len(dict_imputers)
169-
170-
dict_config_opti = {
171-
"RPCA_opti": {
172-
"tau": {"min": .5, "max": 5, "type":"Real"},
173-
"lam": {"min": .1, "max": 1, "type":"Real"},
174-
}
175-
}
176-
177-
ratio_masked = 0.1
178198
```
179199

180200
In order to compare the methods, we $i)$ artificially create missing data (for missing data mechanisms, see the docs); $ii)$ then impute it using the different methods chosen and $iii)$ calculate the reconstruction error. These three steps are repeated a number of times equal to `n_splits`. For each method, we calculate the average error and compare the final errors.
@@ -190,14 +210,12 @@ Concretely, the comparator takes as input a dataframe to impute, a proportion of
190210
Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions.
191211

192212
```python
193-
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=["station"], ratio_masked=ratio_masked)
194-
195213
comparison = comparator.Comparator(
196214
dict_imputers,
197215
cols_to_impute,
198216
generator_holes = generator_holes,
199217
metrics=["mae", "wmape", "KL_columnwise", "ks_test", "energy"],
200-
n_calls_opt=10,
218+
max_evals=10,
201219
dict_config_opti=dict_config_opti,
202220
)
203221
results = comparison.compare(df_data)

qolmat/benchmark/comparator.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from functools import partial
2-
from typing import Any, Callable, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
32

43
import numpy as np
54
import pandas as pd
65

7-
from qolmat.benchmark import cross_validation, metrics
6+
from qolmat.benchmark import hyperparameters, metrics
87
from qolmat.benchmark.missing_patterns import _HoleGenerator
98

109

@@ -23,41 +22,26 @@ class Comparator:
2322
dict_config_opti: Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2423
dictionary of search space for each implementation method. By default, the value is set to
2524
{}.
26-
n_calls_opt: int = 10
25+
max_evals: int = 10
2726
number of calls of the optimization algorithm
2827
10.
2928
"""
3029

31-
dict_metrics: Dict[str, Callable] = {
32-
"mse": metrics.mean_squared_error,
33-
"rmse": metrics.root_mean_squared_error,
34-
"mae": metrics.mean_absolute_error,
35-
"wmape": metrics.weighted_mean_absolute_percentage_error,
36-
"wasserstein_columnwise": partial(metrics.wasserstein_distance, method="columnwise"),
37-
"KL_columnwise": partial(metrics.kl_divergence, method="columnwise"),
38-
"KL_gaussian": partial(metrics.kl_divergence, method="gaussian"),
39-
"ks_test": metrics.kolmogorov_smirnov_test,
40-
"correlation_diff": metrics.mean_difference_correlation_matrix_numerical_features,
41-
"pairwise_dist": metrics.sum_pairwise_distances,
42-
"energy": metrics.sum_energy_distances,
43-
"frechet": metrics.frechet_distance,
44-
}
45-
4630
def __init__(
4731
self,
4832
dict_models: Dict[str, Any],
4933
selected_columns: List[str],
5034
generator_holes: _HoleGenerator,
5135
metrics: List = ["mae", "wmape", "KL_columnwise"],
5236
dict_config_opti: Optional[Dict[str, Any]] = {},
53-
n_calls_opt: int = 10,
37+
max_evals: int = 10,
5438
):
5539
self.dict_imputers = dict_models
5640
self.selected_columns = selected_columns
5741
self.generator_holes = generator_holes
5842
self.metrics = metrics
5943
self.dict_config_opti = dict_config_opti
60-
self.n_calls_opt = n_calls_opt
44+
self.max_evals = max_evals
6145

6246
def get_errors(
6347
self,
@@ -81,7 +65,7 @@ def get_errors(
8165
"""
8266
dict_errors = {}
8367
for name_metric in self.metrics:
84-
dict_errors[name_metric] = Comparator.dict_metrics[name_metric](
68+
dict_errors[name_metric] = metrics.get_metric(name_metric)(
8569
df_origin, df_imputed, df_mask
8670
)
8771
errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
@@ -114,17 +98,16 @@ def evaluate_errors_sample(
11498
for df_mask in self.generator_holes.split(df_origin):
11599
df_corrupted = df_origin.copy()
116100
df_corrupted[df_mask] = np.nan
117-
if dict_config_opti_imputer:
118-
cv = cross_validation.CrossValidation(
119-
imputer,
120-
dict_config_opti_imputer=dict_config_opti_imputer,
121-
hole_generator=self.generator_holes,
122-
n_calls=self.n_calls_opt,
123-
)
124-
imputer.hyperparams_optim = cv.optimize_hyperparams(df_corrupted)
125-
else:
126-
imputer.hyperparams_optim = {}
127-
df_imputed = imputer.fit_transform(df_corrupted)
101+
metric_optim = "mae"
102+
imputer_opti = hyperparameters.optimize(
103+
imputer,
104+
df,
105+
self.generator_holes,
106+
metric_optim,
107+
dict_config_opti_imputer,
108+
max_evals=self.max_evals,
109+
)
110+
df_imputed = imputer_opti.fit_transform(df_corrupted)
128111
subset = self.generator_holes.subset
129112
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
130113
list_errors.append(errors)

0 commit comments

Comments
 (0)