Skip to content

Commit e3b9df2

Browse files
Merge remote-tracking branch 'origin/main' into fix_intercept_SqrtLasso
2 parents e7568bd + 5efeb8f commit e3b9df2

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,13 @@ class SqrtLasso(LinearModel, RegressorMixin):
106106
"""
107107

108108
def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
109-
<<<<<<< HEAD
109+
<< << << < HEAD
110110
tol=1e-4, verbose=0, fit_intercept=True):
111-
=======
112-
tol=1e-4, verbose=0, fit_intercept=False):
113-
>>>>>>> 69bf74f (first try, add support for fit_intercept in sqrtLasso, TODOS: review if correct, clean up, pot. add support for sparse X (not sure if that works), enhance docstring)
111+
112+
113+
== == == =
114+
tol = 1e-4, verbose = 0, fit_intercept = False):
115+
>>>>>> > 69bf74f (first try , add support for fit_intercept in sqrtLasso, TODOS: review if correct, clean up, pot. add support for sparse X (not sure if that works), enhance docstring)
114116
super().__init__()
115117
self.alpha = alpha
116118
self.max_iter = max_iter
@@ -138,11 +140,6 @@ def fit(self, X, y):
138140
self :
139141
Fitted estimator.
140142
"""
141-
<<<<<<< HEAD
142-
self.coef_ = self.path(X, y, alphas=[self.alpha])[1][0]
143-
if self.fit_intercept:
144-
self.intercept_ = self.coef_[-1]
145-
=======
146143
# self.coef_ = self.path(X, y, alphas=[self.alpha])[1][0]
147144
if self.fit_intercept:
148145
X_mean = X.mean(axis=0)
@@ -157,7 +154,6 @@ def fit(self, X, y):
157154

158155
if self.fit_intercept:
159156
self.intercept_ = y_mean - X_mean @ self.coef_
160-
>>>>>>> 69bf74f (first try, add support for fit_intercept in sqrtLasso, TODOS: review if correct, clean up, pot. add support for sparse X (not sure if that works), enhance docstring)
161157
else:
162158
self.intercept_ = 0.
163159
return self

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def test_PDCD_WS(with_dual_init):
7777
penalty = L1(alpha)
7878

7979
w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
80+
<<<<<<< HEAD
8081
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
82+
=======
83+
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
84+
>>>>>>> origin/main
8185
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
8286

8387

0 commit comments

Comments
 (0)