Skip to content

Commit c470767

Browse files
Julien RousselJulien Roussel
authored andcommitted
uniformization em
1 parent 2394c92 commit c470767

File tree

4 files changed

+191
-129
lines changed

4 files changed

+191
-129
lines changed

qolmat/imputations/em_sampler.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import pandas as pd
99
import scipy
1010
from numpy.typing import ArrayLike
11-
from sklearn.impute._base import _BaseImputer
11+
from sklearn.base import BaseEstimator, TransformerMixin
1212
from sklearn.preprocessing import StandardScaler
1313

1414
logger = logging.getLogger(__name__)
1515

1616

17-
def _gradient_conjugue(A: ArrayLike, X: ArrayLike, tol: float = 1e-6) -> ArrayLike:
17+
def _gradient_conjugue(A: ArrayLike, X: ArrayLike) -> ArrayLike:
1818
"""
1919
Minimize Tr(X.T AX) by imputing missing values.
2020
To this aim, we compute in parallel a gradient algorithm for each data.
@@ -25,8 +25,6 @@ def _gradient_conjugue(A: ArrayLike, X: ArrayLike, tol: float = 1e-6) -> ArrayLi
2525
A array
2626
X : ArrayLike
2727
X array
28-
tol : float, optional
29-
Tolerance, by default 1e-6
3028
3129
Returns
3230
-------
@@ -79,7 +77,7 @@ def invert_robust(M, epsilon=1e-2):
7977
return scipy.linalg.inv(Meps)
8078

8179

82-
class ImputeEM(BaseEstimator, TransformerMixin):
80+
class EM(BaseEstimator, TransformerMixin):
8381
"""
8482
Imputation of missing values using a multivariate Gaussian model through EM optimization and
8583
using a projected Ornstein-Uhlenbeck process.
@@ -131,7 +129,7 @@ def __init__(
131129
tolerance: Optional[float] = 1e-4,
132130
stagnation_threshold: Optional[float] = 5e-3,
133131
stagnation_loglik: Optional[float] = 2,
134-
) -> None:
132+
):
135133

136134
if strategy not in ["mle", "ou"]:
137135
raise Exception("strategy has to be 'mle' or 'ou'")
@@ -276,7 +274,7 @@ def transform(self, X: np.array) -> np.array:
276274
return X_transformed
277275

278276

279-
class ImputeMultiNormalEM(ImputeEM):
277+
class MultiNormalEM(EM):
280278
"""
281279
Imputation of missing values using a multivariate Gaussian model through EM optimization and
282280
using a projected Ornstein-Uhlenbeck process.
@@ -488,7 +486,7 @@ def _check_convergence(self) -> bool:
488486
return min_diff_reached or min_diff_stable or max_loglik
489487

490488

491-
class ImputeVAR1EM(ImputeEM):
489+
class VAR1EM(EM):
492490
"""
493491
Imputation of missing values using a vector autoregressive model through EM optimization and
494492
using a projected Ornstein-Uhlenbeck process.

qolmat/imputations/models.py

Lines changed: 110 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -482,118 +482,9 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
482482
)
483483
results = imputer.fit_transform(df)
484484
return pd.DataFrame(data=results, columns=df.columns, index=df.index)
485-
486-
487-
class ImputerRPCA(Imputer):
488-
"""
489-
This class implements the RPCA imputation
490-
491-
Parameters
492-
----------
493-
method : str
494-
name of the RPCA method:
495-
"PCP" for basic RPCA
496-
"temporal" for temporal RPCA, with regularisations
497-
"online" for online RPCA
498-
columnwise : bool
499-
for RPCA method to be applied columnwise (with reshaping of each column into an array)
500-
or to be applied directly on the dataframe. By default, the value is set to False.
501-
"""
502-
503-
def __init__(
504-
self,
505-
method: str = "temporal",
506-
groups: List[str] = [],
507-
columnwise: bool = False,
508-
**hyperparams
509-
) -> None:
510-
super().__init__(groups=groups, columnwise=columnwise, hyperparams=hyperparams)
511-
512-
self.method = method
513-
514-
def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
515-
"""
516-
Fit/transform to impute with RPCA methods
517-
518-
Parameters
519-
----------
520-
df : pd.DataFrame
521-
dataframe to impute
522-
523-
Returns
524-
-------
525-
pd.DataFrame
526-
imputed dataframe
527-
"""
528-
if not isinstance(df, pd.DataFrame):
529-
raise ValueError("Input has to be a pandas.DataFrame.")
530-
531-
if self.method == "PCP":
532-
rpca = RPCA(**self.hyperparams_element)
533-
elif self.method == "temporal":
534-
rpca = TemporalRPCA(**self.hyperparams_element)
535-
elif self.method == "onlinetemporal":
536-
rpca = OnlineTemporalRPCA(**self.hyperparams_element)
537-
538-
df_imputed = pd.DataFrame(rpca.fit_transform(X=df.values), index=df.index, columns=df.columns)
539-
540-
return df_imputed
541-
542485

