Skip to content

Commit 9fe4bae

Browse files
authored
FIX time overhead when fitting with Lasso estimator (#129)
1 parent cca6d48 commit 9fe4bae

File tree

2 files changed

+46
-24
lines changed

2 files changed

+46
-24
lines changed

skglm/datafits/single_task.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class Quadratic(BaseDatafit):
1717
Attributes
1818
----------
1919
Xty : array, shape (n_features,)
20-
Pre-computed quantity used during the gradient evaluation. Equal to X.T @ y.
20+
Pre-computed quantity used during the gradient evaluation.
21+
Equal to ``X.T @ y``.
2122
2223
lipschitz : array, shape (n_features,)
2324
The coordinatewise gradient Lipschitz constants. Equal to
@@ -50,7 +51,7 @@ def params_to_dict(self):
5051
def initialize(self, X, y):
5152
self.Xty = X.T @ y
5253
n_features = X.shape[1]
53-
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
54+
5455
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
5556
for j in range(n_features):
5657
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
@@ -59,9 +60,6 @@ def initialize_sparse(self, X_data, X_indptr, X_indices, y):
5960
n_features = len(X_indptr) - 1
6061
self.Xty = np.zeros(n_features, dtype=X_data.dtype)
6162

62-
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
63-
self.global_lipschitz /= len(y)
64-
6563
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
6664
for j in range(n_features):
6765
nrm2 = 0.
@@ -73,6 +71,13 @@ def initialize_sparse(self, X_data, X_indptr, X_indices, y):
7371
self.lipschitz[j] = nrm2 / len(y)
7472
self.Xty[j] = xty
7573

74+
def init_global_lipschitz(self, X, y):
75+
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
76+
77+
def init_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
78+
self.global_lipschitz = spectral_norm(
79+
X_data, X_indptr, X_indices, len(y)) ** 2 / len(y)
80+
7681
def value(self, y, w, Xw):
7782
return np.sum((y - Xw) ** 2) / (2 * len(Xw))
7883

@@ -155,19 +160,22 @@ def raw_hessian(self, y, Xw):
155160

156161
def initialize(self, X, y):
157162
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
158-
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)
159163

160164
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
161165
n_features = len(X_indptr) - 1
162166

163-
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
164-
self.global_lipschitz /= 4 * len(y)
165-
166167
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
167168
for j in range(n_features):
168169
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
169170
self.lipschitz[j] = (Xj ** 2).sum() / (len(y) * 4)
170171

172+
def init_global_lipschitz(self, X, y):
173+
self.global_lipschitz = norm(X, ord=2) ** 2 / (4 * len(y))
174+
175+
def init_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
176+
self.global_lipschitz = spectral_norm(
177+
X_data, X_indptr, X_indices, len(y)) ** 2 / (4 * len(y))
178+
171179
def value(self, y, w, Xw):
172180
return np.log(1. + np.exp(- y * Xw)).sum() / len(y)
173181

@@ -235,23 +243,27 @@ def params_to_dict(self):
235243
def initialize(self, yXT, y):
236244
n_features = yXT.shape[1]
237245
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
238-
self.global_lipschitz = norm(yXT, ord=2) ** 2
246+
239247
for j in range(n_features):
240248
self.lipschitz[j] = norm(yXT[:, j]) ** 2
241249

242250
def initialize_sparse(self, yXT_data, yXT_indptr, yXT_indices, y):
243251
n_features = len(yXT_indptr) - 1
244252

245-
self.global_lipschitz = spectral_norm(
246-
yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2
247-
248253
self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
249254
for j in range(n_features):
250255
nrm2 = 0.
251256
for idx in range(yXT_indptr[j], yXT_indptr[j + 1]):
252257
nrm2 += yXT_data[idx] ** 2
253258
self.lipschitz[j] = nrm2
254259

260+
def init_global_lipschitz(self, yXT, y):
261+
self.global_lipschitz = norm(yXT, ord=2) ** 2
262+
263+
def init_global_lipschitz_sparse(self, yXT_data, yXT_indptr, yXT_indices, y):
264+
self.global_lipschitz = spectral_norm(
265+
yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2
266+
255267
def value(self, y, w, yXTw):
256268
return (yXTw ** 2).sum() / 2 - np.sum(w)
257269

@@ -328,24 +340,26 @@ def params_to_dict(self):
328340
def initialize(self, X, y):
329341
n_features = X.shape[1]
330342
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
331-
self.global_lipschitz = 0.
332343
for j in range(n_features):
333344
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
334-
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)
335345

336346
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
337347
n_features = len(X_indptr) - 1
338348

339-
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
340-
self.global_lipschitz /= len(y)
341-
342349
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
343350
for j in range(n_features):
344351
nrm2 = 0.
345352
for idx in range(X_indptr[j], X_indptr[j + 1]):
346353
nrm2 += X_data[idx] ** 2
347354
self.lipschitz[j] = nrm2 / len(y)
348355

356+
def init_global_lipschitz(self, X, y):
357+
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
358+
359+
def init_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
360+
self.global_lipschitz = spectral_norm(
361+
X_data, X_indptr, X_indices, len(y)) ** 2 / len(y)
362+
349363
def value(self, y, w, Xw):
350364
n_samples = len(y)
351365
res = 0.

skglm/solvers/fista.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,31 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3939
p_objs_out = []
4040
n_samples, n_features = X.shape
4141
all_features = np.arange(n_features)
42+
X_is_sparse = issparse(X)
4243
t_new = 1.
4344

4445
w = w_init.copy() if w_init is not None else np.zeros(n_features)
4546
z = w_init.copy() if w_init is not None else np.zeros(n_features)
4647
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
4748

48-
if hasattr(datafit, "global_lipschitz"):
49-
lipschitz = datafit.global_lipschitz
50-
else:
51-
# TODO: OR line search
52-
raise Exception("Line search is not yet implemented for FISTA solver.")
49+
try:
50+
if X_is_sparse:
51+
datafit.init_global_lipschitz_sparse(X.data, X.indptr, X.indices, y)
52+
else:
53+
datafit.init_global_lipschitz(X, y)
54+
except AttributeError:
55+
sparse_suffix = '_sparse' if X_is_sparse else ''
56+
57+
raise Exception(
58+
"Datafit is not compatible with FISTA solver.\n Datafit must "
59+
f"implement `init_global_lipschitz{sparse_suffix}` method")
5360

61+
lipschitz = datafit.global_lipschitz
5462
for n_iter in range(self.max_iter):
5563
t_old = t_new
5664
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
5765
w_old = w.copy()
58-
if issparse(X):
66+
if X_is_sparse:
5967
grad = construct_grad_sparse(
6068
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
6169
else:

0 commit comments

Comments
 (0)