Skip to content

Commit 999b89e

Browse files
committed
add gram_solver
1 parent 21fcbfe commit 999b89e

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

skglm/gram_solver.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from time import time
2+
import numpy as np
3+
from numpy.linalg import norm
4+
from numba import njit
5+
from celer import Lasso, GroupLasso
6+
from benchopt.datasets.simulated import make_correlated_data
7+
from skglm.utils import BST, ST
8+
9+
10+
def _grp_converter(groups, n_features):
11+
if isinstance(groups, int):
12+
grp_size = groups
13+
if n_features % grp_size != 0:
14+
raise ValueError("n_features (%d) is not a multiple of the desired"
15+
" group size (%d)" % (n_features, grp_size))
16+
n_groups = n_features // grp_size
17+
grp_ptr = grp_size * np.arange(n_groups + 1)
18+
grp_indices = np.arange(n_features)
19+
elif isinstance(groups, list) and isinstance(groups[0], int):
20+
grp_indices = np.arange(n_features).astype(np.int32)
21+
grp_ptr = np.cumsum(np.hstack([[0], groups]))
22+
elif isinstance(groups, list) and isinstance(groups[0], list):
23+
grp_sizes = np.array([len(ls) for ls in groups])
24+
grp_ptr = np.cumsum(np.hstack([[0], grp_sizes]))
25+
grp_indices = np.array([idx for grp in groups for idx in grp])
26+
else:
27+
raise ValueError("Unsupported group format.")
28+
return grp_ptr.astype(np.int32), grp_indices.astype(np.int32)
29+
30+
31+
@njit
32+
def primal(alpha, y, X, w):
33+
r = y - X @ w
34+
p_obj = (r @ r) / (2 * len(y))
35+
return p_obj + alpha * np.sum(np.abs(w))
36+
37+
38+
@njit
39+
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices):
40+
r = y - X @ w
41+
p_obj = (r @ r) / (2 * len(y))
42+
for g in range(len(grp_ptr) - 1):
43+
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
44+
p_obj += alpha * norm(w_g, ord=2)
45+
return p_obj
46+
47+
48+
@njit
49+
def cd_epoch(X, G, grads, w, alpha, lipschitz):
50+
n_features = X.shape[1]
51+
for j in range(n_features):
52+
if lipschitz[j] == 0.:
53+
continue
54+
old_w_j = w[j]
55+
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j])
56+
if old_w_j != w[j]:
57+
grads += G[j, :] * (old_w_j - w[j]) / len(X)
58+
59+
60+
@njit
61+
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr):
62+
n_groups = len(grp_ptr) - 1
63+
for g in range(n_groups):
64+
if lipschitz[g] == 0.:
65+
continue
66+
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
67+
old_w_g = w[idx].copy()
68+
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g])
69+
diff = old_w_g - w[idx]
70+
if np.any(diff != 0.):
71+
grads += diff @ G[idx, :] / len(X)
72+
73+
74+
def lasso(X, y, alpha, max_iter, tol, check_freq=10):
75+
p_obj_prev = np.inf
76+
n_features = X.shape[1]
77+
# Initialization
78+
grads = X.T @ y / len(y)
79+
G = X.T @ X
80+
lipschitz = np.zeros(n_features, dtype=X.dtype)
81+
for j in range(n_features):
82+
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
83+
w = np.zeros(n_features)
84+
# CD
85+
for n_iter in range(max_iter):
86+
cd_epoch(X, G, grads, w, alpha, lipschitz)
87+
if n_iter % check_freq == 0:
88+
p_obj = primal(alpha, y, X, w)
89+
if p_obj_prev - p_obj < tol:
90+
print("Convergence reached!")
91+
break
92+
print(f"iter {n_iter} :: p_obj {p_obj}")
93+
p_obj_prev = p_obj
94+
return w
95+
96+
97+
def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
98+
p_obj_prev = np.inf
99+
n_features = X.shape[1]
100+
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
101+
n_groups = len(grp_ptr) - 1
102+
# Initialization
103+
grads = X.T @ y / len(y)
104+
G = X.T @ X
105+
lipschitz = np.zeros(n_groups, dtype=X.dtype)
106+
for g in range(n_groups):
107+
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
108+
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
109+
w = np.zeros(n_features)
110+
# BCD
111+
for n_iter in range(max_iter):
112+
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr)
113+
if n_iter % check_freq == 0:
114+
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices)
115+
if p_obj_prev - p_obj < tol:
116+
print("Convergence reached!")
117+
break
118+
print(f"iter {n_iter} :: p_obj {p_obj}")
119+
p_obj_prev = p_obj
120+
return w
121+
122+
123+
n_samples, n_features = 1_000_000, 300
124+
X, y, w_star = make_correlated_data(
125+
n_samples=n_samples, n_features=n_features, random_state=0)
126+
alpha_max = norm(X.T @ y, ord=np.inf)
127+
128+
# Hyperparameters
129+
max_iter = 1000
130+
tol = 1e-8
131+
reg = 0.1
132+
group_size = 3
133+
134+
alpha = alpha_max * reg / n_samples
135+
136+
# Lasso
137+
print("#" * 15)
138+
print("Lasso")
139+
print("#" * 15)
140+
start = time()
141+
w = lasso(X, y, alpha, max_iter, tol)
142+
gram_lasso_time = time() - start
143+
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
144+
start = time()
145+
clf_sk.fit(X, y)
146+
celer_lasso_time = time() - start
147+
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)
148+
149+
print("\n")
150+
print("Celer: %.2f" % celer_lasso_time)
151+
print("Gram: %.2f" % gram_lasso_time)
152+
print("\n")
153+
154+
# Group Lasso
155+
print("#" * 15)
156+
print("Group Lasso")
157+
print("#" * 15)
158+
start = time()
159+
w = group_lasso(X, y, alpha, group_size, max_iter, tol)
160+
gram_group_lasso_time = time() - start
161+
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
162+
start = time()
163+
clf_celer.fit(X, y)
164+
celer_group_lasso_time = time() - start
165+
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)
166+
167+
print("\n")
168+
print("Celer: %.2f" % celer_group_lasso_time)
169+
print("Gram: %.2f" % gram_group_lasso_time)
170+
print("\n")

0 commit comments

Comments
 (0)