Skip to content

Commit d5a7494

Browse files
authored
refactor common code (#36)
1 parent e3efa30 commit d5a7494

File tree

2 files changed

+115
-112
lines changed

2 files changed

+115
-112
lines changed

skglm/solvers/cd_solver.py

Lines changed: 1 addition & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numba import njit
33
from scipy import sparse
44
from sklearn.utils import check_array
5+
from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point
56

67

78
def cd_solver_path(X, y, datafit, penalty, alphas=None,
@@ -347,118 +348,6 @@ def cd_solver(
347348
return w, np.array(obj_out), stop_crit
348349

349350

350-
@njit
351-
def dist_fix_point(w, grad, datafit, penalty, ws):
352-
"""Compute the violation of the fixed point iterate scheme.
353-
354-
Parameters
355-
----------
356-
w : array, shape (n_features,)
357-
Coefficient vector.
358-
359-
grad : array, shape (n_features,)
360-
Gradient.
361-
362-
datafit: instance of BaseDatafit
363-
Datafit.
364-
365-
penalty: instance of BasePenalty
366-
Penalty.
367-
368-
ws : array, shape (n_features,)
369-
The working set.
370-
371-
Returns
372-
-------
373-
dist_fix_point : array, shape (n_features,)
374-
Violation score for every feature.
375-
"""
376-
dist_fix_point = np.zeros(ws.shape[0])
377-
for idx, j in enumerate(ws):
378-
lcj = datafit.lipschitz[j]
379-
if lcj != 0:
380-
dist_fix_point[idx] = np.abs(
381-
w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j))
382-
return dist_fix_point
383-
384-
385-
@njit
386-
def construct_grad(X, y, w, Xw, datafit, ws):
387-
"""Compute the gradient of the datafit restricted to the working set.
388-
389-
Parameters
390-
----------
391-
X : array, shape (n_samples, n_features)
392-
Design matrix.
393-
394-
y : array, shape (n_samples,)
395-
Target vector.
396-
397-
w : array, shape (n_features,)
398-
Coefficient vector.
399-
400-
Xw : array, shape (n_samples, )
401-
Model fit.
402-
403-
datafit : Datafit
404-
Datafit.
405-
406-
ws : array, shape (n_features,)
407-
The working set.
408-
409-
Returns
410-
-------
411-
grad : array, shape (ws_size, n_tasks)
412-
The gradient restricted to the working set.
413-
"""
414-
grad = np.zeros(ws.shape[0])
415-
for idx, j in enumerate(ws):
416-
grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j)
417-
return grad
418-
419-
420-
@njit
421-
def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws):
422-
"""Compute the gradient of the datafit restricted to the working set.
423-
424-
Parameters
425-
----------
426-
data : array-like
427-
Data array of the matrix in CSC format.
428-
429-
indptr : array-like
430-
CSC format index point array.
431-
432-
indices : array-like
433-
CSC format index array.
434-
435-
y : array, shape (n_samples, )
436-
Target matrix.
437-
438-
w : array, shape (n_features,)
439-
Coefficient matrix.
440-
441-
Xw : array, shape (n_samples, )
442-
Model fit.
443-
444-
datafit : Datafit
445-
Datafit.
446-
447-
ws : array, shape (n_features,)
448-
The working set.
449-
450-
Returns
451-
-------
452-
grad : array, shape (ws_size, n_tasks)
453-
The gradient restricted to the working set.
454-
"""
455-
grad = np.zeros(ws.shape[0])
456-
for idx, j in enumerate(ws):
457-
grad[idx] = datafit.gradient_scalar_sparse(
458-
data, indptr, indices, y, Xw, j)
459-
return grad
460-
461-
462351
@njit
463352
def _cd_epoch(X, y, w, Xw, datafit, penalty, feats):
464353
"""Run an epoch of coordinate descent in place.

skglm/solvers/common.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
from numba import njit
3+
4+
5+
@njit
6+
def dist_fix_point(w, grad, datafit, penalty, ws):
7+
"""Compute the violation of the fixed point iterate scheme.
8+
9+
Parameters
10+
----------
11+
w : array, shape (n_features,)
12+
Coefficient vector.
13+
14+
grad : array, shape (n_features,)
15+
Gradient.
16+
17+
datafit: instance of BaseDatafit
18+
Datafit.
19+
20+
penalty: instance of BasePenalty
21+
Penalty.
22+
23+
ws : array, shape (n_features,)
24+
The working set.
25+
26+
Returns
27+
-------
28+
dist_fix_point : array, shape (n_features,)
29+
Violation score for every feature.
30+
"""
31+
dist_fix_point = np.zeros(ws.shape[0])
32+
for idx, j in enumerate(ws):
33+
lcj = datafit.lipschitz[j]
34+
if lcj != 0:
35+
dist_fix_point[idx] = np.abs(
36+
w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j))
37+
return dist_fix_point
38+
39+
40+
@njit
41+
def construct_grad(X, y, w, Xw, datafit, ws):
42+
"""Compute the gradient of the datafit restricted to the working set.
43+
44+
Parameters
45+
----------
46+
X : array, shape (n_samples, n_features)
47+
Design matrix.
48+
49+
y : array, shape (n_samples,)
50+
Target vector.
51+
52+
w : array, shape (n_features,)
53+
Coefficient vector.
54+
55+
Xw : array, shape (n_samples, )
56+
Model fit.
57+
58+
datafit : Datafit
59+
Datafit.
60+
61+
ws : array, shape (n_features,)
62+
The working set.
63+
64+
Returns
65+
-------
66+
grad : array, shape (ws_size, n_tasks)
67+
The gradient restricted to the working set.
68+
"""
69+
grad = np.zeros(ws.shape[0])
70+
for idx, j in enumerate(ws):
71+
grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j)
72+
return grad
73+
74+
75+
@njit
76+
def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws):
77+
"""Compute the gradient of the datafit restricted to the working set.
78+
79+
Parameters
80+
----------
81+
data : array-like
82+
Data array of the matrix in CSC format.
83+
84+
indptr : array-like
85+
CSC format index point array.
86+
87+
indices : array-like
88+
CSC format index array.
89+
90+
y : array, shape (n_samples, )
91+
Target matrix.
92+
93+
w : array, shape (n_features,)
94+
Coefficient matrix.
95+
96+
Xw : array, shape (n_samples, )
97+
Model fit.
98+
99+
datafit : Datafit
100+
Datafit.
101+
102+
ws : array, shape (n_features,)
103+
The working set.
104+
105+
Returns
106+
-------
107+
grad : array, shape (ws_size, n_tasks)
108+
The gradient restricted to the working set.
109+
"""
110+
grad = np.zeros(ws.shape[0])
111+
for idx, j in enumerate(ws):
112+
grad[idx] = datafit.gradient_scalar_sparse(
113+
data, indptr, indices, y, Xw, j)
114+
return grad

0 commit comments

Comments
 (0)