543-
class ImputeEM(_BaseImputer):
544-
def __init__(
545-
self,
546-
strategy: Optional[str] = "mle",
547-
method: Optional[str] = "multinormal",
548-
max_iter_em: Optional[int] = 200,
549-
n_iter_ou: Optional[int] = 50,
550-
ampli: Optional[int] = 1,
551-
random_state: Optional[int] = 123,
552-
dt: Optional[float] = 2e-2,
553-
tolerance: Optional[float] = 1e-4,
554-
stagnation_threshold: Optional[float] = 5e-3,
555-
stagnation_loglik: Optional[float] = 2,
556-
):
557-
if method == "multinormal":
558-
self.model = em_sampler.ImputeMultiNormalEM(
559-
strategy=strategy,
560-
max_iter_em=max_iter_em,
561-
n_iter_ou=n_iter_ou,
562-
ampli=ampli,
563-
random_state=random_state,
564-
dt=dt,
565-
tolerance=tolerance,
566-
stagnation_threshold=stagnation_threshold,
567-
stagnation_loglik=stagnation_loglik,
568-
)
569-
elif method == "VAR1":
570-
self.model = em_sampler.ImputeVAR1EM(
571-
strategy=strategy,
572-
max_iter_em=max_iter_em,
573-
n_iter_ou=n_iter_ou,
574-
ampli=ampli,
575-
random_state=random_state,
576-
dt=dt,
577-
tolerance=tolerance,
578-
stagnation_threshold=stagnation_threshold,
579-
stagnation_loglik=stagnation_loglik,
580-
)
581-
else:
582-
raise ValueError("Strategy '{strategy}' is not handled by ImputeEM!")
583-
584-
def fit(self, df):
585-
X = df.values
586-
self.model.fit(X)
587-
return self
588486

