Skip to content

Commit 5559c85

Browse files
Merge pull request #137 from scikit-learn-contrib/em_sampler_correction
actions/setup-python version patched
2 parents e59e1e0 + ca67e8b commit 5559c85

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
steps:
1414
- uses: actions/checkout@v4
1515
- name: Set up Python
16-
uses: actions/setup-python@v3.12.0
16+
uses: actions/setup-python@v4
1717
with:
1818
python-version: '3.10'
1919
- name: Install dependencies

HISTORY.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
History
33
=======
44

5+
0.1.5 (2024-04-17)
6+
------------------
7+
8+
* CICD now relies on Node.js 20
9+
* New tests for comparator.py and data.py
10+
511
0.1.4 (2024-04-15)
612
------------------
713

qolmat/imputations/em_sampler.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray:
7171
return X_final
7272

7373

74-
def min_diff_Linf(list_params: List[NDArray], n_steps: int, order: int = 1) -> float:
74+
def max_diff_Linf(list_params: List[NDArray], n_steps: int, order: int = 1) -> float:
7575
"""Computes the maximal L infinity norm between the `n_steps` last elements spaced by order.
7676
Used to compute the stop criterion.
7777
@@ -762,8 +762,8 @@ def _check_convergence(self) -> bool:
762762
if n_iter < 3:
763763
return False
764764

765-
min_diff_means1 = min_diff_Linf(list_covs, n_steps=1)
766-
min_diff_covs1 = min_diff_Linf(list_means, n_steps=1)
765+
min_diff_means1 = max_diff_Linf(list_means, n_steps=1)
766+
min_diff_covs1 = max_diff_Linf(list_covs, n_steps=1)
767767
min_diff_reached = min_diff_means1 < self.tolerance and min_diff_covs1 < self.tolerance
768768

769769
if min_diff_reached:
@@ -772,16 +772,16 @@ def _check_convergence(self) -> bool:
772772
if n_iter < 7:
773773
return False
774774

775-
min_diff_means5 = min_diff_Linf(list_covs, n_steps=5)
776-
min_diff_covs5 = min_diff_Linf(list_means, n_steps=5)
775+
min_diff_means5 = max_diff_Linf(list_means, n_steps=5)
776+
min_diff_covs5 = max_diff_Linf(list_covs, n_steps=5)
777777

778778
min_diff_stable = (
779779
min_diff_means5 < self.stagnation_threshold
780780
and min_diff_covs5 < self.stagnation_threshold
781781
)
782782

783-
min_diff_loglik5_ord1 = min_diff_Linf(list_logliks, n_steps=5)
784-
min_diff_loglik5_ord2 = min_diff_Linf(list_logliks, n_steps=5, order=2)
783+
min_diff_loglik5_ord1 = max_diff_Linf(list_logliks, n_steps=5)
784+
min_diff_loglik5_ord2 = max_diff_Linf(list_logliks, n_steps=5, order=2)
785785
max_loglik = (min_diff_loglik5_ord1 < self.stagnation_loglik) or (
786786
min_diff_loglik5_ord2 < self.stagnation_loglik
787787
)
@@ -1105,8 +1105,8 @@ def _check_convergence(self) -> bool:
11051105
if n_iter < 3:
11061106
return False
11071107

1108-
min_diff_B1 = min_diff_Linf(list_B, n_steps=1)
1109-
min_diff_S1 = min_diff_Linf(list_S, n_steps=1)
1108+
min_diff_B1 = max_diff_Linf(list_B, n_steps=1)
1109+
min_diff_S1 = max_diff_Linf(list_S, n_steps=1)
11101110
min_diff_reached = min_diff_B1 < self.tolerance and min_diff_S1 < self.tolerance
11111111

11121112
if min_diff_reached:
@@ -1115,14 +1115,14 @@ def _check_convergence(self) -> bool:
11151115
if n_iter < 7:
11161116
return False
11171117

1118-
min_diff_B5 = min_diff_Linf(list_B, n_steps=5)
1119-
min_diff_S5 = min_diff_Linf(list_S, n_steps=5)
1118+
min_diff_B5 = max_diff_Linf(list_B, n_steps=5)
1119+
min_diff_S5 = max_diff_Linf(list_S, n_steps=5)
11201120
min_diff_stable = (
11211121
min_diff_B5 < self.stagnation_threshold and min_diff_S5 < self.stagnation_threshold
11221122
)
11231123

1124-
max_loglik5_ord1 = min_diff_Linf(list_logliks, n_steps=5, order=1)
1125-
max_loglik5_ord2 = min_diff_Linf(list_logliks, n_steps=5, order=2)
1124+
max_loglik5_ord1 = max_diff_Linf(list_logliks, n_steps=5, order=1)
1125+
max_loglik5_ord2 = max_diff_Linf(list_logliks, n_steps=5, order=2)
11261126
max_loglik = (max_loglik5_ord1 < self.stagnation_loglik) or (
11271127
max_loglik5_ord2 < self.stagnation_loglik
11281128
)

0 commit comments

Comments
 (0)