Skip to content

Commit 787c8c2

Browse files
committed
isolate gram solver in solvers submodule
1 parent af666a5 commit 787c8c2

File tree

3 files changed

+152
-185
lines changed

3 files changed

+152
-185
lines changed

gram_test.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import matplotlib.pyplot as plt
77
import numpy as np
88
from celer import GroupLasso
9-
from skglm.gram_solver import group_lasso
9+
from skglm.solvers.gram import gram_group_lasso
1010

1111
X = np.load("design_matrix.npy")
1212
y = np.load("target.npy")
@@ -22,7 +22,7 @@
2222
# Case 1: slower runtime for (very) small alphas
2323
# alpha_max = 0.003471727067743962
2424
alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y)
25-
alpha = alpha_max / 50
25+
alpha = alpha_max / 100
2626
clf = GroupLasso(fit_intercept=False,
2727
groups=5, alpha=alpha, verbose=1)
2828

@@ -32,26 +32,12 @@
3232

3333
print(f"Celer: {t1 - t0:.3f} s")
3434

35+
# beware: stopping criterion is not the same, tol here needs to be lower
36+
# to get meaningful comparison
3537
t0 = time.time()
3638
res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10)
3739
t1 = time.time()
3840

3941
print(f"skglm gram: {t1 - t0:.3f} s")
4042

41-
# # Case 2: slower runtime for (very) small alphas with weights
42-
# # alpha_max_w = 0.0001897719130007628
43-
# alpha_max_w = np.max(norm((X.T @ y).reshape(-1, 5) /
44-
# weights[:, None], axis=1)) / len(y)
45-
46-
47-
# alpha_ratio = 0.1
48-
# grid_w = np.geomspace(alpha_max_w*alpha_ratio, alpha_max_w, n_alphas)[::-1]
49-
# clf = GroupLasso(fit_intercept=False,
50-
# weights=weights, groups=grps, warm_start=True)
51-
52-
# # for alpha in grid_w:
53-
# # clf.alpha = alpha
54-
# # t0 = time.time()
55-
# # clf.fit(X, y)
56-
# t1 = time.time()
57-
# print(f"Finished tuning with {alpha:.2e}. Took {t1-t0:.2f} seconds!")
43+
# TODO support weights in gram solver

skglm/gram_solver.py

