Skip to content

Commit 24888ec

Browse files
authored
Merge pull request #220 from rust-ndarray/lapack-least-square
least square by LAPACK
2 parents ffb6bef + 962497e commit 24888ec

File tree

5 files changed

+256
-144
lines changed

5 files changed

+256
-144
lines changed

lax/src/layout.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,37 @@ impl MatrixLayout {
9797
MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col },
9898
}
9999
}
100+
101+
/// Transpose without changing memory representation
102+
///
103+
/// C-contigious row=2, lda=3
104+
///
105+
/// ```text
106+
/// [[1, 2, 3]
107+
/// [4, 5, 6]]
108+
/// ```
109+
///
110+
/// and F-contigious col=2, lda=3
111+
///
112+
/// ```text
113+
/// [[1, 4]
114+
/// [2, 5]
115+
/// [3, 6]]
116+
/// ```
117+
///
118+
/// have same memory representation `[1, 2, 3, 4, 5, 6]`, and this toggles them.
119+
///
120+
/// ```
121+
/// # use lax::layout::*;
122+
/// let layout = MatrixLayout::C { row: 2, lda: 3 };
123+
/// assert_eq!(layout.t(), MatrixLayout::F { col: 2, lda: 3 });
124+
/// ```
125+
pub fn t(&self) -> Self {
126+
match *self {
127+
MatrixLayout::C { row, lda } => MatrixLayout::F { col: row, lda },
128+
MatrixLayout::F { col, lda } => MatrixLayout::C { row: col, lda },
129+
}
130+
}
100131
}
101132

