Skip to content

Commit 0bb03f5

Browse files
committed
add convergence warning
1 parent bf71128 commit 0bb03f5

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

rehline/_class.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
# License: MIT License
77

8+
import warnings
9+
810
import numpy as np
911
from sklearn.base import BaseEstimator
12+
from sklearn.exceptions import ConvergenceWarning
1013
from sklearn.utils.validation import (_check_sample_weight, check_array,
1114
check_is_fitted, check_X_y)
1215

@@ -167,6 +170,12 @@ def fit(self, X, sample_weight=None):
167170
self.dual_obj_ = result.dual_objfns
168171
self.primal_obj_ = result.primal_objfns
169172

173+
if self.n_iter_ >= self.max_iter:
174+
warnings.warn(
175+
"ReHLine failed to converge, increase the number of iterations: `max_iter`.",
176+
ConvergenceWarning,
177+
)
178+
170179
def decision_function(self, X):
171180
"""The decision function evaluated on the given dataset
172181
@@ -363,6 +372,12 @@ def fit(self, X, y, sample_weight=None):
363372
self.dual_obj_ = result.dual_objfns
364373
self.primal_obj_ = result.primal_objfns
365374

375+
if self.n_iter_ >= self.max_iter:
376+
warnings.warn(
377+
"ReHLine failed to converge, increase the number of iterations: `max_iter`.",
378+
ConvergenceWarning,
379+
)
380+
366381
def decision_function(self, X):
367382
"""The decision function evaluated on the given dataset
368383

0 commit comments

Comments
 (0)