|
1 | | -from typing import List |
2 | | - |
| 1 | +from typing import List, Literal |
3 | 2 | import numpy as np |
4 | 3 | import pytest |
5 | 4 | from numpy.typing import NDArray |
@@ -279,6 +278,31 @@ def test_mean_covariance_multinormalem(): |
279 | 278 | np.testing.assert_allclose(covariance_imputed, covariance, rtol=1e-1, atol=1e-1) |
280 | 279 |
|
281 | 280 |
|
| 281 | +def test_multinormal_em_minimize_llik(): |
| 282 | + X, X_missing, mean, covariance = generate_multinormal_predefined_mean_cov(d=2, n=1000) |
| 283 | + imputer = em_sampler.MultiNormalEM(method="mle", random_state=11) |
| 284 | + X_imputed = imputer.fit_transform(X_missing) |
| 285 | + llikelihood_imputed = imputer.get_loglikelihood(X_imputed) |
| 286 | + for _ in range(10): |
| 287 | + Delta = imputer.rng.uniform(0, 1, size=X.shape) |
| 288 | + X_perturbated = X_imputed + Delta |
| 289 | + llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated) |
| 290 | + assert llikelihood_perturbated < llikelihood_imputed |
| 291 | + X_perturbated = X |
| 292 | + X_perturbated[np.isnan(X)] = 0 |
| 293 | + llikelihood_perturbated = imputer.get_loglikelihood(X_perturbated) |
| 294 | + assert llikelihood_perturbated < llikelihood_imputed |
| 295 | + |
| 296 | + |
| 297 | +@pytest.mark.parametrize("method", ["sample", "mle"]) |
| 298 | +def test_multinormal_em_fit_transform(method: Literal["mle", "sample"]): |
| 299 | + imputer = em_sampler.MultiNormalEM(method=method, random_state=11) |
| 300 | + X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]]) |
| 301 | + result = imputer.fit_transform(X) |
| 302 | + assert result.shape == X.shape |
| 303 | + np.testing.assert_allclose(result[~np.isnan(X)], X[~np.isnan(X)]) |
| 304 | + |
| 305 | + |
282 | 306 | @pytest.mark.parametrize( |
283 | 307 | "p", |
284 | 308 | [1], |
@@ -319,7 +343,6 @@ def test_varpem_fit_transform(): |
319 | 343 | ] |
320 | 344 | ) |
321 | 345 | np.testing.assert_allclose(result, expected, atol=1e-12) |
322 | | - # assert False |
323 | 346 |
|
324 | 347 |
|
325 | 348 | @pytest.mark.parametrize( |
|
0 commit comments