@@ -326,10 +326,16 @@ def fit(self, X: NDArray) -> Self:
326326 self .p = p
327327 self .fit_X (X )
328328 n1 , n2 = self .X .shape
329- aic = np .log (np .linalg .det (self .S )) + 2 * p * (n2 ** 2 ) / n1
329+ det = np .linalg .det (self .S )
330+ if abs (det ) < 1e-12 :
331+ aic = - np .inf
332+ else :
333+ aic = np .log (det ) + 2 * p * (n2 ** 2 ) / n1
330334 if len (aics ) > 0 and aic > aics [- 1 ]:
331335 break
332336 aics .append (aic )
337+ if aic == - np .inf :
338+ break
333339 self .p = int (np .argmin (aics ))
334340 self .fit_X (X )
335341
@@ -352,15 +358,15 @@ def transform(self, X: NDArray) -> NDArray:
352358 NDArray
353359 Final array after EM sampling.
354360 """
361+ mask_na = np .isnan (X )
362+
355363 # shape_original = X.shape
356364 if hash (X .tobytes ()) == self .hash_fit :
357365 X = self .X
358366 else :
359367 X = utils .prepare_data (X , self .period )
360368 X = utils .linear_interpolation (X )
361369
362- mask_na = np .isnan (X )
363-
364370 if self .method == "mle" :
365371 X_transformed = self ._maximize_likelihood (X , mask_na )
366372 elif self .method == "sample" :
@@ -664,14 +670,19 @@ class VARpEM(EM):
664670 Examples
665671 --------
666672 >>> import numpy as np
667- >>> import pandas as pd
668673 >>> from qolmat.imputations.em_sampler import VARpEM
669- >>> imputer = VARpEM(method="sample")
670- >>> X = pd.DataFrame(data=[[1, 1, 1, 1],
671- >>> [np.nan, np.nan, 3, 2],
672- >>> [1, 2, 2, 1], [2, 2, 2, 2]],
673- >>> columns=["var1", "var2", "var3", "var4"])
674+ >>> imputer = VARpEM(method="sample", random_state=11)
675+ >>> X = np.array([[1, 1, 1, 1],
676+ ... [np.nan, np.nan, 3, 2],
677+ ... [1, 2, 2, 1], [2, 2, 2, 2]])
674678 >>> imputer.fit_transform(X)
679+ EM converged after 9 iterations.
680+ EM converged after 20 iterations.
681+ EM converged after 13 iterations.
682+ array([[1. , 1. , 1. , 1. ],
683+ [1.17054054, 1.49986137, 3. , 2. ],
684+ [1. , 2. , 2. , 1. ],
685+ [2. , 2. , 2. , 2. ]])
675686 """
676687
677688 def __init__ (
@@ -837,6 +848,7 @@ def combine_parameters(self) -> None:
837848 stack_YY = np .stack (list_YY )
838849 self .YY = np .mean (stack_YY , axis = 0 )
839850 self .S = self .YY - self .ZY .T @ self .B - self .B .T @ self .ZY + self .B .T @ self .ZZ @ self .B
851+ self .S [self .S < 1e-12 ] = 0
840852 self .S_inv = np .linalg .pinv (self .S , rcond = 1e-10 )
841853
842854 def _check_convergence (self ) -> bool :
0 commit comments