Skip to content

Commit fe3bedd

Browse files
ENH - Add Gram Solver for single task Quadratic datafit (#59)
* gram solver && unit test * fix bug gram solver && tighten test * add anderson acceleration * bug ``stop_criter`` && refactor * refactoring of var names * handle ``w_init`` * refactor ``_gram_cd_`` * gram epoch greedy and cyclic strategy * extend to sparse case && unitest * one implementation of _gram_cd && unittest * greedy_cd arg instead of cd_strategy * add docs * keep grads instead * refactor ``chosen_j`` * potential improvements, docstring Co-authored-by: mathurinm <[email protected]>
1 parent 8900c6f commit fe3bedd

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed

skglm/solvers/gram_cd.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import warnings
2+
import numpy as np
3+
from numba import njit
4+
from scipy.sparse import issparse
5+
6+
from skglm.utils import AndersonAcceleration
7+
8+
9+
def gram_cd_solver(X, y, penalty, max_iter=100, w_init=None,
10+
use_acc=True, greedy_cd=True, tol=1e-4, verbose=False):
11+
r"""Run coordinate descent while keeping the gradients up-to-date with Gram updates.
12+
13+
This solver should be used when n_features < n_samples, and computes the
14+
(n_features, n_features) Gram matrix which comes with an overhead. It is only
15+
suited to Quadratic datafits.
16+
17+
It minimizes::
18+
1 / (2*n_samples) * norm(y - Xw)**2 + penalty(w)
19+
20+
which can be rewritten as::
21+
w.T @ Q @ w / (2*n_samples) - q.T @ w / n_samples + penalty(w)
22+
23+
where::
24+
Q = X.T @ X (gram matrix), and q = X.T @ y
25+
26+
Parameters
27+
----------
28+
X : array or sparse CSC matrix, shape (n_samples, n_features)
29+
Design matrix.
30+
31+
y : array, shape (n_samples,)
32+
Target vector.
33+
34+
penalty : instance of BasePenalty
35+
Penalty object.
36+
37+
max_iter : int, default 100
38+
Maximum number of iterations.
39+
40+
w_init : array, shape (n_features,), default None
41+
Initial value of coefficients.
42+
If set to None, a zero vector is used instead.
43+
44+
use_acc : bool, default True
45+
Extrapolate the iterates based on the past 5 iterates if set to True.
46+
47+
greedy_cd : bool, default True
48+
Use a greedy strategy to select features to update in coordinate descent epochs
49+
if set to True. A cyclic strategy is used otherwise.
50+
51+
tol : float, default 1e-4
52+
Tolerance for convergence.
53+
54+
verbose : bool, default False
55+
Amount of verbosity. 0/False is silent.
56+
57+
Returns
58+
-------
59+
w : array, shape (n_features,)
60+
Solution that minimizes the problem defined by datafit and penalty.
61+
62+
objs_out : array, shape (n_iter,)
63+
The objective values at every outer iteration.
64+
65+
stop_crit : float
66+
The value of the stopping criterion when the solver stops.
67+
"""
68+
n_samples, n_features = X.shape
69+
70+
if issparse(X):
71+
scaled_gram = X.T.dot(X)
72+
scaled_gram = scaled_gram.toarray() / n_samples
73+
scaled_Xty = X.T.dot(y) / n_samples
74+
else:
75+
scaled_gram = X.T @ X / n_samples
76+
scaled_Xty = X.T @ y / n_samples
77+
# TODO potential improvement: allow to pass scaled_gram (e.g. for path computation)
78+
79+
scaled_y_norm2 = np.linalg.norm(y)**2 / (2*n_samples)
80+
81+
all_features = np.arange(n_features)
82+
stop_crit = np.inf # prevent ref before assign
83+
p_objs_out = []
84+
85+
w = np.zeros(n_features) if w_init is None else w_init
86+
grad = - scaled_Xty if w_init is None else scaled_gram @ w_init - scaled_Xty
87+
opt = penalty.subdiff_distance(w, grad, all_features)
88+
89+
if use_acc:
90+
if greedy_cd:
91+
warnings.warn(
92+
"Anderson acceleration does not work with greedy_cd, set use_acc=False",
93+
UserWarning)
94+
accelerator = AndersonAcceleration(K=5)
95+
w_acc = np.zeros(n_features)
96+
grad_acc = np.zeros(n_features)
97+
98+
for t in range(max_iter):
99+
# check convergences
100+
stop_crit = np.max(opt)
101+
if verbose:
102+
p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w +
103+
scaled_y_norm2 + penalty.value(w))
104+
print(
105+
f"Iteration {t+1}: {p_obj:.10f}, "
106+
f"stopping crit: {stop_crit:.2e}"
107+
)
108+
109+
if stop_crit <= tol:
110+
if verbose:
111+
print(f"Stopping criterion max violation: {stop_crit:.2e}")
112+
break
113+
114+
# inplace update of w, grad
115+
opt = _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd)
116+
117+
# perform Anderson extrapolation
118+
if use_acc:
119+
w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad)
120+
121+
if is_extrapolated:
122+
# omit constant term for comparison
123+
p_obj_acc = (0.5 * w_acc @ (scaled_gram @ w_acc) - scaled_Xty @ w_acc +
124+
penalty.value(w_acc))
125+
p_obj = 0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + penalty.value(w)
126+
if p_obj_acc < p_obj:
127+
w[:] = w_acc
128+
grad[:] = grad_acc
129+
130+
# store p_obj
131+
p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + scaled_y_norm2 +
132+
penalty.value(w))
133+
p_objs_out.append(p_obj)
134+
return w, np.array(p_objs_out), stop_crit
135+
136+
137+
@njit
138+
def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd):
139+
all_features = np.arange(len(w))
140+
for cd_iter in all_features:
141+
# select feature j
142+
if greedy_cd:
143+
opt = penalty.subdiff_distance(w, grad, all_features)
144+
j = np.argmax(opt)
145+
else: # cyclic
146+
j = cd_iter
147+
148+
# update w_j
149+
old_w_j = w[j]
150+
step = 1 / scaled_gram[j, j] # 1 / lipschitz_j
151+
w[j] = penalty.prox_1d(old_w_j - step * grad[j], step, j)
152+
153+
# gradient update with Gram matrix
154+
if w[j] != old_w_j:
155+
grad += (w[j] - old_w_j) * scaled_gram[:, j]
156+
157+
return penalty.subdiff_distance(w, grad, all_features)

