Skip to content

Commit bee7ff1

Browse files
Merge pull request #2 from Quantmetry/fix_rpca
Fix rpca
2 parents 5187f55 + c9d68de commit bee7ff1

File tree

19 files changed

+1567
-1337
lines changed

19 files changed

+1567
-1337
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
/figures
77
qolmat/notebooks/figures
88
qolmat/notebooks/*.ipynb
9+
qolmat/examples/*.ipynb
910
*.egg-info
1011
/dist
1112
/build

examples/1_timeSeries.ipynb

Lines changed: 277 additions & 0 deletions
Large diffs are not rendered by default.

examples/test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
import numpy as np
3+
import timesynth as ts # package for generating time series
4+
5+
import matplotlib.pyplot as plt
6+
7+
from qolmat.utils import plot
8+
from qolmat.imputations.rpca.pcp_rpca import PcpRPCA
9+
from qolmat.imputations.rpca.temporal_rpca import TemporalRPCA, OnlineTemporalRPCA
10+
np.random.seed(402)
11+
12+
################################################################################
13+
14+
time_sampler = ts.TimeSampler(stop_time=20)
15+
irregular_time_samples = time_sampler.sample_irregular_time(num_points=5_000, keep_percentage=100)
16+
sinusoid = ts.signals.Sinusoidal(frequency=2)
17+
white_noise = ts.noise.GaussianNoise(std=0.1)
18+
timeseries = ts.TimeSeries(sinusoid, noise_generator=white_noise)
19+
samples, signals, errors = timeseries.sample(irregular_time_samples)
20+
21+
n = len(samples)
22+
pc = 0.02
23+
indices_ano1 = np.random.choice(n, int(n*pc))
24+
samples[indices_ano1] = [np.random.uniform(low=2*np.min(samples), high=2*np.max(samples)) for i in range(int(n*pc))]
25+
indices = np.random.choice(n, int(n*pc))
26+
samples[indices] = np.nan
27+
28+
29+
################################################################################
30+
31+
time_sampler = ts.TimeSampler(stop_time=20)
32+
irregular_time_samples = time_sampler.sample_irregular_time(num_points=5_000, keep_percentage=100)
33+
sinusoid = ts.signals.Sinusoidal(frequency=3)
34+
white_noise = ts.noise.GaussianNoise(std=0)
35+
timeseries = ts.TimeSeries(sinusoid, noise_generator=white_noise)
36+
samples2, signals2, errors2 = timeseries.sample(irregular_time_samples)
37+
38+
n2 = len(samples2)
39+
indices_ano2 = np.random.choice(n2, int(n*pc))
40+
samples2[indices_ano2] = [np.random.uniform(low=2*np.min(samples2), high=2*np.max(samples2)) for i in range(int(n2*pc))]
41+
indices = np.random.choice(n2, int(n*pc))
42+
samples2[indices] = np.nan
43+
44+
samples += samples2
45+
signals += signals2
46+
errors += errors2
47+
48+
49+
50+
online_temp_rpca = OnlineTemporalRPCA(n_rows=25, tau=1, lam=0.3, list_periods=[20], list_etas=[0.01],
51+
burnin=0.2, online_list_etas=[0.3], nwin=20)
52+
X, A = online_temp_rpca.fit_transform(X=samples)
53+
plot.plot_sig
54+
nal([samples, X, A], style="matplotlib")
55+
len(samples)

qolmat/benchmark/comparator.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, List, Optional
1+
import logging
2+
from typing import Dict, List, Optional, Union
23

34
import numpy as np
45
import pandas as pd
@@ -22,31 +23,29 @@ class Comparator:
2223
search_params: Optional[Dict[str, Dict[str, Union[str, float, int]]]] = {}
2324
dictionary of search space for each implementation method. By default, the value is set to
2425
{}.
25-
n_cv_calls: Optional[int] = 10
26-
number of calls of the hyperparameters cross-validation. By default, the value is set to
26+
n_calls_opt: Optional[int] = 10
27+
number of calls of the optimization algorithm
2728
10.
2829
"""
2930

3031
def __init__(
3132
self,
32-
dict_models: Dict,
33+
dict_models: Dict[str, any],
3334
selected_columns: List[str],
3435
generator_holes: _HoleGenerator,
35-
columnwise_evaluation: Optional[bool] = True,
36-
search_params: Optional[Dict] = {},
37-
n_cv_calls: Optional[int] = 10,
36+
search_params: Optional[Dict[str, Dict[str, Union[float, int, str]]]] = {},
37+
n_calls_opt: Optional[int] = 10,
3838
):
3939

4040
self.dict_models = dict_models
4141
self.selected_columns = selected_columns
4242
self.generator_holes = generator_holes
43-
self.columnwise_evaluation = columnwise_evaluation
4443
self.search_params = search_params
45-
self.n_cv_calls = n_cv_calls
44+
self.n_calls_opt = n_calls_opt
4645

4746
def get_errors(
4847
self, df_origin: pd.DataFrame, df_imputed: pd.DataFrame, df_mask: pd.DataFrame
49-
) -> float:
48+
) -> pd.DataFrame:
5049
"""Functions evaluating the reconstruction's quality
5150
5251
Parameters
@@ -75,27 +74,18 @@ def get_errors(
7574
df_origin[df_mask],
7675
df_imputed[df_mask],
7776
)
77+
7878
dict_errors["kl"] = utils.kl_divergence(
7979
df_origin[df_mask],
8080
df_imputed[df_mask],
8181
)
82-
# if self.columnwise_evaluation:
83-
# wd = utils.wasser_distance(
84-
# df_origin,
85-
# df_imputed,
86-
# )
87-
# if not self.columnwise_evaluation and df_origin.shape[1] > 1:
88-
# frechet = utils.frechet_distance(
89-
# df_origin,
90-
# df_imputed,
91-
# normalized=False,
92-
# )
82+
9383
errors = pd.concat(dict_errors.values(), keys=dict_errors.keys())
9484
return errors
9585

9686
def evaluate_errors_sample(
97-
self, tested_model: any, df: pd.DataFrame, search_space: Optional[dict] = None
98-
) -> Dict:
87+
self, imputer: any, df: pd.DataFrame, list_spaces: List[Dict] = {}
88+
) -> pd.Series:
9989
"""Evaluate the errors in the cross-validation
10090
10191
Parameters
@@ -104,8 +94,8 @@ def evaluate_errors_sample(
10494
imputation model
10595
df : pd.DataFrame
10696
dataframe to impute
107-
search_space : Optional[dict], optional
108-
search space for tested_model's hyperparameters, by default None
97+
search_space : Dict
98+
search space for tested_model's hyperparameters
10999
110100
Returns
111101
-------
@@ -114,25 +104,30 @@ def evaluate_errors_sample(
114104
"""
115105
list_errors = []
116106
df_origin = df[self.selected_columns].copy()
107+
if list_spaces:
108+
print("Hyperparameter optimization")
109+
print(list_spaces)
110+
else:
111+
print("No hyperparameter optimization")
117112
for df_mask in self.generator_holes.split(df_origin):
118113
df_corrupted = df_origin.copy()
119114
df_corrupted[df_mask] = np.nan
120-
if search_space is None:
121-
df_imputed = tested_model.fit_transform(df_corrupted)
122-
else:
115+
if list_spaces:
123116
cv = cross_validation.CrossValidation(
124-
tested_model,
125-
search_space=search_space,
117+
imputer,
118+
list_spaces=list_spaces,
126119
hole_generator=self.generator_holes,
127-
n_calls=self.n_cv_calls,
120+
n_calls=self.n_calls_opt,
128121
)
129122
df_imputed = cv.fit_transform(df_corrupted)
123+
else:
124+
df_imputed = imputer.fit_transform(df_corrupted)
130125

131126
subset = self.generator_holes.subset
132127
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
133128
list_errors.append(errors)
134129
df_errors = pd.DataFrame(list_errors)
135-
errors_mean = df_errors.mean()
130+
errors_mean = df_errors.mean(axis=0)
136131

137132
return errors_mean
138133

@@ -151,13 +146,30 @@ def compare(self, df: pd.DataFrame, verbose: bool = True):
151146
"""
152147

153148
dict_errors = {}
154-
for name, tested_model in self.dict_models.items():
155-
if verbose:
156-
print(type(tested_model).__name__)
157-
158-
search_space = utils.get_search_space(tested_model, self.search_params)
159149

160-
dict_errors[name] = self.evaluate_errors_sample(tested_model, df, search_space)
150+
for name, imputer in self.dict_models.items():
151+
print(f"Tested model: {type(imputer).__name__}")
152+
153+
search_params = self.search_params.get(name, {})
154+
155+
# if imputer.columnwise:
156+
# if len(self.selected_columns) > 0:
157+
# search_params = {}
158+
# for col in self.selected_columns:
159+
# for key, value in self.search_params[type(imputer).__name__].items():
160+
# search_params[f"('{col}', '{key}')"] = value
161+
# else:
162+
# search_params = self.search_params[type(imputer).__name__]
163+
# else:
164+
# search_params = self.search_params[type(imputer).__name__]
165+
166+
list_spaces = utils.get_search_space(search_params)
167+
168+
try:
169+
dict_errors[name] = self.evaluate_errors_sample(imputer, df, list_spaces)
170+
except Exception as excp:
171+
print("Error while testing ", type(imputer).__name__)
172+
raise excp
161173

162174
df_errors = pd.DataFrame(dict_errors)
163175

0 commit comments

Comments
 (0)