Skip to content

Commit 311bc1d

Browse files
authored
Merge pull request #45 from termoshtt/triangular
Triangular matrix
2 parents b7c0ecf + e53a328 commit 311bc1d

File tree

17 files changed

+432
-595
lines changed

17 files changed

+432
-595
lines changed

src/generate.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,18 @@ pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
1919
a
2020
}
2121

22+
/// Random vector
23+
pub fn random_vector<A, S>(n: usize) -> ArrayBase<S, Ix1>
24+
where A: RandNormal,
25+
S: DataOwned<Elem = A>
26+
{
27+
let mut rng = thread_rng();
28+
let v: Vec<A> = (0..n).map(|_| A::randn(&mut rng)).collect();
29+
ArrayBase::from_vec(v)
30+
}
31+
2232
/// Random matrix
23-
pub fn random<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
33+
pub fn random_matrix<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
2434
where A: RandNormal,
2535
S: DataOwned<Elem = A>
2636
{
@@ -34,7 +44,7 @@ pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
3444
where A: RandNormal,
3545
S: DataOwned<Elem = A>
3646
{
37-
random(n, n)
47+
random_matrix(n, n)
3848
}
3949

4050
/// Random Hermite matrix

src/impl2/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ pub mod svd;
55
pub mod solve;
66
pub mod cholesky;
77
pub mod eigh;
8+
pub mod triangular;
89

910
pub use self::opnorm::*;
1011
pub use self::qr::*;
1112
pub use self::svd::*;
1213
pub use self::solve::*;
1314
pub use self::cholesky::*;
1415
pub use self::eigh::*;
16+
pub use self::triangular::*;
1517

1618
use super::error::*;
1719

@@ -20,7 +22,8 @@ trait_alias!(LapackScalar: OperatorNorm_,
2022
SVD_,
2123
Solve_,
2224
Cholesky_,
23-
Eigh_);
25+
Eigh_,
26+
Triangular_);
2427

2528
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
2629
if info == 0 {
@@ -36,3 +39,11 @@ pub enum UPLO {
3639
Upper = b'U',
3740
Lower = b'L',
3841
}
42+
43+
#[derive(Debug, Clone, Copy)]
44+
#[repr(u8)]
45+
pub enum Transpose {
46+
No = b'N',
47+
Transpose = b'T',
48+
Hermite = b'C',
49+
}

src/impl2/solve.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,10 @@ use types::*;
55
use error::*;
66
use layout::Layout;
77

8-
use super::into_result;
8+
use super::{Transpose, into_result};
99

1010
pub type Pivot = Vec<i32>;
1111

12-
#[derive(Debug, Clone, Copy)]
13-
#[repr(u8)]
14-
pub enum Transpose {
15-
No = b'N',
16-
Transpose = b'T',
17-
Hermite = b'C',
18-
}
19-
2012
pub trait Solve_: Sized {
2113
fn lu(Layout, a: &mut [Self]) -> Result<Pivot>;
2214
fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;

src/impl2/triangular.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//! Implement linear solver and inverse matrix
2+
3+
use lapack::c;
4+
5+
use error::*;
6+
use types::*;
7+
use layout::Layout;
8+
use super::{UPLO, Transpose, into_result};
9+
10+
#[derive(Debug, Clone, Copy)]
11+
#[repr(u8)]
12+
pub enum Diag {
13+
Unit = b'U',
14+
NonUnit = b'N',
15+
}
16+
17+
pub trait Triangular_: Sized {
18+
fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>;
19+
fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>;
20+
}
21+
22+
macro_rules! impl_triangular {
23+
($scalar:ty, $trtri:path, $trtrs:path) => {
24+
25+
impl Triangular_ for $scalar {
26+
fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> {
27+
let (n, _) = l.size();
28+
let lda = l.lda();
29+
let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda);
30+
into_result(info, ())
31+
}
32+
33+
fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> {
34+
let (n, _) = al.size();
35+
let lda = al.lda();
36+
let (_, nrhs) = bl.size();
37+
let ldb = bl.lda();
38+
println!("al = {:?}", al);
39+
println!("bl = {:?}", bl);
40+
println!("n = {}", n);
41+
println!("lda = {}", lda);
42+
println!("nrhs = {}", nrhs);
43+
println!("ldb = {}", ldb);
44+
let info = $trtrs(al.lapacke_layout(), uplo as u8, Transpose::No as u8, diag as u8, n, nrhs, a, lda, &mut b, ldb);
45+
into_result(info, ())
46+
}
47+
}
48+
49+
}} // impl_triangular!
50+
51+
impl_triangular!(f64, c::dtrtri, c::dtrtrs);
52+
impl_triangular!(f32, c::strtri, c::strtrs);
53+
impl_triangular!(c64, c::ztrtri, c::ztrtrs);
54+
impl_triangular!(c32, c::ctrtri, c::ctrtrs);

src/impls/mod.rs

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/impls/solve.rs

Lines changed: 0 additions & 87 deletions
This file was deleted.

0 commit comments

Comments
 (0)