Skip to content

Commit 51b4cfe

Browse files
committed
added weights and warm_start
1 parent 787c8c2 commit 51b4cfe

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

skglm/gram_solver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919

2020
alpha = alpha_max * reg / n_samples
2121

22+
weights = np.random.normal(2, 0.4, n_features)
23+
weights_grp = np.random.normal(2, 0.4, n_features // group_size)
24+
2225
# Lasso
2326
print("#" * 15)
2427
print("Lasso")
2528
print("#" * 15)
2629
start = time()
27-
w = gram_lasso(X, y, alpha, max_iter, tol)
30+
w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights)
2831
gram_lasso_time = time() - start
29-
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
32+
clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False)
3033
start = time()
3134
clf_sk.fit(X, y)
3235
celer_lasso_time = time() - start
@@ -42,9 +45,10 @@
4245
print("Group Lasso")
4346
print("#" * 15)
4447
start = time()
45-
w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol)
48+
w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp)
4649
gram_group_lasso_time = time() - start
47-
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
50+
clf_celer = GroupLasso(group_size, alpha, weights=weights_grp, tol=tol,
51+
fit_intercept=False)
4852
start = time()
4953
clf_celer.fit(X, y)
5054
celer_group_lasso_time = time() - start

skglm/solvers/gram.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,37 @@
77

88

99
@njit
10-
def primal(alpha, y, X, w):
10+
def primal(alpha, y, X, w, weights):
1111
r = y - X @ w
1212
p_obj = (r @ r) / (2 * len(y))
13-
return p_obj + alpha * np.sum(np.abs(w))
13+
return p_obj + alpha * np.sum(np.abs(w * weights))
1414

1515

1616
@njit
17-
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices):
17+
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights):
1818
r = y - X @ w
1919
p_obj = (r @ r) / (2 * len(y))
2020
for g in range(len(grp_ptr) - 1):
2121
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
22-
p_obj += alpha * norm(w_g, ord=2)
22+
p_obj += alpha * norm(w_g * weights[g], ord=2)
2323
return p_obj
2424

2525

26-
def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
26+
def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10):
2727
p_obj_prev = np.inf
2828
n_features = X.shape[1]
2929
grads = X.T @ y / len(y)
3030
G = X.T @ X
3131
lipschitz = np.zeros(n_features, dtype=X.dtype)
3232
for j in range(n_features):
3333
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
34-
w = np.zeros(n_features)
34+
w = w_init if w_init is not None else np.zeros(n_features)
35+
weights = weights if weights is not None else np.ones(n_features)
3536
# CD
3637
for n_iter in range(max_iter):
37-
cd_epoch(X, G, grads, w, alpha, lipschitz)
38+
cd_epoch(X, G, grads, w, alpha, lipschitz, weights)
3839
if n_iter % check_freq == 0:
39-
p_obj = primal(alpha, y, X, w)
40+
p_obj = primal(alpha, y, X, w, weights)
4041
if p_obj_prev - p_obj < tol:
4142
print("Convergence reached!")
4243
break
@@ -45,7 +46,8 @@ def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
4546
return w
4647

4748

48-
def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
49+
def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None,
50+
check_freq=50):
4951
p_obj_prev = np.inf
5052
n_features = X.shape[1]
5153
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
@@ -56,12 +58,13 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
5658
for g in range(n_groups):
5759
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
5860
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
59-
w = np.zeros(n_features)
61+
w = w_init if w_init is not None else np.zeros(n_features)
62+
weights = weights if weights is not None else np.ones(n_groups)
6063
# BCD
6164
for n_iter in range(max_iter):
62-
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr)
65+
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights)
6366
if n_iter % check_freq == 0:
64-
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices)
67+
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights)
6568
if p_obj_prev - p_obj < tol:
6669
print("Convergence reached!")
6770
break
@@ -71,26 +74,27 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
7174

7275

7376
@njit
74-
def cd_epoch(X, G, grads, w, alpha, lipschitz):
77+
def cd_epoch(X, G, grads, w, alpha, lipschitz, weights):
7578
n_features = X.shape[1]
7679
for j in range(n_features):
77-
if lipschitz[j] == 0.:
80+
if lipschitz[j] == 0. or weights[j] == np.inf:
7881
continue
7982
old_w_j = w[j]
80-
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j])
83+
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
8184
if old_w_j != w[j]:
8285
grads += G[j, :] * (old_w_j - w[j]) / len(X)
8386

8487

8588
@njit
86-
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr):
89+
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights):
8790
n_groups = len(grp_ptr) - 1
8891
for g in range(n_groups):
89-
if lipschitz[g] == 0.:
92+
if lipschitz[g] == 0. and weights[g] == np.inf:
9093
continue
9194
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
9295
old_w_g = w[idx].copy()
93-
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g])
96+
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]
97+
* weights[g])
9498
diff = old_w_g - w[idx]
9599
if np.any(diff != 0.):
96100
grads += diff @ G[idx, :] / len(X)

0 commit comments

Comments
 (0)