Skip to content

Commit ba579e4

Browse files
committed
Merge branch 'rcarray'
2 parents 7601cb5 + a35cc78 commit ba579e4

18 files changed

+504
-610
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ lapack = "0.11.1"
1414
num-traits = "0.1.36"
1515

1616
[dependencies.ndarray]
17-
version = "0.7"
17+
version = "0.7.3"
1818
features = ["blas"]
1919

2020
[dev-dependencies]

src/hermite.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Define trait for Hermite matrices
22
3-
use ndarray::{Ix2, Array};
3+
use ndarray::{Ix2, Array, RcArray};
44
use lapack::c::Layout;
55

66
use matrix::{Matrix, MFloat};
@@ -71,3 +71,21 @@ impl<A: HMFloat> HermiteMatrix for Array<A, Ix2> {
7171
Ok(rt * rt)
7272
}
7373
}
74+
75+
impl<A: HMFloat> HermiteMatrix for RcArray<A, Ix2> {
76+
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
77+
let (e, v) = self.into_owned().eigh()?;
78+
Ok((e.into_shared(), v.into_shared()))
79+
}
80+
fn ssqrt(self) -> Result<Self, LinalgError> {
81+
let s = self.into_owned().ssqrt()?;
82+
Ok(s.into_shared())
83+
}
84+
fn cholesky(self) -> Result<Self, LinalgError> {
85+
let s = self.into_owned().cholesky()?;
86+
Ok(s.into_shared())
87+
}
88+
fn deth(self) -> Result<Self::Scalar, LinalgError> {
89+
self.into_owned().deth()
90+
}
91+
}

src/matrix.rs

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::cmp::min;
44
use ndarray::prelude::*;
5+
use ndarray::DataMut;
56
use lapack::c::Layout;
67

78
use error::{LinalgError, StrideError};
@@ -43,6 +44,35 @@ pub trait Matrix: Sized {
4344
}
4445
}
4546

47+
fn check_layout(strides: &[Ixs]) -> Result<Layout, StrideError> {
48+
if min(strides[0], strides[1]) != 1 {
49+
return Err(StrideError {
50+
s0: strides[0],
51+
s1: strides[1],
52+
});;
53+
}
54+
if strides[0] < strides[1] {
55+
Ok(Layout::ColumnMajor)
56+
} else {
57+
Ok(Layout::RowMajor)
58+
}
59+
}
60+
61+
fn permutate<A: NdFloat, S>(mut a: &mut ArrayBase<S, Ix2>, ipiv: &Vec<i32>)
62+
where S: DataMut<Elem = A>
63+
{
64+
let m = a.cols();
65+
for (i, j_) in ipiv.iter().enumerate().rev() {
66+
let j = (j_ - 1) as usize;
67+
if i == j {
68+
continue;
69+
}
70+
for k in 0..m {
71+
a.swap((i, k), (j, k));
72+
}
73+
}
74+
}
75+
4676
impl<A: MFloat> Matrix for Array<A, Ix2> {
4777
type Scalar = A;
4878
type Vector = Array<A, Ix1>;
@@ -52,18 +82,7 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
5282
(self.rows(), self.cols())
5383
}
5484
fn layout(&self) -> Result<Layout, StrideError> {
55-
let strides = self.strides();
56-
if min(strides[0], strides[1]) != 1 {
57-
return Err(StrideError {
58-
s0: strides[0],
59-
s1: strides[1],
60-
});;
61-
}
62-
if strides[0] < strides[1] {
63-
Ok(Layout::ColumnMajor)
64-
} else {
65-
Ok(Layout::RowMajor)
66-
}
85+
check_layout(self.strides())
6786
}
6887
fn opnorm_1(&self) -> Self::Scalar {
6988
let (m, n) = self.size();
@@ -159,15 +178,45 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
159178
Ok((p, lm, am))
160179
}
161180
fn permutate(&mut self, ipiv: &Self::Permutator) {
162-
let (_, m) = self.size();
163-
for (i, j_) in ipiv.iter().enumerate().rev() {
164-
let j = (j_ - 1) as usize;
165-
if i == j {
166-
continue;
167-
}
168-
for k in 0..m {
169-
self.swap((i, k), (j, k));
170-
}
171-
}
181+
permutate(self, ipiv);
182+
}
183+
}
184+
185+
impl<A: MFloat> Matrix for RcArray<A, Ix2> {
186+
type Scalar = A;
187+
type Vector = RcArray<A, Ix1>;
188+
type Permutator = Vec<i32>;
189+
fn size(&self) -> (usize, usize) {
190+
(self.rows(), self.cols())
191+
}
192+
fn layout(&self) -> Result<Layout, StrideError> {
193+
check_layout(self.strides())
194+
}
195+
fn opnorm_1(&self) -> Self::Scalar {
196+
// XXX unnecessary clone
197+
self.to_owned().opnorm_1()
198+
}
199+
fn opnorm_i(&self) -> Self::Scalar {
200+
// XXX unnecessary clone
201+
self.to_owned().opnorm_i()
202+
}
203+
fn opnorm_f(&self) -> Self::Scalar {
204+
// XXX unnecessary clone
205+
self.to_owned().opnorm_f()
206+
}
207+
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
208+
let (u, s, v) = self.into_owned().svd()?;
209+
Ok((u.into_shared(), s.into_shared(), v.into_shared()))
210+
}
211+
fn qr(self) -> Result<(Self, Self), LinalgError> {
212+
let (q, r) = self.into_owned().qr()?;
213+
Ok((q.into_shared(), r.into_shared()))
214+
}
215+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
216+
let (p, l, u) = self.into_owned().lu()?;
217+
Ok((p, l.into_shared(), u.into_shared()))
218+
}
219+
fn permutate(&mut self, ipiv: &Self::Permutator) {
220+
permutate(self, ipiv);
172221
}
173222
}

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pub use matrix::Matrix;
44
pub use square::SquareMatrix;
55
pub use hermite::HermiteMatrix;
66
pub use triangular::{TriangularMatrix, drop_lower, drop_upper};
7+
pub use util::{all_close_l1, all_close_l2, all_close_max};

