Skip to content

Commit 561ddbe

Browse files
Julien RousselJulien Roussel
authored andcommitted
warnings removed when optimizing with verbose parameter
1 parent 3701fb0 commit 561ddbe

File tree

11 files changed

+133
-93
lines changed

11 files changed

+133
-93
lines changed

examples/benchmark.md

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ ratio_masked = 0.1
116116
```
117117

118118
```python
119+
dict_config_opti = {}
120+
119121
imputer_mean = imputers.ImputerMean(groups=("station",))
120122
imputer_median = imputers.ImputerMedian(groups=("station",))
121123
imputer_mode = imputers.ImputerMode(groups=("station",))
@@ -127,6 +129,18 @@ imputer_shuffle = imputers.ImputerShuffle(groups=("station",))
127129
imputer_residuals = imputers.ImputerResiduals(groups=("station",), period=365, model_tsa="additive", extrapolate_trend="freq", method_interpolation="linear")
128130

129131
imputer_rpca = imputers.ImputerRPCA(groups=("station",), columnwise=False, max_iterations=500, tau=2, lam=0.05)
132+
imputer_rpca_opti = imputers.ImputerRPCA(groups=("station",), columnwise=False, max_iterations=256)
133+
dict_config_opti["RPCA_opti"] = {
134+
"tau": ho.hp.uniform("tau", low=.5, high=5),
135+
"lam": ho.hp.uniform("lam", low=.1, high=1),
136+
}
137+
imputer_rpca_opticw = imputers.ImputerRPCA(groups=("station",), columnwise=False, max_iterations=256)
138+
dict_config_opti["RPCA_opticw"] = {
139+
"tau/TEMP": ho.hp.uniform("tau/TEMP", low=.5, high=5),
140+
"tau/PRES": ho.hp.uniform("tau/PRES", low=.5, high=5),
141+
"lam/TEMP": ho.hp.uniform("lam/TEMP", low=.1, high=1),
142+
"lam/PRES": ho.hp.uniform("lam/PRES", low=.1, high=1),
143+
}
130144

