@@ -71,7 +71,7 @@ def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray:
7171 return X_final
7272
7373
74- def min_diff_Linf (list_params : List [NDArray ], n_steps : int , order : int = 1 ) -> float :
74+ def max_diff_Linf (list_params : List [NDArray ], n_steps : int , order : int = 1 ) -> float :
7575 """Computes the maximal L infinity norm between the `n_steps` last elements spaced by order.
7676 Used to compute the stop criterion.
7777
@@ -762,8 +762,8 @@ def _check_convergence(self) -> bool:
762762 if n_iter < 3 :
763763 return False
764764
765- min_diff_means1 = min_diff_Linf ( list_covs , n_steps = 1 )
766- min_diff_covs1 = min_diff_Linf ( list_means , n_steps = 1 )
765+ min_diff_means1 = max_diff_Linf ( list_means , n_steps = 1 )
766+ min_diff_covs1 = max_diff_Linf ( list_covs , n_steps = 1 )
767767 min_diff_reached = min_diff_means1 < self .tolerance and min_diff_covs1 < self .tolerance
768768
769769 if min_diff_reached :
@@ -772,16 +772,16 @@ def _check_convergence(self) -> bool:
772772 if n_iter < 7 :
773773 return False
774774
775- min_diff_means5 = min_diff_Linf ( list_covs , n_steps = 5 )
776- min_diff_covs5 = min_diff_Linf ( list_means , n_steps = 5 )
775+ min_diff_means5 = max_diff_Linf ( list_means , n_steps = 5 )
776+ min_diff_covs5 = max_diff_Linf ( list_covs , n_steps = 5 )
777777
778778 min_diff_stable = (
779779 min_diff_means5 < self .stagnation_threshold
780780 and min_diff_covs5 < self .stagnation_threshold
781781 )
782782
783- min_diff_loglik5_ord1 = min_diff_Linf (list_logliks , n_steps = 5 )
784- min_diff_loglik5_ord2 = min_diff_Linf (list_logliks , n_steps = 5 , order = 2 )
783+ min_diff_loglik5_ord1 = max_diff_Linf (list_logliks , n_steps = 5 )
784+ min_diff_loglik5_ord2 = max_diff_Linf (list_logliks , n_steps = 5 , order = 2 )
785785 max_loglik = (min_diff_loglik5_ord1 < self .stagnation_loglik ) or (
786786 min_diff_loglik5_ord2 < self .stagnation_loglik
787787 )
@@ -1105,8 +1105,8 @@ def _check_convergence(self) -> bool:
11051105 if n_iter < 3 :
11061106 return False
11071107
1108- min_diff_B1 = min_diff_Linf (list_B , n_steps = 1 )
1109- min_diff_S1 = min_diff_Linf (list_S , n_steps = 1 )
1108+ min_diff_B1 = max_diff_Linf (list_B , n_steps = 1 )
1109+ min_diff_S1 = max_diff_Linf (list_S , n_steps = 1 )
11101110 min_diff_reached = min_diff_B1 < self .tolerance and min_diff_S1 < self .tolerance
11111111
11121112 if min_diff_reached :
@@ -1115,14 +1115,14 @@ def _check_convergence(self) -> bool:
11151115 if n_iter < 7 :
11161116 return False
11171117
1118- min_diff_B5 = min_diff_Linf (list_B , n_steps = 5 )
1119- min_diff_S5 = min_diff_Linf (list_S , n_steps = 5 )
1118+ min_diff_B5 = max_diff_Linf (list_B , n_steps = 5 )
1119+ min_diff_S5 = max_diff_Linf (list_S , n_steps = 5 )
11201120 min_diff_stable = (
11211121 min_diff_B5 < self .stagnation_threshold and min_diff_S5 < self .stagnation_threshold
11221122 )
11231123
1124- max_loglik5_ord1 = min_diff_Linf (list_logliks , n_steps = 5 , order = 1 )
1125- max_loglik5_ord2 = min_diff_Linf (list_logliks , n_steps = 5 , order = 2 )
1124+ max_loglik5_ord1 = max_diff_Linf (list_logliks , n_steps = 5 , order = 1 )
1125+ max_loglik5_ord2 = max_diff_Linf (list_logliks , n_steps = 5 , order = 2 )
11261126 max_loglik = (max_loglik5_ord1 < self .stagnation_loglik ) or (
11271127 max_loglik5_ord2 < self .stagnation_loglik
11281128 )
0 commit comments