Skip to content

Commit 90abb5a

Browse files
author
vm-aifluence-jro
committed
benchmark working
1 parent cb9a10f commit 90abb5a

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

examples/benchmark.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ cols_to_impute = ["TEMP", "PRES"]
7373

7474
The dataset `Artificial` is designed to have a sum of a periodical signal, a white noise and some outliers.
7575

76-
```python tags=[]
77-
df_data
78-
```
79-
8076
```python
8177
# df_data = data.get_data_corrupted("Artificial", ratio_masked=.2, mean_size=10)
8278
# cols_to_impute = ["signal"]
8379
```
8480

81+
```python tags=[]
82+
df_data
83+
```
84+
8585
Let's take a look at variables to impute. We only consider a station, Aotizhongxin.
8686
Time series display seasonalities (roughly 12 months).
8787

@@ -164,7 +164,7 @@ dict_imputers = {
164164
}
165165
n_imputers = len(dict_imputers)
166166

167-
search_params = {
167+
dict_config_opti = {
168168
"RPCA_opti": {
169169
"tau": {"min": .5, "max": 5, "type":"Real"},
170170
"lam": {"min": .1, "max": 1, "type":"Real"},
@@ -195,7 +195,7 @@ comparison = comparator.Comparator(
195195
generator_holes = generator_holes,
196196
metrics=["mae", "wmape", "KL_columnwise", "ks_test", "energy"],
197197
n_calls_opt=10,
198-
search_params=search_params,
198+
dict_config_opti=dict_config_opti,
199199
)
200200
results = comparison.compare(df_data)
201201
results
@@ -343,7 +343,7 @@ comparison = comparator.Comparator(
343343
df_data.columns,
344344
generator_holes = generator_holes,
345345
n_calls_opt=10,
346-
search_params=search_params,
346+
dict_config_opti=dict_config_opti,
347347
)
348348
results = comparison.compare(df_data)
349349
results

qolmat/benchmark/comparator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Callable, Dict, List, Optional
33

44
import numpy as np
55
import pandas as pd
@@ -28,7 +28,7 @@ class Comparator:
2828
10.
2929
"""
3030

31-
dict_metrics: Dict[str, Any] = {
31+
dict_metrics: Dict[str, Callable] = {
3232
"mse": metrics.mean_squared_error,
3333
"rmse": metrics.root_mean_squared_error,
3434
"mae": metrics.mean_absolute_error,
@@ -101,7 +101,7 @@ def evaluate_errors_sample(
101101
imputation model
102102
df : pd.DataFrame
103103
dataframe to impute
104-
search_space : Dict
104+
dict_config_opti_imputer : Dict
105105
search space for tested_model's hyperparameters
106106
107107
Returns

0 commit comments

Comments
 (0)