Skip to content

Gram-based CD/BCD/FISTA solvers for (group)Lasso when n_samples >> n_features #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions gram_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# data available at https://www.dropbox.com/sh/32b3mr3xghi496g/AACNRS_NOsUXU-hrSLixNg0ja?dl=0


import time
from numpy.linalg import norm
import matplotlib.pyplot as plt
import numpy as np
from celer import GroupLasso
from skglm.solvers.gram import gram_group_lasso

X = np.load("design_matrix.npy")
y = np.load("target.npy")
groups = np.load("groups.npy")
weights = np.load("weights.npy")
# grps = [list(np.where(groups == i)[0]) for i in range(1, 33)]


alpha_ratio = 1e-2
n_alphas = 10


# Case 1: slower runtime for (very) small alphas
# alpha_max = 0.003471727067743962
alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y)
alpha = alpha_max / 100
clf = GroupLasso(fit_intercept=False,
groups=5, alpha=alpha, verbose=1)

t0 = time.time()
clf.fit(X, y)
t1 = time.time()

print(f"Celer: {t1 - t0:.3f} s")

# beware: stopping criterion is not the same, tol here needs to be lower
# to get meaningful comparison
t0 = time.time()
res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10)
t1 = time.time()

print(f"skglm gram: {t1 - t0:.3f} s")

# TODO support weights in gram solver
56 changes: 56 additions & 0 deletions skglm/gram_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from time import time
import numpy as np
from numpy.linalg import norm
from celer import Lasso, GroupLasso
from benchopt.datasets.simulated import make_correlated_data
from skglm.solvers.gram import gram_lasso, gram_group_lasso


n_samples, n_features = 1_000_000, 300
X, y, w_star = make_correlated_data(
n_samples=n_samples, n_features=n_features, random_state=0)
alpha_max = norm(X.T @ y, ord=np.inf)

# Hyperparameters
max_iter = 1000
tol = 1e-8
reg = 0.1
group_size = 3

alpha = alpha_max * reg / n_samples

# Lasso
print("#" * 15)
print("Lasso")
print("#" * 15)
start = time()
w = gram_lasso(X, y, alpha, max_iter, tol)
gram_lasso_time = time() - start
clf_sk = Lasso(alpha, tol=tol, fit_intercept=False)
start = time()
clf_sk.fit(X, y)
celer_lasso_time = time() - start
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5)

print("\n")
print("Celer: %.2f" % celer_lasso_time)
print("Gram: %.2f" % gram_lasso_time)
print("\n")

# Group Lasso
print("#" * 15)
print("Group Lasso")
print("#" * 15)
start = time()
w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol)
gram_group_lasso_time = time() - start
clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False)
start = time()
clf_celer.fit(X, y)
celer_group_lasso_time = time() - start
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1)

print("\n")
print("Celer: %.2f" % celer_group_lasso_time)
print("Gram: %.2f" % gram_group_lasso_time)
print("\n")
96 changes: 96 additions & 0 deletions skglm/solvers/gram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
from numba import njit
from numpy.linalg import norm
from celer.homotopy import _grp_converter

from skglm.utils import BST, ST


@njit
def primal(alpha, y, X, w):
r = y - X @ w
p_obj = (r @ r) / (2 * len(y))
return p_obj + alpha * np.sum(np.abs(w))


@njit
def primal_grp(alpha, y, X, w, grp_ptr, grp_indices):
r = y - X @ w
p_obj = (r @ r) / (2 * len(y))
for g in range(len(grp_ptr) - 1):
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
p_obj += alpha * norm(w_g, ord=2)
return p_obj


def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10):
p_obj_prev = np.inf
n_features = X.shape[1]
grads = X.T @ y / len(y)
G = X.T @ X
lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
w = np.zeros(n_features)
# CD
for n_iter in range(max_iter):
cd_epoch(X, G, grads, w, alpha, lipschitz)
if n_iter % check_freq == 0:
p_obj = primal(alpha, y, X, w)
if p_obj_prev - p_obj < tol:
print("Convergence reached!")
break
print(f"iter {n_iter} :: p_obj {p_obj}")
p_obj_prev = p_obj
return w


def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50):
p_obj_prev = np.inf
n_features = X.shape[1]
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
n_groups = len(grp_ptr) - 1
grads = X.T @ y / len(y)
G = X.T @ X
lipschitz = np.zeros(n_groups, dtype=X.dtype)
for g in range(n_groups):
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
w = np.zeros(n_features)
# BCD
for n_iter in range(max_iter):
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr)
if n_iter % check_freq == 0:
p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices)
if p_obj_prev - p_obj < tol:
print("Convergence reached!")
break
print(f"iter {n_iter} :: p_obj {p_obj}")
p_obj_prev = p_obj
return w


@njit
def cd_epoch(X, G, grads, w, alpha, lipschitz):
n_features = X.shape[1]
for j in range(n_features):
if lipschitz[j] == 0.:
continue
old_w_j = w[j]
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j])
if old_w_j != w[j]:
grads += G[j, :] * (old_w_j - w[j]) / len(X)


@njit
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr):
n_groups = len(grp_ptr) - 1
for g in range(n_groups):
if lipschitz[g] == 0.:
continue
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
old_w_g = w[idx].copy()
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g])
diff = old_w_g - w[idx]
if np.any(diff != 0.):
grads += diff @ G[idx, :] / len(X)