Skip to content

Commit 94064d2

Browse files
Julien RousselJulien Roussel
authored andcommitted
gamma changed to X covariance, and tests added
1 parent 98c3c30 commit 94064d2

File tree

3 files changed

+45
-21
lines changed

3 files changed

+45
-21
lines changed

qolmat/imputations/em_sampler.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,26 @@ def _sample_ou(
268268
self.reset_learned_parameters()
269269
X_init = X.copy()
270270
gamma = self.get_gamma()
271+
print("gamma:")
272+
print(gamma)
271273
sqrt_gamma = np.real(spl.sqrtm(gamma))
272-
for _ in range(self.n_iter_ou):
274+
for i in range(self.n_iter_ou):
275+
print(f"Iteration #{i}")
273276
noise = self.ampli * self.rng.normal(0, 1, size=(n_variables, n_samples))
274277
grad_X = self.gradient_X_loglik(X_copy)
278+
print("grad")
279+
print(self.dt * grad_X @ gamma)
280+
print("noise")
281+
print(np.sqrt(2 * self.dt) * noise @ sqrt_gamma)
275282
X_copy += self.dt * grad_X @ gamma + np.sqrt(2 * self.dt) * noise @ sqrt_gamma
276283
X_copy[~mask_na] = X_init[~mask_na]
277284
if estimate_params:
278285
self.update_parameters(X_copy)
286+
print("X_copy")
287+
print(X_copy)
288+
if np.sum(np.abs(X_copy)) > 1e9:
289+
raise AssertionError
290+
print()
279291

280292
return X_copy
281293

@@ -489,8 +501,10 @@ def get_gamma(self) -> NDArray:
489501
NDArray
490502
Gamma matrix
491503
"""
492-
gamma = np.diag(np.diagonal(self.cov))
493-
# gamma = self.cov
504+
print("get_gamma")
505+
print(self.cov)
506+
# gamma = np.diag(np.diagonal(self.cov))
507+
gamma = self.cov
494508
# gamma = np.eye(len(self.cov))
495509
return gamma
496510

@@ -571,9 +585,9 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray:
571585
NDArray
572586
DataFrame with imputed values.
573587
"""
574-
X_center = X - self.means[:, None]
588+
X_center = X - self.means
575589
X_imputed = _conjugate_gradient(self.cov_inv, X_center, mask_na)
576-
X_imputed = self.means[:, None] + X_imputed
590+
X_imputed = self.means + X_imputed
577591
return X_imputed
578592

579593
def _check_convergence(self) -> bool:

tests/imputations/test_em_sampler.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Literal
22

33
import numpy as np
44
import pytest
@@ -279,6 +279,31 @@ def test_mean_covariance_multinormalem():
279279
np.testing.assert_allclose(covariance_imputed, covariance, rtol=1e-1, atol=1e-1)
280280

281281

282+
def test_multinormal_em_minimize_llik():
283+
X, X_missing, mean, covariance = generate_multinormal_predefined_mean_cov(d=2, n=1000)
284+
imputer = em_sampler.MultiNormalEM(method="mle", random_state=11)
285+
X_imputed = imputer.fit_transform(X_missing)
286+
llikelihood_imputed = imputer.get_loglikelihood(X_imputed)
287+
for _ in range(10):
288+
Delta = imputer.rng.uniform(0, 1, size=X.shape)
289+
X_perturbated = X_imputed + Delta
290+
llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated)
291+
assert llikelihood_perturbated < llikelihood_imputed
292+
X_perturbated = X
293+
X_perturbated[np.isnan(X)] = 0
294+
llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated)
295+
assert llikelihood_perturbated < llikelihood_imputed
296+
297+
298+
@pytest.mark.parametrize("method", ["sample", "mle"])
299+
def test_multinormal_em_fit_transform(method: Literal["mle", "sample"]):
300+
imputer = em_sampler.MultiNormalEM(method=method, random_state=11)
301+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
302+
result = imputer.fit_transform(X)
303+
assert result.shape == X.shape
304+
np.testing.assert_allclose(result[~np.isnan(X)], X[~np.isnan(X)])
305+
306+
282307
@pytest.mark.parametrize(
283308
"p",
284309
[1],
@@ -319,7 +344,6 @@ def test_varpem_fit_transform():
319344
]
320345
)
321346
np.testing.assert_allclose(result, expected, atol=1e-12)
322-
# assert False
323347

324348

325349
@pytest.mark.parametrize(

tests/imputations/test_imputers.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -290,20 +290,6 @@ def test_ImputerSoftImpute_fit_transform(df: pd.DataFrame) -> None:
290290
np.testing.assert_allclose(result, expected, atol=1e-2)
291291

292292

293-
@pytest.mark.parametrize("df", [df_timeseries])
294-
def test_ImputerEM_fit_transform(df: pd.DataFrame) -> None:
295-
imputer = imputers.ImputerEM(method="sample", dt=1e-3, random_state=42)
296-
result = imputer.fit_transform(df)
297-
expected = pd.DataFrame(
298-
{
299-
"col1": [i for i in range(20)],
300-
"col2": [0, 0.638, 2, 2.714, 2] + [i for i in range(5, 20)],
301-
}
302-
)
303-
print(result)
304-
np.testing.assert_allclose(result, expected, atol=1e-2)
305-
306-
307293
index_grouped = pd.MultiIndex.from_product([["a", "b"], range(4)], names=["group", "date"])
308294
dict_values = {"col1": [0, np.nan, 0, np.nan, 1, 1, 1, 1], "col2": [1, 1, 1, 1, 2, 2, 2, 2]}
309295
df_grouped = pd.DataFrame(dict_values, index=index_grouped)

0 commit comments

Comments
 (0)