Lines changed: 51 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,171 +1,56 @@
11
from time import time
22
import numpy as np
33
from numpy.linalg import norm
4-
from numba import njit
54
from celer import Lasso, GroupLasso
65
from benchopt.datasets.simulated import make_correlated_data
7-
from skglm.utils import BST, ST
8-
9-
10-
def _grp_converter(groups, n_features):
11-
if isinstance(groups, int):
12-
grp_size = groups
13-
if n_features % grp_size != 0:
14-
raise ValueError("n_features (%d) is not a multiple of the desired"
15-
" group size (%d)" % (n_features, grp_size))
16-
n_groups = n_features // grp_size
17-
grp_ptr = grp_size * np.arange(n_groups + 1)
18-
grp_indices = np.arange(n_features)
19-
elif isinstance(groups, list) and isinstance(groups[0], int):
20-
grp_indices = np.arange(n_features).astype(np.int32)
21-
grp_ptr = np.cumsum(np.hstack([[0], groups]))
22-
elif isinstance(groups, list) and isinstance(groups[0], list):
23-
grp_sizes = np.array([len(ls) for ls in groups])
24-
grp_ptr = np.cumsum(np.hstack([[0], grp_sizes]))
25-
grp_indices = np.array([idx for grp in groups for idx in grp])
26-
else:
27-
raise ValueError("Unsupported group format.")
28-
return grp_ptr.astype(np.int32), grp_indices.astype(np.int32)
29-
30-
31-
@njit
32-
def primal(alpha, y, X, w):
33-
r = y - X @ w
34-
p_obj = (r @ r) / (2 * len(y))
35-
return p_obj + alpha * np.sum(np.abs(w))
36-
37-
38-
@njit
39-
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices):
40-
r = y - X @ w
41-
p_obj = (r @ r) / (2 * len(y))
42-
for g in range(len(grp_ptr) - 1):
43-
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
44-
p_obj += alpha * norm(w_g, ord=2)
45-
return p_obj
46-
47-
48-
@njit
49-
def cd_epoch(X, G, grads, w, alpha, lipschitz):
50-
n_features = X.shape[1]
51-
for j in range(n_features):
52-
if lipschitz[j] == 0.:
53-
continue
54-
old_w_j = w[j]
55-
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j])
56-
if old_w_j != w[j]:
57-
grads += G[j, :] * (old_w_j - w[j]) / len(X)
58-
59-
60-
@njit
61-
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr):
62-
n_groups = len(grp_ptr) - 1
63-
for g in range(n_groups):
64-
if lipschitz[g] == 0.:
65-
continue
66-
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
67-
old_w_g = w[idx].copy()
68-
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g])
69-
diff = old_w_g - w[idx]
70-
if np.any(diff != 0.):
71-
grads += diff @ G[idx, :] / len(X)
72-
73-
74-
def lasso(X, y, alpha, max_iter, tol, check_freq=10):
75-
p_obj_prev = np.inf
76-
n_features = X.shape[1]
77-
# Initialization
78-
grads = X.T @ y / len(y)
79-
G = X.T @ X
80-
lipschitz = np.zeros(n_features, dtype=X.dtype)
81-
for j in range(n_features):
82-
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
83-
w = np.zeros(n_features)
84-
# CD
85-
for n_iter in range(max_iter):
86-
cd_epoch(X, G, grads, w, alpha, lipschitz)
87-
if n_iter % check_freq == 0:
88-
p_obj = primal(alpha, y, X, w)
89-
if p_obj_prev - p_obj < tol:
90-
print("Convergence reached!")
91-
break
92-
print(f"iter {n_iter} :: p_obj {p_obj}")
93-
p_obj_prev = p_obj
94-
return w
95-
96-
97-
def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
98-
p_obj_prev = np.inf
99-
n_features = X.shape[1]
100-
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
101-
n_groups = len(grp_ptr) - 1
102-
# Initialization
103-
grads = X.T @ y / len(y)
104-
G = X.T @ X
105-
lipschitz = np.zeros(n_groups, dtype=X.dtype)
106-
for g in range(n_groups):
107-
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
108-
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
109-
w = np.zeros(n_features)
110-
# BCD
111-
for n_iter in range(max_iter):
112-
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr)
113-
if n_iter % check_freq == 0:
114-
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices)
115-
if p_obj_prev - p_obj < tol:
116-
print("Convergence reached!")
117-
break
118-
print(f"iter {n_iter} :: p_obj {p_obj}")
119-
p_obj_prev = p_obj
120-
return w
121-
122-
123-
if __name__ == "__main__":
124-
n_samples, n_features = 1_000_000, 300
125-
X, y, w_star = make_correlated_data(
126-
n_samples=n_samples, n_features=n_features, random_state=0)
127-
alpha_max = norm(X.T @ y, ord=np.inf)
128-
129-
# Hyperparameters
130-
max_iter = 1000
131-
tol = 1e-8
132-
reg = 0.1
133-
group_size = 3
134-
135-
alpha = alpha_max * reg / n_samples
136-
137-
# Lasso
138-
print("#" * 15)
139-
print("Lasso")
140-
print("#" * 15)
141-
start = time()
142-
w = lasso(X, y, alpha, max_iter, tol)
143-
gram_lasso_time = time() - start
144-
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
145-
start = time()
146-
clf_sk.fit(X, y)
147-
celer_lasso_time = time() - start
148-
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)
149-
150-
print("\n")
151-
print("Celer: %.2f" % celer_lasso_time)
152-
print("Gram: %.2f" % gram_lasso_time)
153-
print("\n")
154-
155-
# Group Lasso
156-
print("#" * 15)
157-
print("Group Lasso")
158-
print("#" * 15)
159-
start = time()
160-
w = group_lasso(X, y, alpha, group_size, max_iter, tol)
161-
gram_group_lasso_time = time() - start
162-
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
163-
start = time()
164-
clf_celer.fit(X, y)
165-
celer_group_lasso_time = time() - start
166-
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)
167-
168-
print("\n")
169-
print("Celer: %.2f" % celer_group_lasso_time)
170-
print("Gram: %.2f" % gram_group_lasso_time)
171-
print("\n")
6+
from skglm.solvers.gram import gram_lasso, gram_group_lasso
7+
8+
9+
n_samples, n_features = 1_000_000, 300
10+
X, y, w_star = make_correlated_data(
11+
n_samples=n_samples, n_features=n_features, random_state=0)
12+
alpha_max = norm(X.T @ y, ord=np.inf)
13+
14+
# Hyperparameters
15+
max_iter = 1000
16+
tol = 1e-8
17+
reg = 0.1
18+
group_size = 3
19+
20+
alpha = alpha_max * reg / n_samples
21+
22+
# Lasso
23+
print("#" * 15)
24+
print("Lasso")
25+
print("#" * 15)
26+
start = time()
27+
w = gram_lasso(X, y, alpha, max_iter, tol)
28+
gram_lasso_time = time() - start
29+
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
30+
start = time()
31+
clf_sk.fit(X, y)
32+
celer_lasso_time = time() - start
33+
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)
34+
35+
print("\n")
36+
print("Celer: %.2f" % celer_lasso_time)
37+
print("Gram: %.2f" % gram_lasso_time)
38+
print("\n")
39+
40+
# Group Lasso
41+
print("#" * 15)
42+
print("Group Lasso")
43+
print("#" * 15)
44+
start = time()
45+
w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol)
46+
gram_group_lasso_time = time() - start
47+
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
48+
start = time()
49+
clf_celer.fit(X, y)
50+
celer_group_lasso_time = time() - start
51+
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)
52+
53+
print("\n")
54+
print("Celer: %.2f" % celer_group_lasso_time)
55+
print("Gram: %.2f" % gram_group_lasso_time)
56+
print("\n")

