Skip to content

Commit 0ab9c59

Browse files
committed
Use lapack_sys in svd.rs
1 parent f155620 commit 0ab9c59

File tree

1 file changed

+39
-35
lines changed

1 file changed

+39
-35
lines changed

lax/src/svd.rs

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ impl FlagSVD {
2121
FlagSVD::No
2222
}
2323
}
24+
25+
fn as_ptr(&self) -> *const i8 {
26+
self as *const FlagSVD as *const i8
27+
}
2428
}
2529

2630
/// Result of SVD
@@ -49,7 +53,7 @@ macro_rules! impl_svd {
4953
};
5054
(@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
5155
impl SVD_ for $scalar {
52-
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
56+
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
5357
let ju = match l {
5458
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
5559
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
@@ -75,52 +79,52 @@ macro_rules! impl_svd {
7579
let mut s = unsafe { vec_uninit( k as usize) };
7680

7781
$(
78-
let mut $rwork_ident = unsafe { vec_uninit( 5 * k as usize) };
82+
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit( 5 * k as usize) };
7983
)*
8084

8185
// eval work size
8286
let mut info = 0;
8387
let mut work_size = [Self::zero()];
8488
unsafe {
8589
$gesvd(
86-
ju as u8,
87-
jvt as u8,
88-
m,
89-
n,
90-
&mut a,
91-
m,
92-
&mut s,
93-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
94-
m,
95-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
96-
n,
97-
&mut work_size,
98-
-1,
99-
$(&mut $rwork_ident,)*
90+
ju.as_ptr(),
91+
jvt.as_ptr(),
92+
&m,
93+
&n,
94+
AsPtr::as_mut_ptr(a),
95+
&m,
96+
AsPtr::as_mut_ptr(&mut s),
97+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
98+
&m,
99+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
100+
&n,
101+
AsPtr::as_mut_ptr(&mut work_size),
102+
&(-1),
103+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
100104
&mut info,
101105
);
102106
}
103107
info.as_lapack_result()?;
104108

105109
// calc
106110
let lwork = work_size[0].to_usize().unwrap();
107-
let mut work = unsafe { vec_uninit( lwork) };
111+
let mut work: Vec<Self> = unsafe { vec_uninit( lwork) };
108112
unsafe {
109113
$gesvd(
110-
ju as u8,
111-
jvt as u8,
112-
m,
113-
n,
114-
&mut a,
115-
m,
116-
&mut s,
117-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
118-
m,
119-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
120-
n,
121-
&mut work,
122-
lwork as i32,
123-
$(&mut $rwork_ident,)*
114+
ju.as_ptr(),
115+
jvt.as_ptr() ,
116+
&m,
117+
&n,
118+
AsPtr::as_mut_ptr(a),
119+
&m,
120+
AsPtr::as_mut_ptr(&mut s),
121+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
122+
&m,
123+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
124+
&n,
125+
AsPtr::as_mut_ptr(&mut work),
126+
&(lwork as i32),
127+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
124128
&mut info,
125129
);
126130
}
@@ -134,7 +138,7 @@ macro_rules! impl_svd {
134138
};
135139
} // impl_svd!
136140

137-
impl_svd!(@real, f64, lapack::dgesvd);
138-
impl_svd!(@real, f32, lapack::sgesvd);
139-
impl_svd!(@complex, c64, lapack::zgesvd);
140-
impl_svd!(@complex, c32, lapack::cgesvd);
141+
impl_svd!(@real, f64, lapack_sys::dgesvd_);
142+
impl_svd!(@real, f32, lapack_sys::sgesvd_);
143+
impl_svd!(@complex, c64, lapack_sys::zgesvd_);
144+
impl_svd!(@complex, c32, lapack_sys::cgesvd_);

0 commit comments

Comments
 (0)