Skip to content

Commit af666a5

Browse files
committed
test with large data
1 parent 999b89e commit af666a5

File tree

2 files changed

+106
-48
lines changed

2 files changed

+106
-48
lines changed

gram_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# data available at https://www.dropbox.com/sh/32b3mr3xghi496g/AACNRS_NOsUXU-hrSLixNg0ja?dl=0
2+
3+
4+
import time
5+
from numpy.linalg import norm
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
from celer import GroupLasso
9+
from skglm.gram_solver import group_lasso
10+
11+
X = np.load("design_matrix.npy")
12+
y = np.load("target.npy")
13+
groups = np.load("groups.npy")
14+
weights = np.load("weights.npy")
15+
# grps = [list(np.where(groups == i)[0]) for i in range(1, 33)]
16+
17+
18+
alpha_ratio = 1e-2
19+
n_alphas = 10
20+
21+
22+
# Case 1: slower runtime for (very) small alphas
23+
# alpha_max = 0.003471727067743962
24+
alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y)
25+
alpha = alpha_max / 50
26+
clf = GroupLasso(fit_intercept=False,
27+
groups=5, alpha=alpha, verbose=1)
28+
29+
t0 = time.time()
30+
clf.fit(X, y)
31+
t1 = time.time()
32+
33+
print(f"Celer: {t1 - t0:.3f} s")
34+
35+
t0 = time.time()
36+
res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10)
37+
t1 = time.time()
38+
39+
print(f"skglm gram: {t1 - t0:.3f} s")
40+
41+
# # Case 2: slower runtime for (very) small alphas with weights
42+
# # alpha_max_w = 0.0001897719130007628
43+
# alpha_max_w = np.max(norm((X.T @ y).reshape(-1, 5) /
44+
# weights[:, None], axis=1)) / len(y)
45+
46+
47+
# alpha_ratio = 0.1
48+
# grid_w = np.geomspace(alpha_max_w*alpha_ratio, alpha_max_w, n_alphas)[::-1]
49+
# clf = GroupLasso(fit_intercept=False,
50+
# weights=weights, groups=grps, warm_start=True)
51+
52+
# # for alpha in grid_w:
53+
# # clf.alpha = alpha
54+
# # t0 = time.time()
55+
# # clf.fit(X, y)
56+
# t1 = time.time()
57+
# print(f"Finished tuning with {alpha:.2e}. Took {t1-t0:.2f} seconds!")

skglm/gram_solver.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -120,51 +120,52 @@ def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
120120
return w
121121

122122

123-
n_samples, n_features = 1_000_000, 300
124-
X, y, w_star = make_correlated_data(
125-
n_samples=n_samples, n_features=n_features, random_state=0)
126-
alpha_max = norm(X.T @ y, ord=np.inf)
127-
128-
# Hyperparameters
129-
max_iter = 1000
130-
tol = 1e-8
131-
reg = 0.1
132-
group_size = 3
133-
134-
alpha = alpha_max * reg / n_samples
135-
136-
# Lasso
137-
print("#" * 15)
138-
print("Lasso")
139-
print("#" * 15)
140-
start = time()
141-
w = lasso(X, y, alpha, max_iter, tol)
142-
gram_lasso_time = time() - start
143-
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
144-
start = time()
145-
clf_sk.fit(X, y)
146-
celer_lasso_time = time() - start
147-
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)
148-
149-
print("\n")
150-
print("Celer: %.2f" % celer_lasso_time)
151-
print("Gram: %.2f" % gram_lasso_time)
152-
print("\n")
153-
154-
# Group Lasso
155-
print("#" * 15)
156-
print("Group Lasso")
157-
print("#" * 15)
158-
start = time()
159-
w = group_lasso(X, y, alpha, group_size, max_iter, tol)
160-
gram_group_lasso_time = time() - start
161-
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
162-
start = time()
163-
clf_celer.fit(X, y)
164-
celer_group_lasso_time = time() - start
165-
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)
166-
167-
print("\n")
168-
print("Celer: %.2f" % celer_group_lasso_time)
169-
print("Gram: %.2f" % gram_group_lasso_time)
170-
print("\n")
123+
if __name__ == "__main__":
124+
n_samples, n_features = 1_000_000, 300
125+
X, y, w_star = make_correlated_data(
126+
n_samples=n_samples, n_features=n_features, random_state=0)
127+
alpha_max = norm(X.T @ y, ord=np.inf)
128+
129+
# Hyperparameters
130+
max_iter = 1000
131+
tol = 1e-8
132+
reg = 0.1
133+
group_size = 3
134+
135+
alpha = alpha_max * reg / n_samples
136+
137+
# Lasso
138+
print("#" * 15)
139+
print("Lasso")
140+
print("#" * 15)
141+
start = time()
142+
w = lasso(X, y, alpha, max_iter, tol)
143+
gram_lasso_time = time() - start
144+
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
145+
start = time()
146+
clf_sk.fit(X, y)
147+
celer_lasso_time = time() - start
148+
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)
149+
150+
print("\n")
151+
print("Celer: %.2f" % celer_lasso_time)
152+
print("Gram: %.2f" % gram_lasso_time)
153+
print("\n")
154+
155+
# Group Lasso
156+
print("#" * 15)
157+
print("Group Lasso")
158+
print("#" * 15)
159+
start = time()
160+
w = group_lasso(X, y, alpha, group_size, max_iter, tol)
161+
gram_group_lasso_time = time() - start
162+
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
163+
start = time()
164+
clf_celer.fit(X, y)
165+
celer_group_lasso_time = time() - start
166+
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)
167+
168+
print("\n")
169+
print("Celer: %.2f" % celer_group_lasso_time)
170+
print("Gram: %.2f" % gram_group_lasso_time)
171+
print("\n")

0 commit comments

Comments
 (0)