skglm/tests/test_gram_solver.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
from itertools import product
3+
4+
import numpy as np
5+
from numpy.linalg import norm
6+
from sklearn.linear_model import Lasso
7+
8+
from skglm.penalties import L1
9+
from skglm.solvers.gram_cd import gram_cd_solver
10+
from skglm.utils import make_correlated_data, compiled_clone
11+
12+
13+
@pytest.mark.parametrize("rho, X_density, greedy_cd",
14+
product([1e-1, 1e-3], [1., 0.8], [True, False]))
15+
def test_vs_lasso_sklearn(rho, X_density, greedy_cd):
16+
X, y, _ = make_correlated_data(
17+
n_samples=18, n_features=8, random_state=0, X_density=X_density)
18+
alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
19+
alpha = rho * alpha_max
20+
21+
sk_lasso = Lasso(alpha, fit_intercept=False, tol=1e-9)
22+
sk_lasso.fit(X, y)
23+
24+
l1_penalty = compiled_clone(L1(alpha))
25+
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0,
26+
max_iter=1000, greedy_cd=greedy_cd)[0]
27+
28+
np.testing.assert_allclose(w, sk_lasso.coef_.flatten(), rtol=1e-7, atol=1e-7)
29+
30+
31+
if __name__ == '__main__':
32+
pass

0 commit comments

Comments
 (0)