skglm/solvers/gram.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import numpy as np
2+
from numba import njit
3+
from numpy.linalg import norm
4+
from celer.homotopy import _grp_converter
5+
6+
from skglm.utils import BST, ST
7+
8+
9+
@njit
10+
def primal(alpha, y, X, w):
11+
r = y - X @ w
12+
p_obj = (r @ r) / (2 * len(y))
13+
return p_obj + alpha * np.sum(np.abs(w))
14+
15+
16+
@njit
17+
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices):
18+
r = y - X @ w
19+
p_obj = (r @ r) / (2 * len(y))
20+
for g in range(len(grp_ptr) - 1):
21+
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
22+
p_obj += alpha * norm(w_g, ord=2)
23+
return p_obj
24+
25+
26+
def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
27+
p_obj_prev = np.inf
28+
n_features = X.shape[1]
29+
grads = X.T @ y / len(y)
30+
G = X.T @ X
31+
lipschitz = np.zeros(n_features, dtype=X.dtype)
32+
for j in range(n_features):
33+
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
34+
w = np.zeros(n_features)
35+
# CD
36+
for n_iter in range(max_iter):
37+
cd_epoch(X, G, grads, w, alpha, lipschitz)
38+
if n_iter % check_freq == 0:
39+
p_obj = primal(alpha, y, X, w)
40+
if p_obj_prev - p_obj < tol:
41+
print("Convergence reached!")
42+
break
43+
print(f"iter {n_iter} :: p_obj {p_obj}")
44+
p_obj_prev = p_obj
45+
return w
46+
47+
48+
def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
49+
p_obj_prev = np.inf
50+
n_features = X.shape[1]
51+
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
52+
n_groups = len(grp_ptr) - 1
53+
grads = X.T @ y / len(y)
54+
G = X.T @ X
55+
lipschitz = np.zeros(n_groups, dtype=X.dtype)
56+
for g in range(n_groups):
57+
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
58+
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
59+
w = np.zeros(n_features)
60+
# BCD
61+
for n_iter in range(max_iter):
62+
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr)
63+
if n_iter % check_freq == 0:
64+
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices)
65+
if p_obj_prev - p_obj < tol:
66+
print("Convergence reached!")
67+
break
68+
print(f"iter {n_iter} :: p_obj {p_obj}")
69+
p_obj_prev = p_obj
70+
return w
71+
72+
73+
@njit
74+
def cd_epoch(X, G, grads, w, alpha, lipschitz):
75+
n_features = X.shape[1]
76+
for j in range(n_features):
77+
if lipschitz[j] == 0.:
78+
continue
79+
old_w_j = w[j]
80+
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j])
81+
if old_w_j != w[j]:
82+
grads += G[j, :] * (old_w_j - w[j]) / len(X)
83+
84+
85+
@njit
86+
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr):
87+
n_groups = len(grp_ptr) - 1
88+
for g in range(n_groups):
89+
if lipschitz[g] == 0.:
90+
continue
91+
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
92+
old_w_g = w[idx].copy()
93+
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g])
94+
diff = old_w_g - w[idx]
95+
if np.any(diff != 0.):
96+
grads += diff @ G[idx, :] / len(X)

0 commit comments

Comments
 (0)