src/square.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Define trait for Hermite matrices
22
3-
use ndarray::{Ix2, Array};
3+
use ndarray::{Ix2, Array, RcArray, ArrayBase, Data};
44
use lapack::c::Layout;
55

66
use matrix::{Matrix, MFloat};
@@ -32,6 +32,13 @@ pub trait SquareMatrix: Matrix {
3232
}
3333
}
3434

35+
fn trace<A: MFloat, S>(a: &ArrayBase<S, Ix2>) -> A
36+
where S: Data<Elem = A>
37+
{
38+
let n = a.rows();
39+
(0..n).fold(A::zero(), |sum, i| sum + a[(i, i)])
40+
}
41+
3542
impl<A: MFloat> SquareMatrix for Array<A, Ix2> {
3643
fn inv(self) -> Result<Self, LinalgError> {
3744
self.check_square()?;
@@ -47,7 +54,18 @@ impl<A: MFloat> SquareMatrix for Array<A, Ix2> {
4754
}
4855
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
4956
self.check_square()?;
50-
let (n, _) = self.size();
51-
Ok((0..n).fold(A::zero(), |sum, i| sum + self[(i, i)]))
57+
Ok(trace(self))
58+
}
59+
}
60+
61+
impl<A: MFloat> SquareMatrix for RcArray<A, Ix2> {
62+
fn inv(self) -> Result<Self, LinalgError> {
63+
// XXX unnecessary clone (should use into_owned())
64+
let i = self.to_owned().inv()?;
65+
Ok(i.into_shared())
66+
}
67+
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
68+
self.check_square()?;
69+
Ok(trace(self))
5270
}
5371
}

src/triangular.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
use ndarray::{Ix2, Array};
2+
use ndarray::{Ix2, Array, RcArray, NdFloat, ArrayBase, DataMut};
33

44
use matrix::{Matrix, MFloat};
55
use square::SquareMatrix;
@@ -32,19 +32,36 @@ impl<A: MFloat> TriangularMatrix for Array<A, Ix2> {
3232
}
3333
}
3434

35-
pub fn drop_upper(mut a: Array<f64, Ix2>) -> Array<f64, Ix2> {
35+
impl<A: MFloat> TriangularMatrix for RcArray<A, Ix2> {
36+
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
37+
// XXX unnecessary clone
38+
let x = self.to_owned().solve_upper(b.to_owned())?;
39+
Ok(x.into_shared())
40+
}
41+
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
42+
// XXX unnecessary clone
43+
let x = self.to_owned().solve_lower(b.to_owned())?;
44+
Ok(x.into_shared())
45+
}
46+
}
47+
48+
pub fn drop_upper<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
49+
where S: DataMut<Elem = A>
50+
{
3651
for ((i, j), val) in a.indexed_iter_mut() {
3752
if i < j {
38-
*val = 0.0;
53+
*val = A::zero();
3954
}
4055
}
4156
a
4257
}
4358

