|
1 | 1 | from time import time
|
2 | 2 | import numpy as np
|
3 | 3 | from numpy.linalg import norm
|
4 |
| -from numba import njit |
5 | 4 | from celer import Lasso, GroupLasso
|
6 | 5 | from benchopt.datasets.simulated import make_correlated_data
|
7 |
| -from skglm.utils import BST, ST |
8 |
| - |
9 |
| - |
10 |
| -def _grp_converter(groups, n_features): |
11 |
| - if isinstance(groups, int): |
12 |
| - grp_size = groups |
13 |
| - if n_features % grp_size != 0: |
14 |
| - raise ValueError("n_features (%d) is not a multiple of the desired" |
15 |
| - " group size (%d)" % (n_features, grp_size)) |
16 |
| - n_groups = n_features // grp_size |
17 |
| - grp_ptr = grp_size * np.arange(n_groups + 1) |
18 |
| - grp_indices = np.arange(n_features) |
19 |
| - elif isinstance(groups, list) and isinstance(groups[0], int): |
20 |
| - grp_indices = np.arange(n_features).astype(np.int32) |
21 |
| - grp_ptr = np.cumsum(np.hstack([[0], groups])) |
22 |
| - elif isinstance(groups, list) and isinstance(groups[0], list): |
23 |
| - grp_sizes = np.array([len(ls) for ls in groups]) |
24 |
| - grp_ptr = np.cumsum(np.hstack([[0], grp_sizes])) |
25 |
| - grp_indices = np.array([idx for grp in groups for idx in grp]) |
26 |
| - else: |
27 |
| - raise ValueError("Unsupported group format.") |
28 |
| - return grp_ptr.astype(np.int32), grp_indices.astype(np.int32) |
29 |
| - |
30 |
| - |
31 |
| -@njit |
32 |
| -def primal(alpha, y, X, w): |
33 |
| - r = y - X @ w |
34 |
| - p_obj = (r @ r) / (2 * len(y)) |
35 |
| - return p_obj + alpha * np.sum(np.abs(w)) |
36 |
| - |
37 |
| - |
38 |
| -@njit |
39 |
| -def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): |
40 |
| - r = y - X @ w |
41 |
| - p_obj = (r @ r) / (2 * len(y)) |
42 |
| - for g in range(len(grp_ptr) - 1): |
43 |
| - w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] |
44 |
| - p_obj += alpha * norm(w_g, ord=2) |
45 |
| - return p_obj |
46 |
| - |
47 |
| - |
48 |
| -@njit |
49 |
| -def cd_epoch(X, G, grads, w, alpha, lipschitz): |
50 |
| - n_features = X.shape[1] |
51 |
| - for j in range(n_features): |
52 |
| - if lipschitz[j] == 0.: |
53 |
| - continue |
54 |
| - old_w_j = w[j] |
55 |
| - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) |
56 |
| - if old_w_j != w[j]: |
57 |
| - grads += G[j, :] * (old_w_j - w[j]) / len(X) |
58 |
| - |
59 |
| - |
60 |
| -@njit |
61 |
| -def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): |
62 |
| - n_groups = len(grp_ptr) - 1 |
63 |
| - for g in range(n_groups): |
64 |
| - if lipschitz[g] == 0.: |
65 |
| - continue |
66 |
| - idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] |
67 |
| - old_w_g = w[idx].copy() |
68 |
| - w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) |
69 |
| - diff = old_w_g - w[idx] |
70 |
| - if np.any(diff != 0.): |
71 |
| - grads += diff @ G[idx, :] / len(X) |
72 |
| - |
73 |
| - |
74 |
| -def lasso(X, y, alpha, max_iter, tol, check_freq=10): |
75 |
| - p_obj_prev = np.inf |
76 |
| - n_features = X.shape[1] |
77 |
| - # Initialization |
78 |
| - grads = X.T @ y / len(y) |
79 |
| - G = X.T @ X |
80 |
| - lipschitz = np.zeros(n_features, dtype=X.dtype) |
81 |
| - for j in range(n_features): |
82 |
| - lipschitz[j] = (X[:, j] ** 2).sum() / len(y) |
83 |
| - w = np.zeros(n_features) |
84 |
| - # CD |
85 |
| - for n_iter in range(max_iter): |
86 |
| - cd_epoch(X, G, grads, w, alpha, lipschitz) |
87 |
| - if n_iter % check_freq == 0: |
88 |
| - p_obj = primal(alpha, y, X, w) |
89 |
| - if p_obj_prev - p_obj < tol: |
90 |
| - print("Convergence reached!") |
91 |
| - break |
92 |
| - print(f"iter {n_iter} :: p_obj {p_obj}") |
93 |
| - p_obj_prev = p_obj |
94 |
| - return w |
95 |
| - |
96 |
| - |
97 |
| -def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): |
98 |
| - p_obj_prev = np.inf |
99 |
| - n_features = X.shape[1] |
100 |
| - grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) |
101 |
| - n_groups = len(grp_ptr) - 1 |
102 |
| - # Initialization |
103 |
| - grads = X.T @ y / len(y) |
104 |
| - G = X.T @ X |
105 |
| - lipschitz = np.zeros(n_groups, dtype=X.dtype) |
106 |
| - for g in range(n_groups): |
107 |
| - X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] |
108 |
| - lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) |
109 |
| - w = np.zeros(n_features) |
110 |
| - # BCD |
111 |
| - for n_iter in range(max_iter): |
112 |
| - bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) |
113 |
| - if n_iter % check_freq == 0: |
114 |
| - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) |
115 |
| - if p_obj_prev - p_obj < tol: |
116 |
| - print("Convergence reached!") |
117 |
| - break |
118 |
| - print(f"iter {n_iter} :: p_obj {p_obj}") |
119 |
| - p_obj_prev = p_obj |
120 |
| - return w |
121 |
| - |
122 |
| - |
123 |
| -if __name__ == "__main__": |
124 |
| - n_samples, n_features = 1_000_000, 300 |
125 |
| - X, y, w_star = make_correlated_data( |
126 |
| - n_samples=n_samples, n_features=n_features, random_state=0) |
127 |
| - alpha_max = norm(X.T @ y, ord=np.inf) |
128 |
| - |
129 |
| - # Hyperparameters |
130 |
| - max_iter = 1000 |
131 |
| - tol = 1e-8 |
132 |
| - reg = 0.1 |
133 |
| - group_size = 3 |
134 |
| - |
135 |
| - alpha = alpha_max * reg / n_samples |
136 |
| - |
137 |
| - # Lasso |
138 |
| - print("#" * 15) |
139 |
| - print("Lasso") |
140 |
| - print("#" * 15) |
141 |
| - start = time() |
142 |
| - w = lasso(X, y, alpha, max_iter, tol) |
143 |
| - gram_lasso_time = time() - start |
144 |
| - clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) |
145 |
| - start = time() |
146 |
| - clf_sk.fit(X, y) |
147 |
| - celer_lasso_time = time() - start |
148 |
| - np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) |
149 |
| - |
150 |
| - print("\n") |
151 |
| - print("Celer: %.2f" % celer_lasso_time) |
152 |
| - print("Gram: %.2f" % gram_lasso_time) |
153 |
| - print("\n") |
154 |
| - |
155 |
| - # Group Lasso |
156 |
| - print("#" * 15) |
157 |
| - print("Group Lasso") |
158 |
| - print("#" * 15) |
159 |
| - start = time() |
160 |
| - w = group_lasso(X, y, alpha, group_size, max_iter, tol) |
161 |
| - gram_group_lasso_time = time() - start |
162 |
| - clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) |
163 |
| - start = time() |
164 |
| - clf_celer.fit(X, y) |
165 |
| - celer_group_lasso_time = time() - start |
166 |
| - np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) |
167 |
| - |
168 |
| - print("\n") |
169 |
| - print("Celer: %.2f" % celer_group_lasso_time) |
170 |
| - print("Gram: %.2f" % gram_group_lasso_time) |
171 |
| - print("\n") |
| 6 | +from skglm.solvers.gram import gram_lasso, gram_group_lasso |
| 7 | + |
| 8 | + |
| 9 | +n_samples, n_features = 1_000_000, 300 |
| 10 | +X, y, w_star = make_correlated_data( |
| 11 | + n_samples=n_samples, n_features=n_features, random_state=0) |
| 12 | +alpha_max = norm(X.T @ y, ord=np.inf) |
| 13 | + |
| 14 | +# Hyperparameters |
| 15 | +max_iter = 1000 |
| 16 | +tol = 1e-8 |
| 17 | +reg = 0.1 |
| 18 | +group_size = 3 |
| 19 | + |
| 20 | +alpha = alpha_max * reg / n_samples |
| 21 | + |
| 22 | +# Lasso |
| 23 | +print("#" * 15) |
| 24 | +print("Lasso") |
| 25 | +print("#" * 15) |
| 26 | +start = time() |
| 27 | +w = gram_lasso(X, y, alpha, max_iter, tol) |
| 28 | +gram_lasso_time = time() - start |
| 29 | +clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) |
| 30 | +start = time() |
| 31 | +clf_sk.fit(X, y) |
| 32 | +celer_lasso_time = time() - start |
| 33 | +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) |
| 34 | + |
| 35 | +print("\n") |
| 36 | +print("Celer: %.2f" % celer_lasso_time) |
| 37 | +print("Gram: %.2f" % gram_lasso_time) |
| 38 | +print("\n") |
| 39 | + |
| 40 | +# Group Lasso |
| 41 | +print("#" * 15) |
| 42 | +print("Group Lasso") |
| 43 | +print("#" * 15) |
| 44 | +start = time() |
| 45 | +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol) |
| 46 | +gram_group_lasso_time = time() - start |
| 47 | +clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) |
| 48 | +start = time() |
| 49 | +clf_celer.fit(X, y) |
| 50 | +celer_group_lasso_time = time() - start |
| 51 | +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) |
| 52 | + |
| 53 | +print("\n") |
| 54 | +print("Celer: %.2f" % celer_group_lasso_time) |
| 55 | +print("Gram: %.2f" % gram_group_lasso_time) |
| 56 | +print("\n") |
0 commit comments