102133
/// In-place transpose of a square matrix by keeping F/C layout
@@ -139,3 +170,59 @@ pub fn square_transpose<T: Scalar>(layout: MatrixLayout, a: &mut [T]) {
139170
}
140171
}
141172
}
173+
174+
/// Out-place transpose for general matrix
175+
///
176+
/// Inplace transpose of non-square matrices is hard.
177+
/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition
178+
///
179+
/// ```rust
180+
/// # use lax::layout::*;
181+
/// let layout = MatrixLayout::C { row: 2, lda: 3 };
182+
/// let a = vec![1., 2., 3., 4., 5., 6.];
183+
/// let mut b = vec![0.0; a.len()];
184+
/// let l = transpose(layout, &a, &mut b);
185+
/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 });
186+
/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]);
187+
/// ```
188+
///
189+
/// ```rust
190+
/// # use lax::layout::*;
191+
/// let layout = MatrixLayout::F { col: 2, lda: 3 };
192+
/// let a = vec![1., 2., 3., 4., 5., 6.];
193+
/// let mut b = vec![0.0; a.len()];
194+
/// let l = transpose(layout, &a, &mut b);
195+
/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 });
196+
/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]);
197+
/// ```
198+
///
199+
/// Panics
200+
/// ------
201+
/// - If size of `a` and `layout` size mismatch
202+
///
203+
pub fn transpose<T: Scalar>(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout {
204+
let (m, n) = layout.size();
205+
let transposed = layout.resized(n, m).t();
206+
let m = m as usize;
207+
let n = n as usize;
208+
assert_eq!(from.len(), m * n);
209+
assert_eq!(to.len(), m * n);
210+
211+
match layout {
212+
MatrixLayout::C { .. } => {
213+
for i in 0..m {
214+
for j in 0..n {
215+
to[j * m + i] = from[i * n + j];
216+
}
217+
}
218+
}
219+
MatrixLayout::F { .. } => {
220+
for i in 0..m {
221+
for j in 0..n {
222+
to[i * n + j] = from[j * m + i];
223+
}
224+
}
225+
}
226+
}
227+
transposed
228+
}

lax/src/least_squares.rs

Lines changed: 115 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
//! Least squares
22
3-
use crate::{error::*, layout::MatrixLayout};
3+
use crate::{error::*, layout::*};
44
use cauchy::*;
5-
use num_traits::Zero;
5+
use num_traits::{ToPrimitive, Zero};
66

77
/// Result of LeastSquares
88
pub struct LeastSquaresOutput<A: Scalar> {
@@ -14,13 +14,13 @@ pub struct LeastSquaresOutput<A: Scalar> {
1414

1515
/// Wraps `*gelsd`
1616
pub trait LeastSquaresSvdDivideConquer_: Scalar {
17-
unsafe fn least_squares(
17+
fn least_squares(
1818
a_layout: MatrixLayout,
1919
a: &mut [Self],
2020
b: &mut [Self],
2121
) -> Result<LeastSquaresOutput<Self>>;
2222

23-
unsafe fn least_squares_nrhs(
23+
fn least_squares_nrhs(
2424
a_layout: MatrixLayout,
2525
a: &mut [Self],
2626
b_layout: MatrixLayout,
@@ -29,81 +29,129 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
2929
}
3030

3131
macro_rules! impl_least_squares {
32-
($scalar:ty, $gelsd:path) => {
32+
(@real, $scalar:ty, $gelsd:path) => {
33+
impl_least_squares!(@body, $scalar, $gelsd, );
34+
};
35+
(@complex, $scalar:ty, $gelsd:path) => {
36+
impl_least_squares!(@body, $scalar, $gelsd, rwork);
37+
};
38+
39+
(@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => {
3340
impl LeastSquaresSvdDivideConquer_ for $scalar {
34-
unsafe fn least_squares(
35-
a_layout: MatrixLayout,
41+
fn least_squares(
42+
l: MatrixLayout,
3643
a: &mut [Self],
3744
b: &mut [Self],
3845
) -> Result<LeastSquaresOutput<Self>> {
39-
let (m, n) = a_layout.size();
40-
if (m as usize) > b.len() || (n as usize) > b.len() {
41-
return Err(Error::InvalidShape);
42-
}
43-
let k = ::std::cmp::min(m, n);
44-
let nrhs = 1;
45-
let ldb = match a_layout {
46-
MatrixLayout::F { .. } => m.max(n),
47-
MatrixLayout::C { .. } => 1,
48-
};
49-
let rcond: Self::Real = -1.;
50-
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
51-
let mut rank: i32 = 0;
52-
53-
$gelsd(
54-
a_layout.lapacke_layout(),
55-
m,
56-
n,
57-
nrhs,
58-
a,
59-
a_layout.lda(),
60-
b,
61-
ldb,
62-
&mut singular_values,
63-
rcond,
64-
&mut rank,
65-
)
66-
.as_lapack_result()?;
67-
68-
Ok(LeastSquaresOutput {
69-
singular_values,
70-
rank,
71-
})
46+
let b_layout = l.resized(b.len() as i32, 1);
47+
Self::least_squares_nrhs(l, a, b_layout, b)
7248
}
7349

74-
unsafe fn least_squares_nrhs(
50+
fn least_squares_nrhs(
7551
a_layout: MatrixLayout,
7652
a: &mut [Self],
7753
b_layout: MatrixLayout,
7854
b: &mut [Self],
7955
) -> Result<LeastSquaresOutput<Self>> {
56+
// Minimize |b - Ax|_2
57+
//
58+
// where
59+
// A : (m, n)
60+
// b : (max(m, n), nrhs) // `b` has to store `x` on exit
61+
// x : (n, nrhs)
8062
let (m, n) = a_layout.size();
81-
if (m as usize) > b.len()
82-
|| (n as usize) > b.len()
83-
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
84-
{
85-
return Err(Error::InvalidShape);
86-
}
87-
let k = ::std::cmp::min(m, n);
88-
let nrhs = b_layout.size().1;
63+
let (m_, nrhs) = b_layout.size();
64+
let k = m.min(n);
65+
assert!(m_ >= m);
66+
67+
// Transpose if a is C-continuous
68+
let mut a_t = None;
69+
let a_layout = match a_layout {
70+
MatrixLayout::C { .. } => {
71+
a_t = Some(vec![Self::zero(); a.len()]);
72+
transpose(a_layout, a, a_t.as_mut().unwrap())
73+
}
74+
MatrixLayout::F { .. } => a_layout,
75+
};
76+
77+
// Transpose if b is C-continuous
78+
let mut b_t = None;
79+
let b_layout = match b_layout {
80+
MatrixLayout::C { .. } => {
81+
b_t = Some(vec![Self::zero(); b.len()]);
82+
transpose(b_layout, b, b_t.as_mut().unwrap())
83+
}
84+
MatrixLayout::F { .. } => b_layout,
85+
};
86+
8987
let rcond: Self::Real = -1.;
9088
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
9189
let mut rank: i32 = 0;
9290

93-
$gelsd(
94-
a_layout.lapacke_layout(),
95-
m,
96-
n,
97-
nrhs,
98-
a,
99-
a_layout.lda(),
100-
b,
101-
b_layout.lda(),
102-
&mut singular_values,
103-
rcond,
104-
&mut rank,
105-
)
106-
.as_lapack_result()?;
91+
// eval work size
92+
let mut info = 0;
93+
let mut work_size = [Self::zero()];
94+
let mut iwork_size = [0];
95+
$(
96+
let mut $rwork = [Self::Real::zero()];
97+
)*
98+
unsafe {
99+
$gelsd(
100+
m,
101+
n,
102+
nrhs,
103+
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
104+
a_layout.lda(),
105+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
106+
b_layout.lda(),
107+
&mut singular_values,
108+
rcond,
109+
&mut rank,
110+
&mut work_size,
111+
-1,
112+
$(&mut $rwork,)*
113+
&mut iwork_size,
114+
&mut info,
115+
)
116+
};
117+
info.as_lapack_result()?;
118+
119+
// calc
120+
let lwork = work_size[0].to_usize().unwrap();
121+
let mut work = vec![Self::zero(); lwork];
122+
let liwork = iwork_size[0].to_usize().unwrap();
123+
let mut iwork = vec![0; liwork];
124+
$(
125+
let lrwork = $rwork[0].to_usize().unwrap();
126+
let mut $rwork = vec![Self::Real::zero(); lrwork];
127+
)*
128+
unsafe {
129+
$gelsd(
130+
m,
131+
n,
132+
nrhs,
133+
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
134+
a_layout.lda(),
135+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
136+
b_layout.lda(),
137+
&mut singular_values,
138+
rcond,
139+
&mut rank,
140+
&mut work,
141+
lwork as i32,
142+
$(&mut $rwork,)*
143+
&mut iwork,
144+
&mut info,
145+
);
146+
}
147+
info.as_lapack_result()?;
148+
149+
// Skip a_t -> a transpose because A has been destroyed
150+
// Re-transpose b
151+
if let Some(b_t) = b_t {
152+
transpose(b_layout, &b_t, b);
153+
}
154+
107155
Ok(LeastSquaresOutput {
108156
singular_values,
109157
rank,
@@ -113,7 +161,7 @@ macro_rules! impl_least_squares {
113161
};
114162
}
115163

116-
impl_least_squares!(f64, lapacke::dgelsd);
117-
impl_least_squares!(f32, lapacke::sgelsd);
118-
impl_least_squares!(c64, lapacke::zgelsd);
119-
impl_least_squares!(c32, lapacke::cgelsd);
164+
impl_least_squares!(@real, f64, lapack::dgelsd);
165+
impl_least_squares!(@real, f32, lapack::sgelsd);
166+
impl_least_squares!(@complex, c64, lapack::zgelsd);
167+
impl_least_squares!(@complex, c32, lapack::cgelsd);

0 commit comments

Comments
 (0)