Skip to content

Commit 7323e27

Browse files
authored
MNT - add warning and fix shape in ProxNewton solver (#158)
1 parent 8f88024 commit 7323e27

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

skglm/solvers/prox_newton.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import warnings
2+
13
import numpy as np
24
from numba import njit
35
from scipy.sparse import issparse
46
from skglm.solvers.base import BaseSolver
57

8+
from sklearn.exceptions import ConvergenceWarning
9+
610

711
EPS_TOL = 0.3
812
MAX_CD_ITER = 20
@@ -158,6 +162,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
158162

159163
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
160164
p_objs_out.append(p_obj)
165+
else:
166+
warnings.warn(
167+
f"`ProxNewton` did not converge for tol={self.tol:.3e} "
168+
f"and max_iter={self.max_iter}.\n"
169+
"Consider increasing `max_iter` and/or `tol`.",
170+
category=ConvergenceWarning
171+
)
161172
return w, np.asarray(p_objs_out), stop_crit
162173

163174

@@ -242,7 +253,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
242253

243254
# see _descent_direction() comment
244255
past_grads = np.zeros(len(ws))
245-
X_delta_w_ws = np.zeros(len(y))
256+
X_delta_w_ws = np.zeros(Xw_epoch.shape[0])
246257
ws_intercept = np.append(ws, -1) if fit_intercept else ws
247258
w_ws = w_epoch[ws_intercept]
248259

0 commit comments

Comments
 (0)