Skip to content

Commit 57e9f90

Browse files
Merge pull request #112 from Quantmetry/add_test_varpem_nonregression
Add test varpem nonregression
2 parents 12322e3 + 7d25988 commit 57e9f90

File tree

6 files changed

+31
-27
lines changed

6 files changed

+31
-27
lines changed

qolmat/benchmark/metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,6 @@ def kl_divergence_gaussian_exact(
863863
norm_M = (M**2).sum().sum()
864864
norm_y = (y**2).sum()
865865
term_diag_L = 2 * np.sum(np.log(np.diagonal(L2) / np.diagonal(L1)))
866-
print(norm_M, "-", n_variables, "+", norm_y, "+", term_diag_L)
867866
div_kl = 0.5 * (norm_M - n_variables + norm_y + term_diag_L)
868867
return div_kl
869868

qolmat/imputations/em_sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _sample_ou(
269269
X_init = X.copy()
270270
gamma = self.get_gamma()
271271
sqrt_gamma = np.real(spl.sqrtm(gamma))
272-
for _ in range(self.n_iter_ou):
272+
for i in range(self.n_iter_ou):
273273
noise = self.ampli * self.rng.normal(0, 1, size=(n_variables, n_samples))
274274
grad_X = self.gradient_X_loglik(X_copy)
275275
X_copy += self.dt * grad_X @ gamma + np.sqrt(2 * self.dt) * noise @ sqrt_gamma
@@ -489,8 +489,8 @@ def get_gamma(self) -> NDArray:
489489
NDArray
490490
Gamma matrix
491491
"""
492-
gamma = np.diag(np.diagonal(self.cov))
493-
# gamma = self.cov
492+
# gamma = np.diag(np.diagonal(self.cov))
493+
gamma = self.cov
494494
# gamma = np.eye(len(self.cov))
495495
return gamma
496496

@@ -571,9 +571,9 @@ def _maximize_likelihood(self, X: NDArray, mask_na: NDArray) -> NDArray:
571571
NDArray
572572
DataFrame with imputed values.
573573
"""
574-
X_center = X - self.means[:, None]
574+
X_center = X - self.means
575575
X_imputed = _conjugate_gradient(self.cov_inv, X_center, mask_na)
576-
X_imputed = self.means[:, None] + X_imputed
576+
X_imputed = self.means + X_imputed
577577
return X_imputed
578578

579579
def _check_convergence(self) -> bool:

tests/imputations/test_em_sampler.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import List
2-
1+
from typing import List, Literal
32
import numpy as np
43
import pytest
54
from numpy.typing import NDArray
@@ -279,6 +278,31 @@ def test_mean_covariance_multinormalem():
279278
np.testing.assert_allclose(covariance_imputed, covariance, rtol=1e-1, atol=1e-1)
280279

281280

281+
def test_multinormal_em_minimize_llik():
282+
X, X_missing, mean, covariance = generate_multinormal_predefined_mean_cov(d=2, n=1000)
283+
imputer = em_sampler.MultiNormalEM(method="mle", random_state=11)
284+
X_imputed = imputer.fit_transform(X_missing)
285+
llikelihood_imputed = imputer.get_loglikelihood(X_imputed)
286+
for _ in range(10):
287+
Delta = imputer.rng.uniform(0, 1, size=X.shape)
288+
X_perturbated = X_imputed + Delta
289+
llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated)
290+
assert llikelihood_perturbated < llikelihood_imputed
291+
X_perturbated = X
292+
X_perturbated[np.isnan(X)] = 0
293+
llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated)
294+
assert llikelihood_perturbated < llikelihood_imputed
295+
296+
297+
@pytest.mark.parametrize("method", ["sample", "mle"])
298+
def test_multinormal_em_fit_transform(method: Literal["mle", "sample"]):
299+
imputer = em_sampler.MultiNormalEM(method=method, random_state=11)
300+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
301+
result = imputer.fit_transform(X)
302+
assert result.shape == X.shape
303+
np.testing.assert_allclose(result[~np.isnan(X)], X[~np.isnan(X)])
304+
305+
282306
@pytest.mark.parametrize(
283307
"p",
284308
[1],
@@ -319,7 +343,6 @@ def test_varpem_fit_transform():
319343
]
320344
)
321345
np.testing.assert_allclose(result, expected, atol=1e-12)
322-
# assert False
323346

324347

325348
@pytest.mark.parametrize(

tests/imputations/test_imputers.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def test_ImputerShuffle_fit_transform1(df: pd.DataFrame) -> None:
174174
def test_ImputerShuffle_fit_transform2(df: pd.DataFrame) -> None:
175175
imputer = imputers.ImputerShuffle(random_state=42)
176176
result = imputer.fit_transform(df)
177-
print(result)
178177
expected = pd.DataFrame({"col1": [0, 3, 2, 3, 0], "col2": [-1, 1.5, 0.5, 1.5, 1.5]})
179178
np.testing.assert_allclose(result, expected)
180179

@@ -290,20 +289,6 @@ def test_ImputerSoftImpute_fit_transform(df: pd.DataFrame) -> None:
290289
np.testing.assert_allclose(result, expected, atol=1e-2)
291290

292291

293-
@pytest.mark.parametrize("df", [df_timeseries])
294-
def test_ImputerEM_fit_transform(df: pd.DataFrame) -> None:
295-
imputer = imputers.ImputerEM(method="sample", dt=1e-3, random_state=42)
296-
result = imputer.fit_transform(df)
297-
expected = pd.DataFrame(
298-
{
299-
"col1": [i for i in range(20)],
300-
"col2": [0, 0.638, 2, 2.714, 2] + [i for i in range(5, 20)],
301-
}
302-
)
303-
print(result)
304-
np.testing.assert_allclose(result, expected, atol=1e-2)
305-
306-
307292
index_grouped = pd.MultiIndex.from_product([["a", "b"], range(4)], names=["group", "date"])
308293
dict_values = {"col1": [0, np.nan, 0, np.nan, 1, 1, 1, 1], "col2": [1, 1, 1, 1, 2, 2, 2, 2]}
309294
df_grouped = pd.DataFrame(dict_values, index=index_grouped)

tests/imputations/test_imputers_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def test_ImputerRegressorPyTorch_fit_transform(df: pd.DataFrame) -> None:
5454
"col5": [93, 75, 2.132, 12, 2.345],
5555
}
5656
)
57-
print(result["col5"])
5857
np.testing.assert_allclose(result, expected, atol=1e-3)
5958

6059

tests/utils/test_data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,9 @@ def test_utils_data_get_data(name_data: str, df: pd.DataFrame, mocker: MockerFix
186186
assert df_result.columns.tolist() == expected_columns
187187
elif name_data == "Monach_weather":
188188
assert mock_download.call_count == 1
189-
print(df_result)
190189
pd.testing.assert_frame_equal(df_result, df_monach_weather_preprocess)
191190
elif name_data == "Monach_electricity_australia":
192191
assert mock_download.call_count == 1
193-
print(df_result)
194192
pd.testing.assert_frame_equal(df_result, df_monach_elec_preprocess)
195193
else:
196194
assert False

0 commit comments

Comments
 (0)