Skip to content

Commit 200b9d4

Browse files
committed
Take transpose
1 parent 8ff555d commit 200b9d4

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

lax/src/least_squares.rs

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Least squares
22
3-
use crate::{error::*, layout::MatrixLayout};
3+
use crate::{error::*, layout::*};
44
use cauchy::*;
55
use num_traits::{ToPrimitive, Zero};
66

@@ -46,16 +46,37 @@ macro_rules! impl_least_squares_real {
4646
b_layout: MatrixLayout,
4747
b: &mut [Self],
4848
) -> Result<LeastSquaresOutput<Self>> {
49-
let m = a_layout.lda();
50-
let n = a_layout.len();
49+
// Minimize |b - Ax|_2
50+
//
51+
// where
52+
// A : (m, n)
53+
// b : (m, p)
54+
// x : (n, p)
55+
let (m, n) = a_layout.size();
56+
let (m_, p) = b_layout.size();
5157
let k = m.min(n);
52-
if (m as usize) > b.len()
53-
|| (n as usize) > b.len()
54-
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
55-
{
56-
return Err(Error::InvalidShape);
57-
}
58-
let (b_lda, nrhs) = b_layout.size();
58+
assert_eq!(m, m_);
59+
60+
// Transpose if a is C-continuous
61+
let mut a_t = None;
62+
let a_layout = match a_layout {
63+
MatrixLayout::C { .. } => {
64+
a_t = Some(vec![Self::zero(); a.len()]);
65+
transpose(a_layout, a, a_t.as_mut().unwrap())
66+
}
67+
MatrixLayout::F { .. } => a_layout,
68+
};
69+
70+
// Transpose if b is C-continuous
71+
let mut b_t = None;
72+
let b_layout = match b_layout {
73+
MatrixLayout::C { .. } => {
74+
b_t = Some(vec![Self::zero(); b.len()]);
75+
transpose(b_layout, b, b_t.as_mut().unwrap())
76+
}
77+
MatrixLayout::F { .. } => b_layout,
78+
};
79+
5980
let rcond: Self::Real = -1.;
6081
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
6182
let mut rank: i32 = 0;
@@ -67,11 +88,11 @@ macro_rules! impl_least_squares_real {
6788
$gelsd(
6889
m,
6990
n,
70-
nrhs,
71-
a,
72-
m,
73-
b,
74-
b_lda,
91+
p,
92+
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
93+
a_layout.lda(),
94+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
95+
b_layout.lda(),
7596
&mut singular_values,
7697
rcond,
7798
&mut rank,
@@ -90,11 +111,11 @@ macro_rules! impl_least_squares_real {
90111
$gelsd(
91112
m,
92113
n,
93-
nrhs,
94-
a,
95-
m,
96-
b,
97-
b_lda,
114+
p,
115+
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
116+
a_layout.lda(),
117+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
118+
b_layout.lda(),
98119
&mut singular_values,
99120
rcond,
100121
&mut rank,
@@ -105,6 +126,12 @@ macro_rules! impl_least_squares_real {
105126
);
106127
info.as_lapack_result()?;
107128

129+
// Skip a_t -> a transpose because A has been destroyed
130+
// Re-transpose b
131+
if let Some(b_t) = b_t {
132+
transpose(b_layout, &b_t, b);
133+
}
134+
108135
Ok(LeastSquaresOutput {
109136
singular_values,
110137
rank,

ndarray-linalg/src/least_squares.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ use crate::types::*;
7676
/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
7777
/// (which can be seen as solving `Ax = b` k times for different b) and
7878
/// the solution is a `m x k` matrix.
79+
#[derive(Debug, Clone)]
7980
pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
8081
/// The singular values of the matrix A in `Ax = b`
8182
pub singular_values: Array1<E::Real>,

ndarray-linalg/tests/least_squares_nrhs.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ fn test_exact<T: Scalar + Lapack>(a: Array2<T>, b: Array2<T>) {
99
assert_eq!(b.layout().unwrap().size(), (3, 2));
1010

1111
let result = a.least_squares(&b).unwrap();
12+
dbg!(&result);
1213
// unpack result
1314
let x: Array2<T> = result.solution;
1415
let residual_l2_square: Array1<T::Real> = result.residual_sum_of_squares.unwrap();

0 commit comments

Comments
 (0)