Skip to content

Commit 18de2aa

Browse files
authored
add fit_intercept to LASSO (#344)
* add fit_intercept to LASSO * lasso: intercept=None if fit_intercept is false * update CHANGELOG.md to reflect lasso changes * lasso: minor
1 parent 2bf5f7a commit 18de2aa

File tree

4 files changed

+125
-54
lines changed

4 files changed

+125
-54
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [Unreleased]
8+
- WARNING: Breaking changes!
9+
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
10+
11+
712
## [0.4.0] - 2023-04-05
813

914
## Added

src/linear/elastic_net.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
345345
l1_reg * gamma,
346346
parameters.max_iter,
347347
TX::from_f64(parameters.tol).unwrap(),
348+
true,
348349
)?;
349350

350351
for i in 0..p {
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
371372
l1_reg * gamma,
372373
parameters.max_iter,
373374
TX::from_f64(parameters.tol).unwrap(),
375+
true,
374376
)?;
375377

376378
for i in 0..p {

src/linear/lasso.rs

Lines changed: 112 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ pub struct LassoParameters {
5353
#[cfg_attr(feature = "serde", serde(default))]
5454
/// The maximum number of iterations
5555
pub max_iter: usize,
56+
#[cfg_attr(feature = "serde", serde(default))]
57+
/// If false, force the intercept parameter (beta_0) to be zero.
58+
pub fit_intercept: bool,
5659
}
5760

5861
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -86,6 +89,12 @@ impl LassoParameters {
8689
self.max_iter = max_iter;
8790
self
8891
}
92+
93+
/// If false, force the intercept parameter (beta_0) to be zero.
94+
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
95+
self.fit_intercept = fit_intercept;
96+
self
97+
}
8998
}
9099

91100
impl Default for LassoParameters {
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
95104
normalize: true,
96105
tol: 1e-4,
97106
max_iter: 1000,
107+
fit_intercept: true,
98108
}
99109
}
100110
}
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
118128
{
119129
fn new() -> Self {
120130
Self {
121-
coefficients: Option::None,
122-
intercept: Option::None,
131+
coefficients: None,
132+
intercept: None,
123133
_phantom_ty: PhantomData,
124134
_phantom_y: PhantomData,
125135
}
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
155165
#[cfg_attr(feature = "serde", serde(default))]
156166
/// The maximum number of iterations
157167
pub max_iter: Vec<usize>,
168+
#[cfg_attr(feature = "serde", serde(default))]
169+
/// The maximum number of iterations
170+
pub fit_intercept: Vec<bool>,
158171
}
159172

160173
/// Lasso grid search iterator
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
164177
current_normalize: usize,
165178
current_tol: usize,
166179
current_max_iter: usize,
180+
current_fit_intercept: usize,
167181
}
168182

169183
impl IntoIterator for LassoSearchParameters {
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
177191
current_normalize: 0,
178192
current_tol: 0,
179193
current_max_iter: 0,
194+
current_fit_intercept: 0,
180195
}
181196
}
182197
}
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
189204
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
190205
&& self.current_tol == self.lasso_search_parameters.tol.len()
191206
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
207+
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
192208
{
193209
return None;
194210
}
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
198214
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
199215
tol: self.lasso_search_parameters.tol[self.current_tol],
200216
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
217+
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
201218
};
202219

203220
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
214231
self.current_normalize = 0;
215232
self.current_tol = 0;
216233
self.current_max_iter += 1;
234+
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
235+
{
236+
self.current_alpha = 0;
237+
self.current_normalize = 0;
238+
self.current_tol = 0;
239+
self.current_max_iter = 0;
240+
self.current_fit_intercept += 1;
217241
} else {
218242
self.current_alpha += 1;
219243
self.current_normalize += 1;
220244
self.current_tol += 1;
221245
self.current_max_iter += 1;
246+
self.current_fit_intercept += 1;
222247
}
223248

224249
Some(next)
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
234259
normalize: vec![default_params.normalize],
235260
tol: vec![default_params.tol],
236261
max_iter: vec![default_params.max_iter],
262+
fit_intercept: vec![default_params.fit_intercept],
237263
}
238264
}
239265
}
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
283309
l1_reg,
284310
parameters.max_iter,
285311
TX::from_f64(parameters.tol).unwrap(),
312+
parameters.fit_intercept,
286313
)?;
287314

288315
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
289316
w[j] /= *col_std_j;
290317
}
291318

292-
let mut b = TX::zero();
319+
let b = if parameters.fit_intercept {
320+
let mut xw_mean = TX::zero();
321+
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
322+
xw_mean += w[i] * *col_mean_i;
323+
}
293324

294-
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
295-
b += w[i] * *col_mean_i;
296-
}
297-
298-
b = TX::from_f64(y.mean_by()).unwrap() - b;
325+
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
326+
} else {
327+
None
328+
};
299329
(X::from_column(&w), b)
300330
} else {
301331
let mut optimizer = InteriorPointOptimizer::new(x, p);
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
306336
l1_reg,
307337
parameters.max_iter,
308338
TX::from_f64(parameters.tol).unwrap(),
339+
parameters.fit_intercept,
309340
)?;
310341

311-
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
342+
(
343+
X::from_column(&w),
344+
if parameters.fit_intercept {
345+
Some(TX::from_f64(y.mean_by()).unwrap())
346+
} else {
347+
None
348+
},
349+
)
312350
};
313351

