@@ -100,10 +100,13 @@ class SqrtLasso(LinearModel, RegressorMixin):
100100
101101 verbose : bool, default False
102102 Amount of verbosity. 0/False is silent.
103+
104+ fit_intercept: bool, optional (default=True)
105+ Whether or not to fit an intercept.
103106 """
104107
105108 def __init__ (self , alpha = 1. , max_iter = 100 , max_pn_iter = 100 , p0 = 10 ,
106- tol = 1e-4 , verbose = 0 ):
109+ tol = 1e-4 , verbose = 0 , fit_intercept = True ):
107110 super ().__init__ ()
108111 self .alpha = alpha
109112 self .max_iter = max_iter
@@ -112,6 +115,7 @@ def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
112115 self .p0 = p0
113116 self .tol = tol
114117 self .verbose = verbose
118+ self .fit_intercept = fit_intercept
115119
116120 def fit (self , X , y ):
117121 """Fit the model according to the given training data.
@@ -131,7 +135,11 @@ def fit(self, X, y):
131135 Fitted estimator.
132136 """
133137 self .coef_ = self .path (X , y , alphas = [self .alpha ])[1 ][0 ]
134- self .intercept_ = 0. # TODO handle fit_intercept
138+ if self .fit_intercept :
139+ self .intercept_ = self .coef_ [- 1 ]
140+ self .coef_ = self .coef_ [:- 1 ]
141+ else :
142+ self .intercept_ = 0.
135143 return self
136144
137145 def path (self , X , y , alphas = None , eps = 1e-3 , n_alphas = 10 ):
@@ -168,7 +176,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
168176 if not hasattr (self , "solver_" ):
169177 self .solver_ = ProxNewton (
170178 tol = self .tol , max_iter = self .max_iter , verbose = self .verbose ,
171- fit_intercept = False )
179+ fit_intercept = self . fit_intercept )
172180 # build path
173181 if alphas is None :
174182 alpha_max = norm (X .T @ y , ord = np .inf ) / (np .sqrt (len (y )) * norm (y ))
@@ -181,7 +189,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
181189 sqrt_quadratic = SqrtQuadratic ()
182190 l1_penalty = L1 (1. ) # alpha is set along the path
183191
184- coefs = np .zeros ((n_alphas , n_features ))
192+ coefs = np .zeros ((n_alphas , n_features + self . fit_intercept ))
185193
186194 for i in range (n_alphas ):
187195 if self .verbose :
@@ -192,12 +200,14 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
192200
193201 l1_penalty .alpha = alphas [i ]
194202 # no warm start for the first alpha
195- coef_init = coefs [i ].copy () if i else np .zeros (n_features )
203+ coef_init = coefs [i ].copy () if i else np .zeros (n_features
204+ + self .fit_intercept )
196205
197206 try :
198207 coef , _ , _ = self .solver_ .solve (
199208 X , y , sqrt_quadratic , l1_penalty ,
200- w_init = coef_init , Xw_init = X @ coef_init )
209+ w_init = coef_init , Xw_init = X @ coef_init [:- 1 ] + coef_init [- 1 ]
210+ if self .fit_intercept else X @ coef_init )
201211 coefs [i ] = coef
202212 except ValueError as val_exception :
203213 # make sure to catch residual error
@@ -208,7 +218,8 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
208218 # save coef despite not converging
209219 # coef_init holds a ref to coef
210220 coef = coef_init
211- res_norm = norm (y - X @ coef )
221+ X_coef = X @ coef [:- 1 ] + coef [- 1 ] if self .fit_intercept else X @ coef
222+ res_norm = norm (y - X_coef )
212223 warnings .warn (
213224 f"Small residuals prevented the solver from converging "
214225 f"at alpha={ alphas [i ]:.2e} (residuals' norm: { res_norm :.4e} ). "
0 commit comments