Skip to content

Commit 1173877

Browse files
refactor
1 parent f036743 commit 1173877

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

skglm/solvers/lbfgs.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,45 +51,39 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
5151
datafit.initialize(X, y)
5252

5353
def objective(w):
54+
w_features = w[:n_features]
55+
Xw = X @ w_features
5456
if self.fit_intercept:
55-
Xw = X @ w[:-1] + w[-1]
56-
datafit_value = datafit.value(y, w[:-1], Xw)
57-
penalty_value = penalty.value(w[:-1])
58-
else:
59-
Xw = X @ w
60-
datafit_value = datafit.value(y, w, Xw)
61-
penalty_value = penalty.value(w)
62-
57+
Xw += w[-1]
58+
datafit_value = datafit.value(y, w_features, Xw)
59+
penalty_value = penalty.value(w_features)
6360
return datafit_value + penalty_value
6461

6562
def d_jac(w):
63+
w_features = w[:n_features]
64+
Xw = X @ w_features
65+
if self.fit_intercept:
66+
Xw += w[-1]
67+
datafit_grad = datafit.gradient(X, y, Xw)
68+
penalty_grad = penalty.gradient(w_features)
6669
if self.fit_intercept:
67-
Xw = X @ w[:-1] + w[-1]
68-
datafit_grad = datafit.gradient(X, y, Xw)
69-
penalty_grad = penalty.gradient(w[:-1])
70-
intercept_grad = datafit.intercept_update_step(y, Xw)
70+
intercept_grad = datafit.raw_grad(y, Xw).sum()
7171
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
7272
else:
73-
Xw = X @ w
74-
datafit_grad = datafit.gradient(X, y, Xw)
75-
penalty_grad = penalty.gradient(w)
76-
7773
return datafit_grad + penalty_grad
7874

7975
def s_jac(w):
76+
w_features = w[:n_features]
77+
Xw = X @ w_features
78+
if self.fit_intercept:
79+
Xw += w[-1]
80+
datafit_grad = datafit.gradient_sparse(
81+
X.data, X.indptr, X.indices, y, Xw)
82+
penalty_grad = penalty.gradient(w_features)
8083
if self.fit_intercept:
81-
Xw = X @ w[:-1] + w[-1]
82-
datafit_grad = datafit.gradient_sparse(
83-
X.data, X.indptr, X.indices, y, Xw)
84-
penalty_grad = penalty.gradient(w[:-1])
85-
intercept_grad = datafit.intercept_update_step(y, Xw)
84+
intercept_grad = datafit.raw_grad(y, Xw).sum()
8685
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
8786
else:
88-
Xw = X @ w
89-
datafit_grad = datafit.gradient_sparse(
90-
X.data, X.indptr, X.indices, y, Xw)
91-
penalty_grad = penalty.gradient(w)
92-
9387
return datafit_grad + penalty_grad
9488

9589
def callback_post_iter(w_k):

0 commit comments

Comments
 (0)