589-
def transform(self, df):
590-
X = df.values
591-
X_transformed = self.model.transform(X)
592-
df_transformed = pd.DataFrame(X_transformed, columns=df.columns, index=df.index)
593-
return df_transformed
594-
595-
596-
class ImputeMICE(Imputer):
487+
class ImputerMICE(Imputer):
597488
"""
598489
This class implements an iterative imputer in the multivariate case.
599490
It imputes each Series within a DataFrame multiple times using an iteration of fits
@@ -728,7 +619,7 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
728619
hyperparams[hyperparam] = value
729620

730621
model = self.type_model(**hyperparams)
731-
622+
732623
if self.fit_on_nan:
733624
X = df.drop(columns=col)
734625
else:
@@ -802,3 +693,111 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.Series:
802693
df_imp.loc[is_na, col] = random_pred[is_na]
803694

804695
return df_imp
696+
697+
698+
class ImputerRPCA(Imputer):
699+
"""
700+
This class implements the RPCA imputation
701+
702+
Parameters
703+
----------
704+
method : str
705+
name of the RPCA method:
706+
"PCP" for basic RPCA
707+
"temporal" for temporal RPCA, with regularisations
708+
"online" for online RPCA
709+
columnwise : bool
710+
for RPCA method to be applied columnwise (with reshaping of each column into an array)
711+
or to be applied directly on the dataframe. By default, the value is set to False.
712+
"""
713+
714+
def __init__(
715+
self,
716+
method: str = "temporal",
717+
groups: List[str] = [],
718+
columnwise: bool = False,
719+
**hyperparams
720+
) -> None:
721+
super().__init__(groups=groups, columnwise=columnwise, hyperparams=hyperparams)
722+
723+
self.method = method
724+
725+
def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
726+
"""
727+
Fit/transform to impute with RPCA methods
728+
729+
Parameters
730+
----------
731+
df : pd.DataFrame
732+
dataframe to impute
733+
734+
Returns
735+
-------
736+
pd.DataFrame
737+
imputed dataframe
738+
"""
739+
if not isinstance(df, pd.DataFrame):
740+
raise ValueError("Input has to be a pandas.DataFrame.")
741+
742+
if self.method == "PCP":
743+
rpca = RPCA(**self.hyperparams_element)
744+
elif self.method == "temporal":
745+
rpca = TemporalRPCA(**self.hyperparams_element)
746+
elif self.method == "onlinetemporal":
747+
rpca = OnlineTemporalRPCA(**self.hyperparams_element)
748+
749+
df_imputed = pd.DataFrame(rpca.fit_transform(X=df.values), index=df.index, columns=df.columns)
750+
751+
return df_imputed
752+
753+
754+
class ImputeEM(Imputer):
755+
def __init__(
756+
self,
757+
groups: List[str]=[],
758+
method: Optional[str] = "multinormal",
759+
columnwise: bool=False,
760+
**hyperparams
761+
762+
):
763+
super().__init__(groups=groups, columnwise=columnwise, hyperparams=hyperparams)
764+
self.method = method
765+
# if method == "multinormal":
766+
# self.model = em_sampler.ImputeMultiNormalEM(
767+
# **hyperparams
768+
# )
769+
# elif method == "VAR1":
770+
# self.model = em_sampler.ImputeVAR1EM(
771+
# **hyperparams
772+
# )
773+
# else:
774+
# raise ValueError("Strategy '{strategy}' is not handled by ImputeEM!")
775+
776+
def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
777+
if self.method == "multinormal":
778+
model = em_sampler.MultiNormalEM(
779+
**self.hyperparams_element
780+
)
781+
elif self.method == "VAR1":
782+
model = em_sampler.VAR1EM(
783+
**self.hyperparams_element
784+
)
785+
else:
786+
raise ValueError("Strategy '{strategy}' is not handled by ImputeEM!")
787+
X = df.values
788+
model.fit(X)
789+
790+
X_transformed = model.transform(X)
791+
df_transformed = pd.DataFrame(X_transformed, columns=df.columns, index=df.index)
792+
return df_transformed
793+
794+
# def fit(self, df):
795+
# X = df.values
796+
# self.model.fit(X)
797+
# return self
798+
799+
# def transform(self, df):
800+
# X = df.values
801+
# X_transformed = self.model.transform(X)
802+
# df_transformed = pd.DataFrame(X_transformed, columns=df.columns, index=df.index)
803+
# return df_transformed

qolmat/notebooks/benchmark.md

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,22 +219,29 @@ results = comparison.compare(df_data)
219219
results
220220
```
221221

222-
### **IV. Comparison of methods**
223-
224222
```python
225-
df
223+
fig = plt.figure(figsize=(24, 4))
224+
plot.multibar(results.loc["mae"])
225+
plt.show()
226226
```
227227

228+
### **IV. Comparison of methods**
229+
230+
228231
We now run just one time each algorithm on the initial corrupted dataframe and compare the different performances through multiple analysis.
229232

230233
```python
231-
dfs_imputed = {name: imp.fit_transform(df_data) for name, imp in dict_models.items()}
234+
df_plot = df_data[["TEMP", "PRES"]]
235+
```
236+
237+
```python
238+
dfs_imputed = {name: imp.fit_transform(df_plot) for name, imp in dict_models.items()}
232239
```
233240

234241
```python
235242
station = "Aotizhongxin"
236-
df_station = df_data.loc[station]
237-
dfs_imputed_station = {name: df.loc[station] for name, df in dfs_imputed.items()}
243+
df_station = df_plot.loc[station]
244+
dfs_imputed_station = {name: df_plot.loc[station] for name, df_plot in dfs_imputed.items()}
238245
```
239246

240247
Let's look at the imputations.
@@ -243,12 +250,37 @@ Note here we didn't fit the hyperparams of the RPCA... results might be of poor
243250

244251
```python
245252
# palette = sns.color_palette("icefire", n_colors=len(dict_models))
246-
#palette = sns.color_palette("husl", 8)
253+
# palette = sns.color_palette("husl", 8)
247254
# sns.set_palette(palette)
248-
markers = ["o", "s", "D", "+", "P", ">", "^", "d"]
249-
colors = ["tab:red", "tab:blue", "tab:blue"]
255+
# markers = ["o", "s", "D", "+", "P", ">", "^", "d"]
250256

257+
for col in cols_to_impute:
258+
fig, ax = plt.subplots(figsize=(10, 3))
259+
values_orig = df_station[col]
260+
261+
plt.plot(values_orig, ".", color='black', label="original")
262+
#plt.plot(df.iloc[870:1000][col], markers[0], color='k', linestyle='-' , ms=3)
251263

264+
for ind, (name, model) in enumerate(list(dict_models.items())):
265+
values_imp = dfs_imputed_station[name][col].copy()
266+
values_imp[values_orig.notna()] = np.nan
267+
plt.plot(values_imp, ".", color=tab10(ind), label=name, alpha=1)
268+
plt.ylabel(col, fontsize=16)
269+
plt.legend(loc=[1, 0], fontsize=18)
270+
loc = plticker.MultipleLocator(base=2*365)
271+
ax.xaxis.set_major_locator(loc)
272+
ax.tick_params(axis='both', which='major', labelsize=17)
273+
plt.show()
274+
275+
```
276+
277+
```python
278+
# palette = sns.color_palette("icefire", n_colors=len(dict_models))
279+
# palette = sns.color_palette("husl", 8)
280+
# sns.set_palette(palette)
281+
# markers = ["o", "s", "D", "+", "P", ">", "^", "d"]
282+
283+
fig = plt.figure(figsize=(
252284
for col in cols_to_impute:
253285
fig, ax = plt.subplots(figsize=(10, 3))
254286
values_orig = df_station[col]
@@ -265,8 +297,8 @@ for col in cols_to_impute:
265297
loc = plticker.MultipleLocator(base=2*365)
266298
ax.xaxis.set_major_locator(loc)
267299
ax.tick_params(axis='both', which='major', labelsize=17)
268-
sns.despine()
269300
plt.show()
301+
270302
```
271303

272304
**IV.a. Covariance**

0 commit comments

Comments
 (0)