-
Notifications
You must be signed in to change notification settings - Fork 41
ENH - add support for intercept in SqrtLasso
#214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
PascalCarrivain
wants to merge
54
commits into
scikit-learn-contrib:main
from
PascalCarrivain:fix_intercept_SqrtLasso
Closed
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
aabbf91
fix add support for intercept in SqrtLasso
PascalCarrivain a44c11f
fix line too long (91 > 88 characters)
PascalCarrivain d4e69f2
[CI trigger]
Badr-MOUFAD 1975371
set intercept to false in unittest
Badr-MOUFAD 6f68666
fix inverting the cases (fit_intercept)
PascalCarrivain a549bba
Merge branch 'fix_intercept_SqrtLasso' of https://github.com/PascalCa…
PascalCarrivain cdc21ea
fix undefined name 'X_coeff'
PascalCarrivain c38c00d
add fit_intercept to test_alpha_max and to self.solver_
PascalCarrivain ad47b19
add fit_intercept docstring, factorized duplicate code
PascalCarrivain e90add1
Fix imports in README (#211)
PascalCarrivain 4a7d466
DOC - Add installation instructions for conda-forge (#210)
jjerphan a1507d5
FIX - Upload documentation only when merging to ``main`` (#209)
Badr-MOUFAD c1d6a15
cd to doc before (#212)
Badr-MOUFAD 564fa61
DOC - Hide table of contents in documentation home page (#213)
Badr-MOUFAD a36c05c
Add Code of Conduct (#217)
Badr-MOUFAD 240f11a
MNT - bump to version 0.3.2 (#218)
mathurinm a1662bb
DOC - Fix link to stable documentation (#219)
Badr-MOUFAD 360b41a
Add Group Lasso with positive weigths (#221)
QB3 c1d16a2
Add prox vec L05 (#222)
mathurinm e17ffc6
DOC update contribution section (#224)
mathurinm 3fe0cc6
FIX - computation of ``subdiff_distance`` in ``WeightedGroupL2`` pena…
Badr-MOUFAD fe27eb7
API change use_acc default value to False in GreedyCD (#236)
mathurinm c560dd1
MNT change solver quantile regressor sklearn in test (#235)
mathurinm b738b3f
ENH add GroupLasso estimator with sparse X support (#228)
tomaszkacprzak dbcc207
DOC add ucurve example (#239)
mathurinm 891f75c
FIX objective function in docstring ``GroupLasso`` (#241)
mathurinm e7c25b2
DOC - update documentation (#242)
Badr-MOUFAD 27dec4b
DOC add group logistic regression example (#246)
mathurinm e0c28d7
ENH implement gradient and allow `y_i = 0` in `Poisson` datafit (#253)
mathurinm 840e8b2
ENH gradient, raw grad and hessian for Quadratic (#257)
mathurinm e60ca6a
FIX ProxNewton solver with fixpoint strategy (#259)
mathurinm 553c3d6
FEAT add WeightedQuadratic datafit to allow sample weights (#258)
sujay-pandit cac34ae
DOC L1-regularization parameter tutorial (#264)
wassimmazouz 3260ebd
FEAT implement sparse group lasso penalty and ws_strategy="fixpoint" …
mathurinm 07c49bb
ENH - check ``datafit + penalty`` compatibility with solver (#137)
PABannier ab8b9d7
MNT Update pull_request_template.md (#271)
mathurinm b35ff5b
FIX make PDCD_WS solver usable in GeneralizedLinearEstimator (#274)
mathurinm c6325e9
ENH - Adds support for L1 + L2 regularization in SparseLogisticRegres…
AnavAgrawal 7d274d8
FEAT Add QuadraticHessian datafit (#279)
tvayer 75b92cc
Docstring update for L2 penalty in SparseLogisticRegression (#281)
floriankozikowski 0837365
FIX/MNT install R with conda and use python 3.10 on test workflow (#282)
mathurinm a334d2a
MNT move citation up in readme (#284)
mathurinm 554a93c
REL release 0.4 (#285)
mathurinm 397b842
MNT start dev of 0.5 version (#286)
mathurinm 3692944
MNT add celer test dep (#288)
jolars 8985aaa
MNT add python version requirement (#289)
jolars cb284b0
MNT fix failing slope test (#287)
jolars 185c17f
MNT add tags for sklearn (#293)
mathurinm 7222e00
ENH - jit-compile datafits and penalties inside solver (#270)
Badr-MOUFAD 4629e59
first try at unit test
floriankozikowski e7568bd
first try, add support for fit_intercept in sqrtLasso, TODOS: review …
floriankozikowski e3b9df2
Merge remote-tracking branch 'origin/main' into fix_intercept_SqrtLasso
floriankozikowski 91a5608
fix merge errors
floriankozikowski 791c8cd
fix pytest, should work now
floriankozikowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||||||||
| from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic, | ||||||||||||
| _chambolle_pock_sqrt) | ||||||||||||
| from skglm.experimental.pdcd_ws import PDCD_WS | ||||||||||||
| from skglm import Lasso | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_alpha_max(): | ||||||||||||
|
|
@@ -16,7 +17,10 @@ def test_alpha_max(): | |||||||||||
|
|
||||||||||||
| sqrt_lasso = SqrtLasso(alpha=alpha_max).fit(X, y) | ||||||||||||
|
|
||||||||||||
| np.testing.assert_equal(sqrt_lasso.coef_, 0) | ||||||||||||
| if sqrt_lasso.fit_intercept: | ||||||||||||
| np.testing.assert_equal(sqrt_lasso.coef_[:-1], 0) | ||||||||||||
| else: | ||||||||||||
| np.testing.assert_equal(sqrt_lasso.coef_, 0) | ||||||||||||
|
Comment on lines
+20
to
+23
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT about this refactoring?
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_vs_statsmodels(): | ||||||||||||
|
|
@@ -31,7 +35,7 @@ def test_vs_statsmodels(): | |||||||||||
| n_alphas = 3 | ||||||||||||
| alphas = alpha_max * np.geomspace(1, 1e-2, n_alphas+1)[1:] | ||||||||||||
|
|
||||||||||||
| sqrt_lasso = SqrtLasso(tol=1e-9) | ||||||||||||
| sqrt_lasso = SqrtLasso(tol=1e-9, fit_intercept=False) | ||||||||||||
| coefs_skglm = sqrt_lasso.path(X, y, alphas)[1] | ||||||||||||
|
|
||||||||||||
| coefs_statsmodels = np.zeros((len(alphas), n_features)) | ||||||||||||
|
|
@@ -54,7 +58,7 @@ def test_prox_newton_cp(): | |||||||||||
|
|
||||||||||||
| alpha_max = norm(X.T @ y, ord=np.inf) / norm(y) | ||||||||||||
| alpha = alpha_max / 10 | ||||||||||||
| clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y) | ||||||||||||
| clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y) | ||||||||||||
| w, _, _ = _chambolle_pock_sqrt(X, y, alpha, max_iter=1000) | ||||||||||||
| np.testing.assert_allclose(clf.coef_, w) | ||||||||||||
|
|
||||||||||||
|
|
@@ -73,9 +77,56 @@ def test_PDCD_WS(with_dual_init): | |||||||||||
| penalty = L1(alpha) | ||||||||||||
|
|
||||||||||||
| w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0] | ||||||||||||
| clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y) | ||||||||||||
|
|
||||||||||||
| clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y) | ||||||||||||
|
|
||||||||||||
| np.testing.assert_allclose(clf.coef_, w, atol=1e-6) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_sqrt_lasso_with_intercept(): | ||||||||||||
| np.random.seed(0) | ||||||||||||
| X = np.random.randn(10, 20) | ||||||||||||
| y = np.random.randn(10) | ||||||||||||
| y += 1 | ||||||||||||
|
|
||||||||||||
| n = len(y) | ||||||||||||
| alpha_max = norm(X.T @ y, ord=np.inf) / n | ||||||||||||
| alpha = alpha_max / 10 | ||||||||||||
|
|
||||||||||||
| # Fit standard Lasso with intercept | ||||||||||||
| lass = Lasso(alpha=alpha, fit_intercept=True, tol=1e-8).fit(X, y) | ||||||||||||
| w_lass = lass.coef_ | ||||||||||||
| assert norm(w_lass) > 0 | ||||||||||||
|
|
||||||||||||
| scal = n / norm(y - lass.predict(X)) | ||||||||||||
|
|
||||||||||||
| # Fit SqrtLasso with intercept | ||||||||||||
| sqrt = SqrtLasso(alpha=alpha * scal, fit_intercept=True, tol=1e-8).fit(X, y) | ||||||||||||
|
|
||||||||||||
| # Make sure intercept was learned | ||||||||||||
| assert abs(sqrt.intercept_) > 1e-6 | ||||||||||||
|
|
||||||||||||
| y_pred = sqrt.predict(X) | ||||||||||||
| assert y_pred.shape == y.shape | ||||||||||||
|
|
||||||||||||
| # Check that coef_ and intercept_ are handled separately | ||||||||||||
| assert sqrt.coef_.shape == (20,) | ||||||||||||
| assert np.isscalar(sqrt.intercept_) | ||||||||||||
|
|
||||||||||||
| # Confirm prediction matches manual computation | ||||||||||||
| manual_pred = X @ sqrt.coef_ + sqrt.intercept_ | ||||||||||||
| np.testing.assert_allclose(manual_pred, y_pred, rtol=1e-6) | ||||||||||||
|
|
||||||||||||
| np.testing.assert_allclose( | ||||||||||||
| sqrt.intercept_, y.mean() - X.mean(axis=0) @ sqrt.coef_, rtol=1e-6 | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| sqrt_no_intercept = SqrtLasso( | ||||||||||||
| alpha=alpha * scal, fit_intercept=False, tol=1e-8).fit(X, y) | ||||||||||||
| assert np.isscalar(sqrt_no_intercept.intercept_) | ||||||||||||
| np.testing.assert_allclose(sqrt_no_intercept.predict( | ||||||||||||
| X), X @ sqrt_no_intercept.coef_ + sqrt_no_intercept.intercept_) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| if __name__ == '__main__': | ||||||||||||
| pass | ||||||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@floriankozikowski remove these 3 files