Skip to content

Commit edaf2d3

Browse files
Merge pull request #111 from Quantmetry/add_test_varpem_nonregression
new test added to check docstring non reproducibility
2 parents 3c50162 + 98c3c30 commit edaf2d3

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

qolmat/imputations/em_sampler.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,16 @@ def fit(self, X: NDArray) -> Self:
326326
self.p = p
327327
self.fit_X(X)
328328
n1, n2 = self.X.shape
329-
aic = np.log(np.linalg.det(self.S)) + 2 * p * (n2**2) / n1
329+
det = np.linalg.det(self.S)
330+
if abs(det) < 1e-12:
331+
aic = -np.inf
332+
else:
333+
aic = np.log(det) + 2 * p * (n2**2) / n1
330334
if len(aics) > 0 and aic > aics[-1]:
331335
break
332336
aics.append(aic)
337+
if aic == -np.inf:
338+
break
333339
self.p = int(np.argmin(aics))
334340
self.fit_X(X)
335341

@@ -352,15 +358,15 @@ def transform(self, X: NDArray) -> NDArray:
352358
NDArray
353359
Final array after EM sampling.
354360
"""
361+
mask_na = np.isnan(X)
362+
355363
# shape_original = X.shape
356364
if hash(X.tobytes()) == self.hash_fit:
357365
X = self.X
358366
else:
359367
X = utils.prepare_data(X, self.period)
360368
X = utils.linear_interpolation(X)
361369

362-
mask_na = np.isnan(X)
363-
364370
if self.method == "mle":
365371
X_transformed = self._maximize_likelihood(X, mask_na)
366372
elif self.method == "sample":
@@ -842,6 +848,7 @@ def combine_parameters(self) -> None:
842848
stack_YY = np.stack(list_YY)
843849
self.YY = np.mean(stack_YY, axis=0)
844850
self.S = self.YY - self.ZY.T @ self.B - self.B.T @ self.ZY + self.B.T @ self.ZZ @ self.B
851+
self.S[self.S < 1e-12] = 0
845852
self.S_inv = np.linalg.pinv(self.S, rcond=1e-10)
846853

847854
def _check_convergence(self) -> bool:

tests/imputations/test_em_sampler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,22 @@ def test_parameters_after_imputation_varpem(p: int):
306306
np.testing.assert_allclose(em.S, S, rtol=1e-1, atol=1e-1)
307307

308308

309+
def test_varpem_fit_transform():
310+
imputer = em_sampler.VARpEM(method="sample", random_state=11)
311+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
312+
result = imputer.fit_transform(X)
313+
expected = np.array(
314+
[
315+
[1.0, 1.0, 1.0, 1.0],
316+
[1.0, 1.5, 3.0, 2.0],
317+
[1.0, 2.0, 2.0, 1.0],
318+
[2.0, 2.0, 2.0, 2.0],
319+
]
320+
)
321+
np.testing.assert_allclose(result, expected, atol=1e-12)
322+
# assert False
323+
324+
309325
@pytest.mark.parametrize(
310326
"X, em, p",
311327
[(X_first_guess, em_sampler.MultiNormalEM(), 0), (X_first_guess, em_sampler.VARpEM(p=2), 2)],

tests/imputations/test_imputers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,10 @@ def test_ImputerEM_fit_transform(df: pd.DataFrame) -> None:
297297
expected = pd.DataFrame(
298298
{
299299
"col1": [i for i in range(20)],
300-
"col2": [0, 0.773, 2, 2.621, 2] + [i for i in range(5, 20)],
300+
"col2": [0, 0.638, 2, 2.714, 2] + [i for i in range(5, 20)],
301301
}
302302
)
303+
print(result)
303304
np.testing.assert_allclose(result, expected, atol=1e-2)
304305

305306

0 commit comments

Comments
 (0)