Skip to content

Commit aff89b7

Browse files
committed
Use lapack_sys in svddc.rs
1 parent 0ab9c59 commit aff89b7

File tree

1 file changed

+41
-35
lines changed

1 file changed

+41
-35
lines changed

lax/src/svddc.rs

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ pub enum UVTFlag {
1616
None = b'N',
1717
}
1818

19+
impl UVTFlag {
20+
fn as_ptr(&self) -> *const i8 {
21+
self as *const UVTFlag as *const i8
22+
}
23+
}
24+
1925
pub trait SVDDC_: Scalar {
2026
fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
2127
}
@@ -29,7 +35,7 @@ macro_rules! impl_svddc {
2935
};
3036
(@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
3137
impl SVDDC_ for $scalar {
32-
fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self],) -> Result<SVDOutput<Self>> {
38+
fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self],) -> Result<SVDOutput<Self>> {
3339
let m = l.lda();
3440
let n = l.len();
3541
let k = m.min(n);
@@ -58,7 +64,7 @@ macro_rules! impl_svddc {
5864
UVTFlag::None => 7 * mn,
5965
_ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
6066
};
61-
let mut $rwork_ident = unsafe { vec_uninit( lrwork) };
67+
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit( lrwork) };
6268
)*
6369

6470
// eval work size
@@ -67,44 +73,44 @@ macro_rules! impl_svddc {
6773
let mut work_size = [Self::zero()];
6874
unsafe {
6975
$gesdd(
70-
jobz as u8,
71-
m,
72-
n,
73-
&mut a,
74-
m,
75-
&mut s,
76-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
77-
m,
78-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
79-
vt_row,
80-
&mut work_size,
81-
-1,
82-
$(&mut $rwork_ident,)*
83-
&mut iwork,
76+
jobz.as_ptr(),
77+
&m,
78+
&n,
79+
AsPtr::as_mut_ptr(a),
80+
&m,
81+
AsPtr::as_mut_ptr(&mut s),
82+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
83+
&m,
84+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
85+
&vt_row,
86+
AsPtr::as_mut_ptr(&mut work_size),
87+
&(-1),
88+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
89+
iwork.as_mut_ptr(),
8490
&mut info,
8591
);
8692
}
8793
info.as_lapack_result()?;
8894

8995
// do svd
9096
let lwork = work_size[0].to_usize().unwrap();
91-
let mut work = unsafe { vec_uninit( lwork) };
97+
let mut work: Vec<Self> = unsafe { vec_uninit( lwork) };
9298
unsafe {
9399
$gesdd(
94-
jobz as u8,
95-
m,
96-
n,
97-
&mut a,
98-
m,
99-
&mut s,
100-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
101-
m,
102-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
103-
vt_row,
104-
&mut work,
105-
lwork as i32,
106-
$(&mut $rwork_ident,)*
107-
&mut iwork,
100+
jobz.as_ptr(),
101+
&m,
102+
&n,
103+
AsPtr::as_mut_ptr(a),
104+
&m,
105+
AsPtr::as_mut_ptr(&mut s),
106+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
107+
&m,
108+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
109+
&vt_row,
110+
AsPtr::as_mut_ptr(&mut work),
111+
&(lwork as i32),
112+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
113+
iwork.as_mut_ptr(),
108114
&mut info,
109115
);
110116
}
@@ -119,7 +125,7 @@ macro_rules! impl_svddc {
119125
};
120126
}
121127

122-
impl_svddc!(@real, f32, lapack::sgesdd);
123-
impl_svddc!(@real, f64, lapack::dgesdd);
124-
impl_svddc!(@complex, c32, lapack::cgesdd);
125-
impl_svddc!(@complex, c64, lapack::zgesdd);
128+
impl_svddc!(@real, f32, lapack_sys::sgesdd_);
129+
impl_svddc!(@real, f64, lapack_sys::dgesdd_);
130+
impl_svddc!(@complex, c32, lapack_sys::cgesdd_);
131+
impl_svddc!(@complex, c64, lapack_sys::zgesdd_);

0 commit comments

Comments
 (0)