Skip to content

Commit 791c8cd

Browse files
fix pytest, should work now
1 parent 91a5608 commit 791c8cd

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class SqrtLasso(LinearModel, RegressorMixin):
107107

108108
def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
109109
tol=1e-4, verbose=0, fit_intercept=True):
110+
110111
super().__init__()
111112
self.alpha = alpha
112113
self.max_iter = max_iter
@@ -147,7 +148,8 @@ def fit(self, X, y):
147148
self.coef_ = self.path(X_centered, y_centered, alphas=[self.alpha])[1][0]
148149

149150
if self.fit_intercept:
150-
self.intercept_ = y_mean - X_mean @ self.coef_
151+
self.intercept_ = y_mean - X_mean @ self.coef_[:-1]
152+
self.coef_ = self.coef_[:-1]
151153
else:
152154
self.intercept_ = 0.
153155
return self

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,24 @@ def test_sqrt_lasso_with_intercept():
109109
y_pred = sqrt.predict(X)
110110
assert y_pred.shape == y.shape
111111

112+
# Check that coef_ and intercept_ are handled separately
113+
assert sqrt.coef_.shape == (20,)
114+
assert np.isscalar(sqrt.intercept_)
115+
116+
# Confirm prediction matches manual computation
117+
manual_pred = X @ sqrt.coef_ + sqrt.intercept_
118+
np.testing.assert_allclose(manual_pred, y_pred, rtol=1e-6)
119+
120+
np.testing.assert_allclose(
121+
sqrt.intercept_, y.mean() - X.mean(axis=0) @ sqrt.coef_, rtol=1e-6
122+
)
123+
124+
sqrt_no_intercept = SqrtLasso(
125+
alpha=alpha * scal, fit_intercept=False, tol=1e-8).fit(X, y)
126+
assert np.isscalar(sqrt_no_intercept.intercept_)
127+
np.testing.assert_allclose(sqrt_no_intercept.predict(
128+
X), X @ sqrt_no_intercept.coef_ + sqrt_no_intercept.intercept_)
129+
112130

113131
if __name__ == '__main__':
114132
pass

0 commit comments

Comments
 (0)