314352
Ok(Lasso {
315-
intercept: Some(b),
353+
intercept: b,
316354
coefficients: Some(w),
317355
_phantom_ty: PhantomData,
318356
_phantom_y: PhantomData,
@@ -378,30 +416,28 @@ mod tests {
378416
let parameters = LassoSearchParameters {
379417
alpha: vec![0., 1.],
380418
max_iter: vec![10, 100],
419+
fit_intercept: vec![false, true],
381420
..Default::default()
382421
};
383-
let mut iter = parameters.into_iter();
384-
let next = iter.next().unwrap();
385-
assert_eq!(next.alpha, 0.);
386-
assert_eq!(next.max_iter, 10);
387-
let next = iter.next().unwrap();
388-
assert_eq!(next.alpha, 1.);
389-
assert_eq!(next.max_iter, 10);
390-
let next = iter.next().unwrap();
391-
assert_eq!(next.alpha, 0.);
392-
assert_eq!(next.max_iter, 100);
393-
let next = iter.next().unwrap();
394-
assert_eq!(next.alpha, 1.);
395-
assert_eq!(next.max_iter, 100);
422+
423+
let mut iter = parameters.clone().into_iter();
424+
for current_fit_intercept in 0..parameters.fit_intercept.len() {
425+
for current_max_iter in 0..parameters.max_iter.len() {
426+
for current_alpha in 0..parameters.alpha.len() {
427+
let next = iter.next().unwrap();
428+
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
429+
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
430+
assert_eq!(
431+
next.fit_intercept,
432+
parameters.fit_intercept[current_fit_intercept]
433+
);
434+
}
435+
}
436+
}
396437
assert!(iter.next().is_none());
397438
}
398439

399-
#[cfg_attr(
400-
all(target_arch = "wasm32", not(target_os = "wasi")),
401-
wasm_bindgen_test::wasm_bindgen_test
402-
)]
403-
#[test]
404-
fn lasso_fit_predict() {
440+
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
405441
let x = DenseMatrix::from_2d_array(&[
406442
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
407443
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -427,6 +463,17 @@ mod tests {
427463
114.2, 115.7, 116.9,
428464
];
429465

466+
(x, y)
467+
}
468+
469+
#[cfg_attr(
470+
all(target_arch = "wasm32", not(target_os = "wasi")),
471+
wasm_bindgen_test::wasm_bindgen_test
472+
)]
473+
#[test]
474+
fn lasso_fit_predict() {
475+
let (x, y) = get_example_x_y();
476+
430477
let y_hat = Lasso::fit(&x, &y, Default::default())
431478
.and_then(|lr| lr.predict(&x))
432479
.unwrap();
@@ -441,6 +488,7 @@ mod tests {
441488
normalize: false,
442489
tol: 1e-4,
443490
max_iter: 1000,
491+
fit_intercept: true,
444492
},
445493
)
446494
.and_then(|lr| lr.predict(&x))
@@ -479,35 +527,46 @@ mod tests {
479527
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
480528
}
481529

530+
#[cfg_attr(
531+
all(target_arch = "wasm32", not(target_os = "wasi")),
532+
wasm_bindgen_test::wasm_bindgen_test
533+
)]
534+
#[test]
535+
fn test_fit_intercept() {
536+
let (x, y) = get_example_x_y();
537+
let fit_result = Lasso::fit(
538+
&x,
539+
&y,
540+
LassoParameters {
541+
alpha: 0.1,
542+
normalize: false,
543+
tol: 1e-8,
544+
max_iter: 1000,
545+
fit_intercept: false,
546+
},
547+
)
548+
.unwrap();
549+
550+
let w = fit_result.coefficients().iterator(0).copied().collect();
551+
// by sklearn LassoLars. coordinate descent doesn't converge well
552+
let expected_w = vec![
553+
0.18335684,
554+
0.02106526,
555+
0.00703214,
556+
-1.35952542,
557+
0.09295222,
558+
0.,
559+
];
560+
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
561+
assert_eq!(fit_result.intercept, None);
562+
}
563+
482564
// TODO: serialization for the new DenseMatrix needs to be implemented
483565
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
484566
// #[test]
485567
// #[cfg(feature = "serde")]
486568
// fn serde() {
487-
// let x = DenseMatrix::from_2d_array(&[
488-
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
489-
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
490-
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
491-
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
492-
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
493-
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
494-
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
495-
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
496-
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
497-
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
498-
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
499-
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
500-
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
501-
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
502-
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
503-
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
504-
// ]);
505-
506-
// let y = vec![
507-
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
508-
// 114.2, 115.7, 116.9,
509-
// ];
510-
569+
// let (x, y) = get_lasso_sample_x_y();
511570
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
512571

513572
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =

src/linear/lasso_optimizer.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
4545
lambda: T,
4646
max_iter: usize,
4747
tol: T,
48+
fit_intercept: bool,
4849
) -> Result<Vec<T>, Failed> {
4950
let (n, p) = x.shape();
5051
let p_f64 = T::from_usize(p).unwrap();
@@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
6162
let mu = T::two();
6263

6364
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
64-
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
65+
let y = if fit_intercept {
66+
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
67+
} else {
68+
y.to_owned()
69+
};
6570

6671
let mut max_ls_iter = 100;
6772
let mut pitr = 0;

0 commit comments

Comments
 (0)