Skip to content

Commit 17ebe83

Browse files
author
vm-aifluence-jro
committed
ImputeEM implemented in models.py, and TS MLE version added
1 parent f569cc1 commit 17ebe83

File tree

4 files changed

+137
-89
lines changed

4 files changed

+137
-89
lines changed

qolmat/benchmark/missing_patterns.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def get_sizes_max(values_isna: pd.Series) -> pd.Series:
4444

4545
class _HoleGenerator:
4646
"""
47-
This abstract class implements the generic method to generate masks according to law of missing values.
47+
This abstract class implements the generic method to generate masks according to law of missing
48+
values.
4849
4950
Parameters
5051
----------
@@ -192,7 +193,8 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
192193

193194

194195
class _SamplerHoleGenerator(_HoleGenerator):
195-
"""This abstract class implements a generic way to generate holes in a dataframe by sampling 1D hole size distributions.
196+
"""This abstract class implements a generic way to generate holes in a dataframe by sampling 1D
197+
hole size distributions.
196198
197199
Parameters
198200
----------

qolmat/imputations/em_sampler.py

Lines changed: 75 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
self.convergence_threshold = tolerance
149149
self.stagnation_threshold = stagnation_threshold
150150
self.stagnation_loglik = stagnation_loglik
151+
self.scaler = StandardScaler()
151152

152153
self.dict_criteria_stop = {}
153154

@@ -200,109 +201,79 @@ def _convert_numpy(self, X: ArrayLike) -> np.ndarray:
200201
def _check_convergence(self) -> bool:
201202
return False
202203

203-
def _maximize_likelihood(self, X: ArrayLike) -> ArrayLike:
204+
def fit(self, X: np.array):
204205
"""
205-
Get the argmax of a posterior distribution.
206+
Fit the statistical distribution with the input X array.
206207
207208
Parameters
208209
----------
209-
X : ArrayLike
210-
Input DataFrame.
211-
212-
Returns
213-
-------
214-
ArrayLike
215-
DataFrame with imputed values.
210+
X : np.array
211+
Numpy array to be imputed
216212
"""
217-
X_center = X - self.means[:, None]
218-
X_imputed = _gradient_conjugue(self.cov_inv, X_center)
219-
X_imputed = self.means[:, None] + X_imputed
220-
return X_imputed
221-
222-
def impute_em(self, X: ArrayLike) -> ArrayLike:
223-
"""Imputation via EM algorithm
224-
225-
Parameters
226-
----------
227-
X : ArrayLike
228-
array with missing values
213+
X = X.copy()
214+
self.hash_fit = hash(X.tobytes())
215+
if not isinstance(X, np.ndarray):
216+
raise AssertionError("Invalid type. X must be a np.ndarray.")
229217