44-
pub fn drop_lower(mut a: Array<f64, Ix2>) -> Array<f64, Ix2> {
59+
pub fn drop_lower<A: NdFloat, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
60+
where S: DataMut<Elem = A>
61+
{
4562
for ((i, j), val) in a.indexed_iter_mut() {
4663
if i > j {
47-
*val = 0.0;
64+
*val = A::zero();
4865
}
4966
}
5067
a

tests/cholesky.rs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
include!("header.rs");
22

3-
#[test]
4-
fn cholesky() {
5-
let r_dist = RealNormal::new(0., 1.);
6-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
7-
a = a.dot(&a.t());
8-
println!("a = \n{:?}", a);
9-
let c = a.clone().cholesky().unwrap();
10-
println!("c = \n{:?}", c);
11-
println!("cc = \n{:?}", c.t().dot(&c));
12-
c.t().dot(&c).assert_allclose_l2(&a, 1e-7);
3+
macro_rules! impl_test {
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use super::random_hermite;
7+
use ndarray_linalg::prelude::*;
8+
#[test]
9+
fn cholesky() {
10+
let a = random_hermite(3);
11+
println!("a = \n{:?}", a);
12+
let c = a.$clone().cholesky().unwrap();
13+
println!("c = \n{:?}", c);
14+
println!("cc = \n{:?}", c.t().dot(&c));
15+
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
16+
}
17+
#[test]
18+
fn cholesky_t() {
19+
let a = random_hermite(3);
20+
println!("a = \n{:?}", a);
21+
let c = a.$clone().cholesky().unwrap();
22+
println!("c = \n{:?}", c);
23+
println!("cc = \n{:?}", c.t().dot(&c));
24+
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
25+
}
1326
}
27+
}} // impl_test
1428

15-
#[test]
16-
fn cholesky_t() {
17-
let r_dist = RealNormal::new(0., 1.);
18-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
19-
a = a.dot(&a.t()).reversed_axes();
20-
println!("a = \n{:?}", a);
21-
let c = a.clone().cholesky().unwrap();
22-
println!("c = \n{:?}", c);
23-
println!("cc = \n{:?}", c.t().dot(&c));
24-
c.t().dot(&c).assert_allclose_l2(&a, 1e-7);
25-
}
29+
impl_test!(owned, clone);
30+
impl_test!(shared, to_shared);

tests/det.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
include!("header.rs");
22

3-
fn random_hermite(n: usize) -> Array<f64, Ix2> {
4-
let r_dist = RealNormal::new(0., 1.);
5-
let a = Array::<f64, _>::random((n, n), r_dist);
6-
a.dot(&a.t())
3+
macro_rules! impl_test{
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use super::random_hermite;
7+
use ndarray_linalg::prelude::*;
8+
use ndarray_numtest::prelude::*;
9+
#[test]
10+
fn deth() {
11+
let a = random_hermite(3);
12+
let (e, _) = a.$clone().eigh().unwrap();
13+
let deth = a.$clone().deth().unwrap();
14+
let det_eig = e.iter().fold(1.0, |x, y| x * y);
15+
deth.assert_close(det_eig, 1.0e-7);
16+
}
717
}
18+
}} // impl_test
819

9-
#[test]
10-
fn deth() {
11-
let a = random_hermite(3);
12-
let (e, _) = a.clone().eigh().unwrap();
13-
let deth = a.clone().deth().unwrap();
14-
let det_eig = e.iter().fold(1.0, |x, y| x * y);
15-
deth.assert_close(det_eig, 1.0e-7);
16-
}
20+
impl_test!(owned, clone);
21+
impl_test!(shared, to_shared);

tests/eigh.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
include!("header.rs");
22

3-
#[test]
4-
fn eigen_vector_manual() {
5-
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
6-
let (e, vecs) = a.clone().eigh().unwrap();
7-
assert!(e.all_close(&arr1(&[2.0, 2.0, 5.0]), 1.0e-7));
8-
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
9-
let av = a.dot(&v);
10-
let ev = v.mapv(|x| e[i] * x);
11-
assert!(av.all_close(&ev, 1.0e-7));
3+
macro_rules! impl_test {
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use ndarray::prelude::*;
7+
use ndarray_linalg::prelude::*;
8+
use ndarray_numtest::prelude::*;
9+
#[test]
10+
fn eigen_vector_manual() {
11+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
12+
let (e, vecs) = a.$clone().eigh().unwrap();
13+
all_close_l2(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7).unwrap();
14+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
15+
let av = a.dot(&v);
16+
let ev = v.mapv(|x| e[i] * x);
17+
all_close_l2(&av, &ev, 1.0e-7).unwrap();
18+
}
1219
}
13-
}
14-
15-
#[test]
16-
fn diagonalize() {
17-
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
18-
let (e, vecs) = a.clone().eigh().unwrap();
19-
let s = vecs.t().dot(&a).dot(&vecs);
20-
for i in 0..3 {
21-
e[i].assert_close(s[(i, i)], 1e-7);
20+
#[test]
21+
fn diagonalize() {
22+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
23+
let (e, vecs) = a.$clone().eigh().unwrap();
24+
let s = vecs.t().dot(&a).dot(&vecs);
25+
for i in 0..3 {
26+
e[i].assert_close(s[(i, i)], 1e-7);
27+
}
2228
}
2329
}
30+
}} // impl_test
31+
32+
impl_test!(owned, clone);
33+
impl_test!(shared, to_shared);

0 commit comments

Comments
 (0)