Skip to content

Commit 4bf4414

Browse files
committed
Use lapack_sys in Eigh_
1 parent 3c4dc5e commit 4bf4414

File tree

1 file changed

+55
-53
lines changed

1 file changed

+55
-53
lines changed

lax/src/eigh.rs

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,50 +37,51 @@ macro_rules! impl_eigh {
3737
calc_v: bool,
3838
layout: MatrixLayout,
3939
uplo: UPLO,
40-
mut a: &mut [Self],
40+
a: &mut [Self],
4141
) -> Result<Vec<Self::Real>> {
4242
assert_eq!(layout.len(), layout.lda());
4343
let n = layout.len();
4444
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
4545
let mut eigs = unsafe { vec_uninit(n as usize) };
4646

4747
$(
48-
let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2 as usize) };
48+
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit(3 * n as usize - 2 as usize) };
4949
)*
5050

5151
// calc work size
5252
let mut info = 0;
5353
let mut work_size = [Self::zero()];
5454
unsafe {
5555
$ev(
56-
jobz as u8,
57-
uplo as u8,
58-
n,
59-
&mut a,
60-
n,
61-
&mut eigs,
62-
&mut work_size,
63-
-1,
64-
$(&mut $rwork_ident,)*
56+
jobz.as_ptr() ,
57+
uplo.as_ptr(),
58+
&n,
59+
AsPtr::as_mut_ptr(a),
60+
&n,
61+
AsPtr::as_mut_ptr(&mut eigs),
62+
AsPtr::as_mut_ptr(&mut work_size),
63+
&(-1),
64+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
6565
&mut info,
6666
);
6767
}
6868
info.as_lapack_result()?;
6969

7070
// actual ev
7171
let lwork = work_size[0].to_usize().unwrap();
72-
let mut work = unsafe { vec_uninit(lwork) };
72+
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
73+
let lwork = lwork as i32;
7374
unsafe {
7475
$ev(
75-
jobz as u8,
76-
uplo as u8,
77-
n,
78-
&mut a,
79-
n,
80-
&mut eigs,
81-
&mut work,
82-
lwork as i32,
83-
$(&mut $rwork_ident,)*
76+
jobz.as_ptr(),
77+
uplo.as_ptr(),
78+
&n,
79+
AsPtr::as_mut_ptr(a),
80+
&n,
81+
AsPtr::as_mut_ptr(&mut eigs),
82+
AsPtr::as_mut_ptr(&mut work),
83+
&lwork,
84+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
8485
&mut info,
8586
);
8687
}
@@ -92,57 +93,58 @@ macro_rules! impl_eigh {
9293
calc_v: bool,
9394
layout: MatrixLayout,
9495
uplo: UPLO,
95-
mut a: &mut [Self],
96-
mut b: &mut [Self],
96+
a: &mut [Self],
97+
b: &mut [Self],
9798
) -> Result<Vec<Self::Real>> {
9899
assert_eq!(layout.len(), layout.lda());
99100
let n = layout.len();
100101
let jobz = if calc_v { EigenVectorFlag::Calc } else { EigenVectorFlag::Not };
101102
let mut eigs = unsafe { vec_uninit(n as usize) };
102103

103104
$(
104-
let mut $rwork_ident = unsafe { vec_uninit(3 * n as usize - 2) };
105+
let mut $rwork_ident: Vec<Self::Real> = unsafe { vec_uninit(3 * n as usize - 2) };
105106
)*
106107

107108
// calc work size
108109
let mut info = 0;
109110
let mut work_size = [Self::zero()];
110111
unsafe {
111112
$evg(
112-
&[1],
113-
jobz as u8,
114-
uplo as u8,
115-
n,
116-
&mut a,
117-
n,
118-
&mut b,
119-
n,
120-
&mut eigs,
121-
&mut work_size,
122-
-1,
123-
$(&mut $rwork_ident,)*
113+
&1, // ITYPE A*x = (lambda)*B*x
114+
jobz.as_ptr(),
115+
uplo.as_ptr(),
116+
&n,
117+
AsPtr::as_mut_ptr(a),
118+
&n,
119+
AsPtr::as_mut_ptr(b),
120+
&n,
121+
AsPtr::as_mut_ptr(&mut eigs),
122+
AsPtr::as_mut_ptr(&mut work_size),
123+
&(-1),
124+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
124125
&mut info,
125126
);
126127
}
127128
info.as_lapack_result()?;
128129

129130
// actual evg
130131
let lwork = work_size[0].to_usize().unwrap();
131-
let mut work = unsafe { vec_uninit(lwork) };
132+
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
133+
let lwork = lwork as i32;
132134
unsafe {
133135
$evg(
134-
&[1],
135-
jobz as u8,
136-
uplo as u8,
137-
n,
138-
&mut a,
139-
n,
140-
&mut b,
141-
n,
142-
&mut eigs,
143-
&mut work,
144-
lwork as i32,
145-
$(&mut $rwork_ident,)*
136+
&1, // ITYPE A*x = (lambda)*B*x
137+
jobz.as_ptr(),
138+
uplo.as_ptr(),
139+
&n,
140+
AsPtr::as_mut_ptr(a),
141+
&n,
142+
AsPtr::as_mut_ptr(b),
143+
&n,
144+
AsPtr::as_mut_ptr(&mut eigs),
145+
AsPtr::as_mut_ptr(&mut work),
146+
&lwork,
147+
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
146148
&mut info,
147149
);
148150
}
@@ -153,7 +155,7 @@ macro_rules! impl_eigh {
153155
};
154156
} // impl_eigh!
155157

156-
impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv);
157-
impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv);
158-
impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv);
159-
impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv);
158+
impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_);
159+
impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_);
160+
impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_);
161+
impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_);

0 commit comments

Comments
 (0)