131145
imputer_ou = imputers.ImputerEM(groups=("station",), model="multinormal", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
132146
imputer_tsou = imputers.ImputerEM(groups=("station",), model="VAR1", method="sample", max_iter_em=34, n_iter_ou=15, dt=1e-3)
@@ -142,40 +156,6 @@ imputer_regressor = imputers.ImputerRegressor(groups=("station",), estimator=Lin
142156
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=("station",), subset=cols_to_impute, ratio_masked=ratio_masked)
143157
```
144158

145-
```python
146-
dict_config_opti = {
147-
"tau": ho.hp.uniform("tau", low=.5, high=5),
148-
"lam": ho.hp.uniform("lam", low=.1, high=1),
149-
}
150-
imputer_rpca_opti = imputers.ImputerRPCA(groups=("station",), columnwise=False, max_iterations=256)
151-
imputer_rpca_opti = hyperparameters.optimize(
152-
imputer_rpca_opti,
153-
df_data,
154-
generator = generator_holes,
155-
metric="mae",
156-
max_evals=10,
157-
dict_spaces=dict_config_opti
158-
)
159-
```
160-
161-
```python jupyter={"source_hidden": true}
162-
dict_config_opti2 = {
163-
"tau/TEMP": ho.hp.uniform("tau/TEMP", low=.5, high=5),
164-
"tau/PRES": ho.hp.uniform("tau/PRES", low=.5, high=5),
165-
"lam/TEMP": ho.hp.uniform("lam/TEMP", low=.1, high=1),
166-
"lam/PRES": ho.hp.uniform("lam/PRES", low=.1, high=1),
167-
}
168-
imputer_rpca_opti2 = imputers.ImputerRPCA(groups=("station",), columnwise=True, max_iterations=256)
169-
imputer_rpca_opti2 = hyperparameters.optimize(
170-
imputer_rpca_opti2,
171-
df_data,
172-
generator = generator_holes,
173-
metric="mae",
174-
max_evals=10,
175-
dict_spaces=dict_config_opti2
176-
)
177-
```
178-
179159
```python
180160
dict_imputers = {
181161
"mean": imputer_mean,
@@ -190,7 +170,7 @@ dict_imputers = {
190170
"TSMLE": imputer_tsmle,
191171
"RPCA": imputer_rpca,
192172
"RPCA_opti": imputer_rpca_opti,
193-
# "RPCA_opti2": imputer_rpca_opti2,
173+
# "RPCA_opticw": imputer_rpca_opti2,
194174
# "locf": imputer_locf,
195175
# "nocb": imputer_nocb,
196176
# "knn": imputer_knn,
@@ -225,7 +205,7 @@ results = comparison.compare(df_data)
225205
results
226206
```
227207

228-
```python jupyter={"source_hidden": true}
208+
```python
229209
df_plot = results.loc["KL_columnwise",'TEMP']
230210
plt.barh(df_plot.index, df_plot, color=tab10(0))
231211
plt.title('TEMP')
@@ -246,7 +226,7 @@ plot.multibar(results.loc["mae"], decimals=1)
246226
plt.ylabel("mae")
247227

248228
fig.add_subplot(2, 1, 2)
249-
plot.multibar(results.loc["dist_corr_pattern"], decimals=1)
229+
plot.multibar(results.loc["dist_corr_pattern"], decimals=2)
250230
plt.ylabel("dist_corr_pattern")
251231

252232
plt.savefig("figures/imputations_benchmark_errors.png")

qolmat/benchmark/comparator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ def __init__(
3535
metrics: List = ["mae", "wmape", "KL_columnwise"],
3636
dict_config_opti: Optional[Dict[str, Any]] = {},
3737
max_evals: int = 10,
38+
verbose: bool = False,
3839
):
3940
self.dict_imputers = dict_models
4041
self.selected_columns = selected_columns
4142
self.generator_holes = generator_holes
4243
self.metrics = metrics
4344
self.dict_config_opti = dict_config_opti
4445
self.max_evals = max_evals
46+
self.verbose = verbose
4547

4648
def get_errors(
4749
self,
@@ -106,6 +108,7 @@ def evaluate_errors_sample(
106108
metric_optim,
107109
dict_config_opti_imputer,
108110
max_evals=self.max_evals,
111+
verbose=self.verbose,
109112
)
110113
df_imputed = imputer_opti.fit_transform(df_corrupted)
111114
subset = self.generator_holes.subset

qolmat/benchmark/hyperparameters.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,38 @@
1111
from qolmat.benchmark import metrics
1212

1313
from qolmat.benchmark.missing_patterns import _HoleGenerator
14+
from qolmat.imputations.imputers import _Imputer
15+
from qolmat.utils.utils import HyperValue
1416

15-
HyperValue = Union[int, float, str]
1617

17-
18-
def get_objective(imputer, df, generator, metric, names_hyperparams) -> Callable:
18+
def get_objective(
19+
imputer: _Imputer,
20+
df: pd.DataFrame,
21+
generator: _HoleGenerator,
22+
metric: str,
23+
names_hyperparams: List[str],
24+
) -> Callable:
1925
"""
20-
Define the objective function for the cross-validation
26+
Define the objective function, which is the average metric computed over the folds provided by
27+
the hole generator, using a cross-validation.
28+
29+
Parameters
30+
----------
31+
imputer: _Imputer
32+
Imputer that should be optimized, it should at least have a fit_transform method and an
33+
imputer_params attribute
34+
generator: _HoleGenerator
35+
Generator creating the masked values in the nested cross validation allowing to measure the
36+
imputer performance
37+
metric: str
38+
Metric used as perfomance indicator, common values are `mse` and `mae`
39+
names_hyperparams: List[str]
40+
List of the names of the hyperparameters which are being optimized
2141
2242
Returns
2343
-------
24-
_type_
25-
objective function
44+
Callable[List[HyperValue], float]
45+
Objective function
2646
"""
2747

2848
def fun_obf(args: List[HyperValue]) -> float:
@@ -47,32 +67,55 @@ def fun_obf(args: List[HyperValue]) -> float:
4767
return fun_obf
4868

4969

50-
def optimize(imputer, df, generator, metric, dict_spaces, max_evals=100):
51-
"""Optimize hyperparamaters
70+
def optimize(
71+
imputer: _Imputer,
72+
df: pd.DataFrame,
73+
generator: _HoleGenerator,
74+
metric: str,
75+
dict_config: Dict[str, HyperValue],
76+
max_evals: int = 100,
77+
verbose: bool = False,
78+
):
79+
"""Return the provided imputer with hyperparameters optimized in the provided range in order to
80+
minimize the provided metric.
5281
5382
Parameters
5483
----------
55-
df : pd.DataFrame
56-
DataFrame masked
84+
imputer: _Imputer
85+
Imputer that should be optimized, it should at least have a fit_transform method and an
86+
imputer_params attribute
87+
generator: _HoleGenerator
88+
Generator creating the masked values in the nested cross validation allowing to measure the
89+
imputer performance
90+
metric: str
91+
Metric used as perfomance indicator, common values are `mse` and `mae`
92+
dict_config: Dict[str, HyperValue]
93+
Search space for the tested hyperparameters
94+
max_evals: int
95+
Maximum number of evaluation of the performance of the algorithm. Each estimation involves
96+
one call to fit_transform per fold returned by the generator. See the n_fold attribute.
97+
verbose: bool
98+
Verbosity switch, usefull for imputers that can have unstable behavior for some
99+
hyperparameters values
57100
58101
Returns
59102
-------
60-
Dict[str, Any]
61-
hyperparameters optimize flat
103+
_Imputer
104+
Optimized imputer
62105
"""
63106
imputer = copy.deepcopy(imputer)
64-
if dict_spaces == {}:
107+
if dict_config == {}:
65108
return imputer
66-
names_hyperparams = list(dict_spaces.keys())
67-
values_hyperparams = list(dict_spaces.values())
68-
imputer.imputer_params = tuple(set(imputer.imputer_params) | set(dict_spaces.keys()))
109+
names_hyperparams = list(dict_config.keys())
110+
values_hyperparams = list(dict_config.values())
111+
imputer.imputer_params = tuple(set(imputer.imputer_params) | set(dict_config.keys()))
112+
if verbose and hasattr(imputer, "verbose"):
113+
setattr(imputer, "verbose", False)
69114
fun_obj = get_objective(imputer, df, generator, metric, names_hyperparams)
70115
hyperparams = ho.fmin(
71116
fn=fun_obj, space=values_hyperparams, algo=ho.tpe.suggest, max_evals=max_evals
72117
)
73118

74-
# hyperparams = deflat_hyperparams(hyperparams_flat)
75119
for key, value in hyperparams.items():
76120
setattr(imputer, key, value)
77-
# imputer.hyperparams = hyperparams
78121
return imputer

qolmat/benchmark/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,7 @@ def get_metric(name: str) -> Callable:
919919
"wasserstein_columnwise": partial(wasserstein_distance, method="columnwise"),
920920
"KL_columnwise": partial(kl_divergence, method="columnwise"),
921921
"KL_gaussian": partial(kl_divergence, method="gaussian"),
922+
"KL_forest": partial(kl_divergence, method="random_forest"),
922923
"ks_test": kolmogorov_smirnov_test,
923924
"correlation_diff": mean_difference_correlation_matrix_numerical_features,
924925
"pairwise_dist": sum_pairwise_distances,

qolmat/imputations/em_sampler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
stagnation_threshold: float = 5e-3,
114114
stagnation_loglik: float = 2,
115115
period: int = 1,
116+
verbose: bool = False,
116117
):
117118
if method not in ["mle", "sample"]:
118119
raise ValueError(f"`method` must be 'mle' or 'sample', provided value is '{method}'")
@@ -131,6 +132,7 @@ def __init__(
131132

132133
self.dict_criteria_stop: Dict[str, List] = {}
133134
self.period = period
135+
self.verbose = verbose
134136

135137
def _convert_numpy(self, X: NDArray) -> NDArray:
136138
"""
@@ -248,6 +250,8 @@ class MultiNormalEM(EM):
248250
dt : float
249251
Process integration time step, a large value increases the sample bias and can make
250252
the algorithm unstable, but compensates for a smaller n_iter_ou. By default, 2e-2.
253+
verbose: bool
254+
default `False`
251255
252256
Attributes
253257
----------
@@ -280,6 +284,7 @@ def __init__(
280284
stagnation_threshold: float = 5e-3,
281285
stagnation_loglik: float = 2,
282286
period: int = 1,
287+
verbose: bool = False,
283288
) -> None:
284289
super().__init__(
285290
method=method,
@@ -292,6 +297,7 @@ def __init__(
292297
stagnation_threshold=stagnation_threshold,
293298
stagnation_loglik=stagnation_loglik,
294299
period=period,
300+
verbose=verbose,
295301
)
296302
self.dict_criteria_stop = {"logliks": [], "means": [], "covs": []}
297303

@@ -473,6 +479,8 @@ class VAR1EM(EM):
473479
dt : float
474480
Process integration time step, a large value increases the sample bias and can make
475481
the algorithm unstable, but compensates for a smaller n_iter_ou. By default, 2e-2.
482+
verbose: bool
483+
default `False`
476484
477485
Attributes
478486
----------
@@ -505,6 +513,7 @@ def __init__(
505513
stagnation_threshold: float = 5e-3,
506514
stagnation_loglik: float = 2,
507515
period: int = 1,
516+
verbose: bool = False,
508517
) -> None:
509518
super().__init__(
510519
method=method,
@@ -517,6 +526,7 @@ def __init__(
517526
stagnation_threshold=stagnation_threshold,
518527
stagnation_loglik=stagnation_loglik,
519528
period=period,
529+
verbose=verbose,
520530
)
521531

522532
def fit_parameter_A(self, X):

0 commit comments

Comments
 (0)