230-
Returns
231-
-------
232-
X_transformed
233-
imputed array
234-
"""
218+
if X.shape[0] < 2:
219+
raise AssertionError("Invalid dimensions: X must be of dimension (n,m) with m>1.")
235220

236-
X_ = self._convert_numpy(X)
237-
if np.nansum(X_) == 0:
238-
return X_
221+
X = self.scaler.fit_transform(X)
222+
X = X.T
239223

240224
mask_na = np.isnan(X)
241225

242226
# first imputation
243-
X_transformed = self._linear_interpolation(X_)
227+
X_sample_last = self._linear_interpolation(X)
244228

245-
self.fit_distribution(X_transformed)
229+
self.fit_distribution(X_sample_last)
246230

247231
for iter_em in range(self.max_iter_em):
248232

249-
X_transformed = self._sample_ou(X_transformed, mask_na)
233+
X_sample_last = self._sample_ou(X_sample_last, mask_na)
250234

251235
if self._check_convergence():
252236
logger.info(f"EM converged after {iter_em} iterations.")
253237
break
254238

255-
if self.strategy == "mle":
256-
X_transformed = self._maximize_likelihood(X_)
257-
elif self.strategy == "ou":
258-
X_transformed = self._sample_ou(X_transformed, mask_na)
259-
260239
self.dict_criteria_stop = {key: [] for key in self.dict_criteria_stop}
240+
self.X_sample_last = X_sample_last
241+
return self
261242

262-
if np.all(np.isnan(X_transformed)):
263-
raise WarningMessage("Result contains NaN. This is a bug.")
264-
265-
return X_transformed
266-
267-
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
243+
def transform(self, X: np.array) -> np.array:
268244
"""
269-
Fit and impute input X array.
245+
Transform the input X array by imputing the missing values.
270246
271247
Parameters
272248
----------
273-
X : pd.DataFrame
274-
DataFrame to be imputed
249+
X : np.array
250+
Numpy array to be imputed
275251
276252
Returns
277253
-------
278254
ArrayLike
279255
Final array after EM sampling.
280256
"""
281-
if not ((isinstance(df, np.ndarray)) or (isinstance(df, pd.DataFrame))):
282-
raise AssertionError("Invalid type. X must be either pd.DataFrame or np.ndarray.")
283-
284-
if df.shape[1] < 2:
285-
raise AssertionError("Invalid dimensions: X must be of dimension (n,m) with m>1.")
286257

287-
X = df.values
258+
if hash(X.tobytes()) == self.hash_fit:
259+
X = self.X_sample_last
260+
else:
261+
X = self.scaler.transform(X)
262+
X = X.T
263+
X = self._linear_interpolation(X)
288264

289-
scaler = StandardScaler()
290-
X = scaler.fit_transform(X)
291-
X = X.T
292-
X = self.impute_em(X)
293-
X = X.T
294-
X = scaler.inverse_transform(X)
265+
if self.strategy == "mle":
266+
X_transformed = self._maximize_likelihood(X)
267+
elif self.strategy == "ou":
268+
mask_na = np.isnan(X)
269+
X_transformed = self._sample_ou(X, mask_na)
295270

296-
if np.isnan(np.sum(X)):
271+
if np.all(np.isnan(X_transformed)):
297272
raise WarningMessage("Result contains NaN. This is a bug.")
298273

299-
if isinstance(df, np.ndarray):
300-
return X
301-
elif isinstance(df, pd.DataFrame):
302-
return pd.DataFrame(X, index=df.index, columns=df.columns)
303-
304-
else:
305-
raise AssertionError("Invalid type. X must be either pd.DataFrame or np.ndarray.")
274+
X_transformed = X_transformed.T
275+
X_transformed = self.scaler.inverse_transform(X_transformed)
276+
return X_transformed
306277

307278

308279
class ImputeMultiNormalEM(ImputeEM): # type: ignore
@@ -372,18 +343,32 @@ def __init__(
372343
)
373344
self.tolerance = tolerance
374345

375-
# self.list_logliks = []
376-
# self.list_means = []
377-
# self.list_covs = []
378346
self.dict_criteria_stop = {"logliks": [], "means": [], "covs": []}
379347

380348
def fit_distribution(self, X):
381-
# first estimation of params
382349
self.means = np.mean(X, axis=1)
383350
self.cov = np.cov(X)
384-
385351
self.cov_inv = invert_robust(self.cov, epsilon=1e-2)
386352

353+
def _maximize_likelihood(self, X: ArrayLike) -> ArrayLike:
354+
"""
355+
Get the argmax of a posterior distribution.
356+
357+
Parameters
358+
----------
359+
X : ArrayLike
360+
Input DataFrame.
361+
362+
Returns
363+
-------
364+
ArrayLike
365+
DataFrame with imputed values.
366+
"""
367+
X_center = X - self.means[:, None]
368+
X_imputed = _gradient_conjugue(self.cov_inv, X_center)
369+
X_imputed = self.means[:, None] + X_imputed
370+
return X_imputed
371+
387372
def _sample_ou(
388373
self,
389374
X: ArrayLike,
@@ -465,10 +450,6 @@ def _check_convergence(self) -> bool:
465450
True/False if the algorithm has converged
466451
"""
467452

468-
# self.list_means.append(self.means)
469-
# self.list_covs.append(self.cov)
470-
# self.list_logliks.append(self.loglik)
471-
472453
list_means = self.dict_criteria_stop["means"]
473454
list_covs = self.dict_criteria_stop["covs"]
474455
list_logliks = self.dict_criteria_stop["logliks"]
@@ -602,11 +583,6 @@ def fit_distribution(self, X):
602583
self.fit_parameter_A(X)
603584
self.fit_parameter_omega(X)
604585

605-
# print("distribution fitted :")
606-
# print(self.A)
607-
# print(self.B)
608-
# print(self.omega)
609-
610586
def gradient_X_centered_loglik(self, Xc):
611587
Xc_back = np.roll(Xc, 1, axis=1)
612588
Xc_back[:, 0] = 0
@@ -616,6 +592,25 @@ def gradient_X_centered_loglik(self, Xc):
616592
Z_fore = Xc_fore - self.A @ Xc
617593
return -self.omega_inv @ Z_back + self.A.T @ self.omega_inv @ Z_fore
618594

595+
def _maximize_likelihood(self, X: ArrayLike, dt=1e-2) -> ArrayLike:
596+
"""
597+
Get the argmax of a posterior distribution.
598+
599+
Parameters
600+
----------
601+
X : ArrayLike
602+
Input numpy array.
603+
604+
Returns
605+
-------
606+
ArrayLike
607+
DataFrame with imputed values.
608+
"""
609+
Xc = X - self.B[:, None]
610+
for n_optim in range(1000):
611+
Xc += dt * self.gradient_X_centered_loglik(Xc)
612+
return Xc + self.B[:, None]
613+
619614
def _sample_ou(
620615
self,
621616
X: ArrayLike,

qolmat/imputations/models.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import missingpy
2020

2121
from qolmat.benchmark import utils
22+
from qolmat.imputations import em_sampler
2223
from qolmat.imputations.rpca.pcp_rpca import RPCA
2324
from qolmat.imputations.rpca.temporal_rpca import OnlineTemporalRPCA, TemporalRPCA
2425

@@ -561,6 +562,59 @@ def get_hyperparams(self) -> Dict[str, Union[str, float, int]]:
561562
}
562563

