@@ -53,6 +53,9 @@ pub struct LassoParameters {
5353 #[ cfg_attr( feature = "serde" , serde( default ) ) ]
5454 /// The maximum number of iterations
5555 pub max_iter : usize ,
56+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
57+ /// If false, force the intercept parameter (beta_0) to be zero.
58+ pub fit_intercept : bool ,
5659}
5760
5861#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
@@ -86,6 +89,12 @@ impl LassoParameters {
8689 self . max_iter = max_iter;
8790 self
8891 }
92+
93+ /// If false, force the intercept parameter (beta_0) to be zero.
94+ pub fn with_fit_intercept ( mut self , fit_intercept : bool ) -> Self {
95+ self . fit_intercept = fit_intercept;
96+ self
97+ }
8998}
9099
91100impl Default for LassoParameters {
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
95104 normalize : true ,
96105 tol : 1e-4 ,
97106 max_iter : 1000 ,
107+ fit_intercept : true ,
98108 }
99109 }
100110}
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
118128{
119129 fn new ( ) -> Self {
120130 Self {
121- coefficients : Option :: None ,
122- intercept : Option :: None ,
131+ coefficients : None ,
132+ intercept : None ,
123133 _phantom_ty : PhantomData ,
124134 _phantom_y : PhantomData ,
125135 }
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
155165 #[ cfg_attr( feature = "serde" , serde( default ) ) ]
156166 /// The maximum number of iterations
157167 pub max_iter : Vec < usize > ,
168+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
169+ /// The maximum number of iterations
170+ pub fit_intercept : Vec < bool > ,
158171}
159172
160173/// Lasso grid search iterator
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
164177 current_normalize : usize ,
165178 current_tol : usize ,
166179 current_max_iter : usize ,
180+ current_fit_intercept : usize ,
167181}
168182
169183impl IntoIterator for LassoSearchParameters {
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
177191 current_normalize : 0 ,
178192 current_tol : 0 ,
179193 current_max_iter : 0 ,
194+ current_fit_intercept : 0 ,
180195 }
181196 }
182197}
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
189204 && self . current_normalize == self . lasso_search_parameters . normalize . len ( )
190205 && self . current_tol == self . lasso_search_parameters . tol . len ( )
191206 && self . current_max_iter == self . lasso_search_parameters . max_iter . len ( )
207+ && self . current_fit_intercept == self . lasso_search_parameters . fit_intercept . len ( )
192208 {
193209 return None ;
194210 }
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
198214 normalize : self . lasso_search_parameters . normalize [ self . current_normalize ] ,
199215 tol : self . lasso_search_parameters . tol [ self . current_tol ] ,
200216 max_iter : self . lasso_search_parameters . max_iter [ self . current_max_iter ] ,
217+ fit_intercept : self . lasso_search_parameters . fit_intercept [ self . current_fit_intercept ] ,
201218 } ;
202219
203220 if self . current_alpha + 1 < self . lasso_search_parameters . alpha . len ( ) {
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
214231 self . current_normalize = 0 ;
215232 self . current_tol = 0 ;
216233 self . current_max_iter += 1 ;
234+ } else if self . current_fit_intercept + 1 < self . lasso_search_parameters . fit_intercept . len ( )
235+ {
236+ self . current_alpha = 0 ;
237+ self . current_normalize = 0 ;
238+ self . current_tol = 0 ;
239+ self . current_max_iter = 0 ;
240+ self . current_fit_intercept += 1 ;
217241 } else {
218242 self . current_alpha += 1 ;
219243 self . current_normalize += 1 ;
220244 self . current_tol += 1 ;
221245 self . current_max_iter += 1 ;
246+ self . current_fit_intercept += 1 ;
222247 }
223248
224249 Some ( next)
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
234259 normalize : vec ! [ default_params. normalize] ,
235260 tol : vec ! [ default_params. tol] ,
236261 max_iter : vec ! [ default_params. max_iter] ,
262+ fit_intercept : vec ! [ default_params. fit_intercept] ,
237263 }
238264 }
239265}
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
283309 l1_reg,
284310 parameters. max_iter ,
285311 TX :: from_f64 ( parameters. tol ) . unwrap ( ) ,
312+ parameters. fit_intercept ,
286313 ) ?;
287314
288315 for ( j, col_std_j) in col_std. iter ( ) . enumerate ( ) . take ( p) {
289316 w[ j] /= * col_std_j;
290317 }
291318
292- let mut b = TX :: zero ( ) ;
319+ let b = if parameters. fit_intercept {
320+ let mut xw_mean = TX :: zero ( ) ;
321+ for ( i, col_mean_i) in col_mean. iter ( ) . enumerate ( ) . take ( p) {
322+ xw_mean += w[ i] * * col_mean_i;
323+ }
293324
294- for ( i, col_mean_i) in col_mean. iter ( ) . enumerate ( ) . take ( p) {
295- b += w[ i] * * col_mean_i;
296- }
297-
298- b = TX :: from_f64 ( y. mean_by ( ) ) . unwrap ( ) - b;
325+ Some ( TX :: from_f64 ( y. mean_by ( ) ) . unwrap ( ) - xw_mean)
326+ } else {
327+ None
328+ } ;
299329 ( X :: from_column ( & w) , b)
300330 } else {
301331 let mut optimizer = InteriorPointOptimizer :: new ( x, p) ;
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
306336 l1_reg,
307337 parameters. max_iter ,
308338 TX :: from_f64 ( parameters. tol ) . unwrap ( ) ,
339+ parameters. fit_intercept ,
309340 ) ?;
310341
311- ( X :: from_column ( & w) , TX :: from_f64 ( y. mean_by ( ) ) . unwrap ( ) )
342+ (
343+ X :: from_column ( & w) ,
344+ if parameters. fit_intercept {
345+ Some ( TX :: from_f64 ( y. mean_by ( ) ) . unwrap ( ) )
346+ } else {
347+ None
348+ } ,
349+ )
312350 } ;
313351
314352 Ok ( Lasso {
315- intercept : Some ( b ) ,
353+ intercept : b ,
316354 coefficients : Some ( w) ,
317355 _phantom_ty : PhantomData ,
318356 _phantom_y : PhantomData ,
@@ -378,30 +416,28 @@ mod tests {
378416 let parameters = LassoSearchParameters {
379417 alpha : vec ! [ 0. , 1. ] ,
380418 max_iter : vec ! [ 10 , 100 ] ,
419+ fit_intercept : vec ! [ false , true ] ,
381420 ..Default :: default ( )
382421 } ;
383- let mut iter = parameters. into_iter ( ) ;
384- let next = iter. next ( ) . unwrap ( ) ;
385- assert_eq ! ( next. alpha, 0. ) ;
386- assert_eq ! ( next. max_iter, 10 ) ;
387- let next = iter. next ( ) . unwrap ( ) ;
388- assert_eq ! ( next. alpha, 1. ) ;
389- assert_eq ! ( next. max_iter, 10 ) ;
390- let next = iter. next ( ) . unwrap ( ) ;
391- assert_eq ! ( next. alpha, 0. ) ;
392- assert_eq ! ( next. max_iter, 100 ) ;
393- let next = iter. next ( ) . unwrap ( ) ;
394- assert_eq ! ( next. alpha, 1. ) ;
395- assert_eq ! ( next. max_iter, 100 ) ;
422+
423+ let mut iter = parameters. clone ( ) . into_iter ( ) ;
424+ for current_fit_intercept in 0 ..parameters. fit_intercept . len ( ) {
425+ for current_max_iter in 0 ..parameters. max_iter . len ( ) {
426+ for current_alpha in 0 ..parameters. alpha . len ( ) {
427+ let next = iter. next ( ) . unwrap ( ) ;
428+ assert_eq ! ( next. alpha, parameters. alpha[ current_alpha] ) ;
429+ assert_eq ! ( next. max_iter, parameters. max_iter[ current_max_iter] ) ;
430+ assert_eq ! (
431+ next. fit_intercept,
432+ parameters. fit_intercept[ current_fit_intercept]
433+ ) ;
434+ }
435+ }
436+ }
396437 assert ! ( iter. next( ) . is_none( ) ) ;
397438 }
398439
399- #[ cfg_attr(
400- all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
401- wasm_bindgen_test:: wasm_bindgen_test
402- ) ]
403- #[ test]
404- fn lasso_fit_predict ( ) {
440+ fn get_example_x_y ( ) -> ( DenseMatrix < f64 > , Vec < f64 > ) {
405441 let x = DenseMatrix :: from_2d_array ( & [
406442 & [ 234.289 , 235.6 , 159.0 , 107.608 , 1947. , 60.323 ] ,
407443 & [ 259.426 , 232.5 , 145.6 , 108.632 , 1948. , 61.122 ] ,
@@ -427,6 +463,17 @@ mod tests {
427463 114.2 , 115.7 , 116.9 ,
428464 ] ;
429465
466+ ( x, y)
467+ }
468+
469+ #[ cfg_attr(
470+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
471+ wasm_bindgen_test:: wasm_bindgen_test
472+ ) ]
473+ #[ test]
474+ fn lasso_fit_predict ( ) {
475+ let ( x, y) = get_example_x_y ( ) ;
476+
430477 let y_hat = Lasso :: fit ( & x, & y, Default :: default ( ) )
431478 . and_then ( |lr| lr. predict ( & x) )
432479 . unwrap ( ) ;
@@ -441,6 +488,7 @@ mod tests {
441488 normalize : false ,
442489 tol : 1e-4 ,
443490 max_iter : 1000 ,
491+ fit_intercept : true ,
444492 } ,
445493 )
446494 . and_then ( |lr| lr. predict ( & x) )
@@ -479,35 +527,46 @@ mod tests {
479527 assert ! ( mean_absolute_error( & w, & expected_w) < 1e-3 ) ; // actual mean_absolute_error is about 2e-4
480528 }
481529
530+ #[ cfg_attr(
531+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
532+ wasm_bindgen_test:: wasm_bindgen_test
533+ ) ]
534+ #[ test]
535+ fn test_fit_intercept ( ) {
536+ let ( x, y) = get_example_x_y ( ) ;
537+ let fit_result = Lasso :: fit (
538+ & x,
539+ & y,
540+ LassoParameters {
541+ alpha : 0.1 ,
542+ normalize : false ,
543+ tol : 1e-8 ,
544+ max_iter : 1000 ,
545+ fit_intercept : false ,
546+ } ,
547+ )
548+ . unwrap ( ) ;
549+
550+ let w = fit_result. coefficients ( ) . iterator ( 0 ) . copied ( ) . collect ( ) ;
551+ // by sklearn LassoLars. coordinate descent doesn't converge well
552+ let expected_w = vec ! [
553+ 0.18335684 ,
554+ 0.02106526 ,
555+ 0.00703214 ,
556+ -1.35952542 ,
557+ 0.09295222 ,
558+ 0. ,
559+ ] ;
560+ assert ! ( mean_absolute_error( & w, & expected_w) < 1e-6 ) ;
561+ assert_eq ! ( fit_result. intercept, None ) ;
562+ }
563+
482564 // TODO: serialization for the new DenseMatrix needs to be implemented
483565 // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
484566 // #[test]
485567 // #[cfg(feature = "serde")]
486568 // fn serde() {
487- // let x = DenseMatrix::from_2d_array(&[
488- // &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
489- // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
490- // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
491- // &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
492- // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
493- // &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
494- // &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
495- // &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
496- // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
497- // &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
498- // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
499- // &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
500- // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
501- // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
502- // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
503- // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
504- // ]);
505-
506- // let y = vec![
507- // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
508- // 114.2, 115.7, 116.9,
509- // ];
510-
569+ // let (x, y) = get_lasso_sample_x_y();
511570 // let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
512571
513572 // let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
0 commit comments