Skip to content

Commit d80c9aa

Browse files
authored
ENH - add support for sparse dataset in LBFGS (#173)
1 parent f6f0875 commit d80c9aa

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

skglm/datafits/single_task.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from skglm.datafits.base import BaseDatafit
77
from skglm.utils.sparse_ops import spectral_norm
8+
from skglm.solvers.prox_newton import _sparse_xj_dot
89

910

1011
class Quadratic(BaseDatafit):
@@ -185,6 +186,16 @@ def gradient_scalar(self, X, y, w, Xw, j):
185186
def gradient(self, X, y, Xw):
186187
return X.T @ self.raw_grad(y, Xw)
187188

189+
def gradient_sparse(self, X_data, X_indptr, X_indices, y, Xw):
190+
n_features = X_indptr.shape[0] - 1
191+
out = np.zeros(n_features, dtype=X_data.dtype)
192+
raw_grad = self.raw_grad(y, Xw)
193+
194+
for j in range(n_features):
195+
out[j] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad)
196+
197+
return out
198+
188199
def full_grad_sparse(
189200
self, X_data, X_indptr, X_indices, y, Xw):
190201
n_features = X_indptr.shape[0] - 1
@@ -654,6 +665,17 @@ def gradient(self, X, y, Xw):
654665
"""Compute gradient of the datafit."""
655666
return X.T @ self.raw_grad(y, Xw)
656667

668+
def gradient_sparse(self, X_data, X_indptr, X_indices, y, Xw):
669+
"""Compute gradient of the datafit in case ``X`` is sparse."""
670+
n_features = X_indptr.shape[0] - 1
671+
out = np.zeros(n_features, dtype=X_data.dtype)
672+
raw_grad = self.raw_grad(y, Xw)
673+
674+
for j in range(n_features):
675+
out[j] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad)
676+
677+
return out
678+
657679
def initialize(self, X, y):
658680
"""Initialize the datafit attributes."""
659681
tm, s = y

skglm/solvers/lbfgs.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import scipy.optimize
66
from numpy.linalg import norm
7+
from scipy.sparse import issparse
78

89
from skglm.solvers import BaseSolver
910

@@ -33,27 +34,34 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False):
3334

3435
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3536

36-
def objective_function(w):
37+
def objective(w):
3738
Xw = X @ w
3839
datafit_value = datafit.value(y, w, Xw)
3940
penalty_value = penalty.value(w)
4041

4142
return datafit_value + penalty_value
4243

43-
def jacobian_function(w):
44+
def d_jac(w):
4445
Xw = X @ w
4546
datafit_grad = datafit.gradient(X, y, Xw)
4647
penalty_grad = penalty.gradient(w)
4748

4849
return datafit_grad + penalty_grad
4950

51+
def s_jac(w):
52+
Xw = X @ w
53+
datafit_grad = datafit.gradient_sparse(X.data, X.indptr, X.indices, y, Xw)
54+
penalty_grad = penalty.gradient(w)
55+
56+
return datafit_grad + penalty_grad
57+
5058
def callback_post_iter(w_k):
5159
# save p_obj
52-
p_obj = objective_function(w_k)
60+
p_obj = objective(w_k)
5361
p_objs_out.append(p_obj)
5462

5563
if self.verbose:
56-
grad = jacobian_function(w_k)
64+
grad = jac(w_k)
5765
stop_crit = norm(grad)
5866

5967
it = len(p_objs_out)
@@ -64,11 +72,12 @@ def callback_post_iter(w_k):
6472

6573
n_features = X.shape[1]
6674
w = np.zeros(n_features) if w_init is None else w_init
75+
jac = s_jac if issparse(X) else d_jac
6776
p_objs_out = []
6877

6978
result = scipy.optimize.minimize(
70-
fun=objective_function,
71-
jac=jacobian_function,
79+
fun=objective,
80+
jac=jac,
7281
x0=w,
7382
method="L-BFGS-B",
7483
options=dict(

skglm/tests/test_lbfgs_solver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
from skglm.utils.data import make_correlated_data, make_dummy_survival_data
1313

1414

15-
def test_lbfgs_L2_logreg():
15+
@pytest.mark.parametrize("X_sparse", [True, False])
16+
def test_lbfgs_L2_logreg(X_sparse):
1617
reg = 1.
18+
X_density = 1. if not X_sparse else 0.5
1719
n_samples, n_features = 100, 50
1820

1921
X, y, _ = make_correlated_data(
20-
n_samples, n_features, random_state=0)
22+
n_samples, n_features, random_state=0, X_density=X_density,
23+
)
2124
y = np.sign(y)
2225

2326
# fit L-BFGS

0 commit comments

Comments
 (0)