563564

565+
class ImputeEM(_BaseImputer):
566+
def __init__(
567+
self,
568+
strategy: Optional[str] = "mle",
569+
method: Optional[str] = "multinormal",
570+
max_iter_em: Optional[int] = 200,
571+
n_iter_ou: Optional[int] = 50,
572+
ampli: Optional[int] = 1,
573+
random_state: Optional[int] = 123,
574+
dt: Optional[float] = 2e-2,
575+
tolerance: Optional[float] = 1e-4,
576+
stagnation_threshold: Optional[float] = 5e-3,
577+
stagnation_loglik: Optional[float] = 2,
578+
):
579+
if method == "multinormal":
580+
self.model = em_sampler.ImputeMultiNormalEM(
581+
strategy=strategy,
582+
max_iter_em=max_iter_em,
583+
n_iter_ou=n_iter_ou,
584+
ampli=ampli,
585+
random_state=random_state,
586+
dt=dt,
587+
tolerance=tolerance,
588+
stagnation_threshold=stagnation_threshold,
589+
stagnation_loglik=stagnation_loglik,
590+
)
591+
elif method == "VAR1":
592+
self.model = em_sampler.ImputeVAR1EM(
593+
strategy=strategy,
594+
max_iter_em=max_iter_em,
595+
n_iter_ou=n_iter_ou,
596+
ampli=ampli,
597+
random_state=random_state,
598+
dt=dt,
599+
tolerance=tolerance,
600+
stagnation_threshold=stagnation_threshold,
601+
stagnation_loglik=stagnation_loglik,
602+
)
603+
else:
604+
raise ValueError("Strategy '{strategy}' is not handled by ImputeEM!")
605+
606+
def fit(self, df):
607+
X = df.values
608+
self.model.fit(X)
609+
return self
610+
611+
def transform(self, df):
612+
X = df.values
613+
X_transformed = self.model.transform(X)
614+
df_transformed = pd.DataFrame(X_transformed, columns=df.columns, index=df.index)
615+
return df_transformed
616+
617+
564618
class ImputeMICE(_BaseImputer):
565619
"""
566620
This class implements an iterative imputer in the multivariate case.

qolmat/notebooks/benchmark.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ from qolmat.benchmark import comparator, missing_patterns
5353
from qolmat.benchmark.utils import kl_divergence
5454
from qolmat.imputations import models
5555
from qolmat.utils import data, utils, plot
56-
from qolmat.imputations.em_sampler import ImputeMultiNormalEM, ImputeVAR1EM
5756
# from qolmat.drawing import display_bar_table
5857

5958
```
@@ -133,8 +132,9 @@ imputer_residuals = models.ImputeOnResiduals("additive", 7, "freq", "linear")
133132
imputer_rpca = models.ImputeRPCA(
134133
method="temporal", multivariate=False, **{"n_rows":7*4, "maxIter":1000, "tau":1, "lam":0.7}
135134
)
136-
imputer_ou = ImputeMultiNormalEM(max_iter_em=34, n_iter_ou=15, verbose=0, strategy="ou")
137-
imputer_tsou = ImputeVAR1EM(max_iter_em=34, n_iter_ou=15, verbose=0, strategy="ou")
135+
imputer_ou = models.ImputeEM(method="multinormal", max_iter_em=34, n_iter_ou=15, strategy="ou")
136+
imputer_tsou = models.ImputeEM(method="VAR1", strategy="ou", max_iter_em=34, n_iter_ou=15)
137+
imputer_tsmle = models.ImputeEM(method="VAR1", strategy="mle", max_iter_em=34, n_iter_ou=15)
138138
imputer_locf = models.ImputeLOCF()
139139
imputer_nocb = models.ImputeNOCB()
140140
imputer_knn = models.ImputeKNN(k=10)
@@ -157,6 +157,7 @@ dict_models = {
157157
#"iterative": imputer_iterative,
158158
"OU": imputer_ou,
159159
"TSOU": imputer_tsou,
160+
"TSMLE": imputer_tsmle,
160161
#"RPCA": imputer_rpca,
161162
}
162163
n_models = len(dict_models)
@@ -227,10 +228,6 @@ Let's look at the imputations.
227228
When the data is missing at random, imputation is easier. Missing block are more challenging.
228229
Note here we didn't fit the hyperparams of the RPCA... results might be of poor quality...
229230

230-
```python
231-
plt.scatter(df_station["TEMP"], df_station["PRES"])
232-
```
233-
234231
```python
235232
palette = sns.color_palette("icefire", n_colors=len(dict_models))
236233
#palette = sns.color_palette("husl", 8)

0 commit comments

Comments
 (0)