Skip to content

Commit 4362c2c

Browse files
committed
mv _prox_vec to utils
1 parent a24ed9c commit 4362c2c

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

skglm/solvers/fista.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
import numpy as np
22
from scipy.sparse import issparse
3-
from numba import njit
43
from skglm.solvers.base import BaseSolver
54
from skglm.solvers.common import construct_grad, construct_grad_sparse
6-
7-
8-
@njit
9-
def _prox_vec(w, z, penalty, lipschitz):
10-
n_features = w.shape[0]
11-
for j in range(n_features):
12-
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
13-
return w
5+
from skglm.utils import prox_vec
146

157

168
class FISTA(BaseSolver):
@@ -71,7 +63,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
7163
else:
7264
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
7365
z -= grad / lipschitz
74-
w = _prox_vec(w, z, penalty, lipschitz)
66+
w = prox_vec(w, z, penalty, lipschitz)
7567
Xw = X @ w
7668
z = w + (t_old - 1.) / t_new * (w - w_old)
7769

skglm/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,32 @@ def extrapolate(self, w, Xw):
457457
C = inv_UTU_ones / np.sum(inv_UTU_ones)
458458
# floating point errors may cause w and Xw to disagree
459459
return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C, True
460+
461+
462+
@njit
463+
def prox_vec(w, z, penalty, lipschitz):
464+
"""Evaluate the vectorized proximal operator for the FISTA solver.
465+
466+
Parameters
467+
----------
468+
w : array, shape (n_features,)
469+
Coefficient vector.
470+
471+
z : array, shape (n_features,)
472+
FISTA auxiliary variable.
473+
474+
penalty : instance of Penalty.
475+
Penalty object.
476+
477+
lipschitz : float
478+
Global Lipschitz constant.
479+
480+
Returns
481+
-------
482+
w : array; shape (n_features,)
483+
Updated coefficient vector.
484+
"""
485+
n_features = w.shape[0]
486+
for j in range(n_features):
487+
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
488+
return w

0 commit comments

Comments
 (0)