Skip to content

Commit 68b8ae5

Browse files
committed
Use lapack_sys in solve.rs
1 parent b470f5f commit 68b8ae5

File tree

1 file changed

+65
-14
lines changed

1 file changed

+65
-14
lines changed

lax/src/solve.rs

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,16 @@ macro_rules! impl_solve {
3535
let k = ::std::cmp::min(row, col);
3636
let mut ipiv = unsafe { vec_uninit(k as usize) };
3737
let mut info = 0;
38-
unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
38+
unsafe {
39+
$getrf(
40+
&l.lda(),
41+
&l.len(),
42+
AsPtr::as_mut_ptr(a),
43+
&l.lda(),
44+
ipiv.as_mut_ptr(),
45+
&mut info,
46+
)
47+
};
3948
info.as_lapack_result()?;
4049
Ok(ipiv)
4150
}
@@ -50,20 +59,30 @@ macro_rules! impl_solve {
5059
// calc work size
5160
let mut info = 0;
5261
let mut work_size = [Self::zero()];
53-
unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
62+
unsafe {
63+
$getri(
64+
&n,
65+
AsPtr::as_mut_ptr(a),
66+
&l.lda(),
67+
ipiv.as_ptr(),
68+
AsPtr::as_mut_ptr(&mut work_size),
69+
&(-1),
70+
&mut info,
71+
)
72+
};
5473
info.as_lapack_result()?;
5574

5675
// actual
5776
let lwork = work_size[0].to_usize().unwrap();
58-
let mut work = unsafe { vec_uninit(lwork) };
77+
let mut work: Vec<Self> = unsafe { vec_uninit(lwork) };
5978
unsafe {
6079
$getri(
61-
l.len(),
62-
a,
63-
l.lda(),
64-
ipiv,
65-
&mut work,
66-
lwork as i32,
80+
&l.len(),
81+
AsPtr::as_mut_ptr(a),
82+
&l.lda(),
83+
ipiv.as_ptr(),
84+
AsPtr::as_mut_ptr(&mut work),
85+
&(lwork as i32),
6786
&mut info,
6887
)
6988
};
@@ -116,7 +135,19 @@ macro_rules! impl_solve {
116135
*b_elem = b_elem.conj();
117136
}
118137
}
119-
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
138+
unsafe {
139+
$getrs(
140+
t.as_ptr(),
141+
&n,
142+
&nrhs,
143+
AsPtr::as_ptr(a),
144+
&l.lda(),
145+
ipiv.as_ptr(),
146+
AsPtr::as_mut_ptr(b),
147+
&ldb,
148+
&mut info,
149+
)
150+
};
120151
if conj {
121152
for b_elem in &mut *b {
122153
*b_elem = b_elem.conj();
@@ -129,7 +160,27 @@ macro_rules! impl_solve {
129160
};
130161
} // impl_solve!
131162

132-
impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
133-
impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
134-
impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
135-
impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);
163+
impl_solve!(
164+
f64,
165+
lapack_sys::dgetrf_,
166+
lapack_sys::dgetri_,
167+
lapack_sys::dgetrs_
168+
);
169+
impl_solve!(
170+
f32,
171+
lapack_sys::sgetrf_,
172+
lapack_sys::sgetri_,
173+
lapack_sys::sgetrs_
174+
);
175+
impl_solve!(
176+
c64,
177+
lapack_sys::zgetrf_,
178+
lapack_sys::zgetri_,
179+
lapack_sys::zgetrs_
180+
);
181+
impl_solve!(
182+
c32,
183+
lapack_sys::cgetrf_,
184+
lapack_sys::cgetri_,
185+
lapack_sys::cgetrs_
186+
);

0 commit comments

Comments
 (0)