Skip to content

Commit 3e18871

Browse files
authored
Merge pull request #32 from termoshtt/arraybase
Layout module
2 parents f37e155 + b9d1f01 commit 3e18871

File tree

14 files changed

+225
-126
lines changed

14 files changed

+225
-126
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ openblas = ["blas/openblas", "lapack/openblas"]
1515
netlib = ["blas/netlib", "lapack/netlib"]
1616

1717
[dependencies]
18+
derive-new = "0.4"
19+
enum-error-derive = "0.1"
1820
num-traits = "0.1"
1921
num-complex = "0.1"
2022
ndarray = { version = "0.9", default-features = false, features = ["blas"] }

src/error.rs

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@ use std::error;
44
use std::fmt;
55
use ndarray::{Ixs, ShapeError};
66

7-
#[derive(Debug)]
7+
pub type Result<T> = ::std::result::Result<T, LinalgError>;
8+
9+
#[derive(Debug, EnumError)]
10+
pub enum LinalgError {
11+
NotSquare(NotSquareError),
12+
Lapack(LapackError),
13+
Stride(StrideError),
14+
MemoryCont(MemoryContError),
15+
Shape(ShapeError),
16+
}
17+
18+
#[derive(Debug, new)]
819
pub struct LapackError {
920
pub return_code: i32,
1021
}
@@ -27,10 +38,10 @@ impl From<i32> for LapackError {
2738
}
2839
}
2940

30-
#[derive(Debug)]
41+
#[derive(Debug, new)]
3142
pub struct NotSquareError {
32-
pub rows: usize,
33-
pub cols: usize,
43+
pub rows: i32,
44+
pub cols: i32,
3445
}
3546

3647
impl fmt::Display for NotSquareError {
@@ -45,7 +56,7 @@ impl error::Error for NotSquareError {
4556
}
4657
}
4758

48-
#[derive(Debug)]
59+
#[derive(Debug, new)]
4960
pub struct StrideError {
5061
pub s0: Ixs,
5162
pub s1: Ixs,
@@ -63,55 +74,17 @@ impl error::Error for StrideError {
6374
}
6475
}
6576

66-
#[derive(Debug)]
67-
pub enum LinalgError {
68-
NotSquare(NotSquareError),
69-
Lapack(LapackError),
70-
Stride(StrideError),
71-
Shape(ShapeError),
72-
}
77+
#[derive(Debug, new)]
78+
pub struct MemoryContError {}
7379

74-
impl fmt::Display for LinalgError {
80+
impl fmt::Display for MemoryContError {
7581
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76-
match *self {
77-
LinalgError::NotSquare(ref err) => err.fmt(f),
78-
LinalgError::Lapack(ref err) => err.fmt(f),
79-
LinalgError::Stride(ref err) => err.fmt(f),
80-
LinalgError::Shape(ref err) => err.fmt(f),
81-
}
82+
write!(f, "Memory is not contiguous")
8283
}
8384
}
8485

85-
impl error::Error for LinalgError {
86+
impl error::Error for MemoryContError {
8687
fn description(&self) -> &str {
87-
match *self {
88-
LinalgError::NotSquare(ref err) => err.description(),
89-
LinalgError::Lapack(ref err) => err.description(),
90-
LinalgError::Stride(ref err) => err.description(),
91-
LinalgError::Shape(ref err) => err.description(),
92-
}
93-
}
94-
}
95-
96-
impl From<NotSquareError> for LinalgError {
97-
fn from(err: NotSquareError) -> LinalgError {
98-
LinalgError::NotSquare(err)
99-
}
100-
}
101-
102-
impl From<LapackError> for LinalgError {
103-
fn from(err: LapackError) -> LinalgError {
104-
LinalgError::Lapack(err)
105-
}
106-
}
107-
108-
impl From<StrideError> for LinalgError {
109-
fn from(err: StrideError) -> LinalgError {
110-
LinalgError::Stride(err)
111-
}
112-
}
113-
impl From<ShapeError> for LinalgError {
114-
fn from(err: ShapeError) -> LinalgError {
115-
LinalgError::Shape(err)
88+
"Memory is not contiguous"
11689
}
11790
}

src/impl2/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
pub mod opnorm;
3+
pub use self::opnorm::*;
4+
5+
pub trait LapackScalar: OperatorNorm_ {}
6+
impl<A> LapackScalar for A where A: OperatorNorm_ {}

src/impl2/opnorm.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//! Implement Operator norms for matrices
2+
3+
use lapack::c;
4+
use lapack::c::Layout::ColumnMajor as cm;
5+
6+
use types::*;
7+
use layout::*;
8+
9+
#[repr(u8)]
10+
pub enum NormType {
11+
One = b'o',
12+
Infinity = b'i',
13+
Frobenius = b'f',
14+
}
15+
16+
impl NormType {
17+
fn transpose(self) -> Self {
18+
match self {
19+
NormType::One => NormType::Infinity,
20+
NormType::Infinity => NormType::One,
21+
NormType::Frobenius => NormType::Frobenius,
22+
}
23+
}
24+
}
25+
26+
pub trait OperatorNorm_: AssociatedReal {
27+
fn opnorm(NormType, Layout, &[Self]) -> Self::Real;
28+
}
29+
30+
macro_rules! impl_opnorm {
31+
($scalar:ty, $lange:path) => {
32+
impl OperatorNorm_ for $scalar {
33+
fn opnorm(t: NormType, l: Layout, a: &[Self]) -> Self::Real {
34+
match l {
35+
Layout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
36+
Layout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda),
37+
}
38+
}
39+
}
40+
}} // impl_opnorm!
41+
42+
impl_opnorm!(f64, c::dlange);
43+
impl_opnorm!(f32, c::slange);
44+
impl_opnorm!(c64, c::zlange);
45+
impl_opnorm!(c32, c::clange);

src/impls/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ pub mod outer;
33
pub mod qr;
44
pub mod svd;
55
pub mod eigh;
6-
pub mod opnorm;
76
pub mod solve;
87
pub mod cholesky;

src/impls/opnorm.rs

Lines changed: 0 additions & 27 deletions
This file was deleted.

src/layout.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
use ndarray::*;
3+
4+
use super::error::*;
5+
6+
pub type LDA = i32;
7+
pub type Col = i32;
8+
pub type Row = i32;
9+
10+
pub enum Layout {
11+
C((Row, LDA)),
12+
F((Col, LDA)),
13+
}
14+
15+
impl Layout {
16+
pub fn size(&self) -> (Row, Col) {
17+
match *self {
18+
Layout::C((row, lda)) => (row, lda),
19+
Layout::F((col, lda)) => (lda, col),
20+
}
21+
}
22+
}
23+
24+
pub trait AllocatedArray {
25+
type Scalar;
26+
fn layout(&self) -> Result<Layout>;
27+
fn square_layout(&self) -> Result<Layout>;
28+
fn as_allocated(&self) -> Result<&[Self::Scalar]>;
29+
}
30+
31+
impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
32+
where S: Data<Elem = A>
33+
{
34+
type Scalar = A;
35+
36+
fn layout(&self) -> Result<Layout> {
37+
let strides = self.strides();
38+
if ::std::cmp::min(strides[0], strides[1]) != 1 {
39+
return Err(StrideError::new(strides[0], strides[1]).into());
40+
}
41+
if strides[0] > strides[1] {
42+
Ok(Layout::C((self.rows() as i32, self.cols() as i32)))
43+
} else {
44+
Ok(Layout::F((self.cols() as i32, self.rows() as i32)))
45+
}
46+
}
47+
48+
fn square_layout(&self) -> Result<Layout> {
49+
let l = self.layout()?;
50+
let (n, m) = l.size();
51+
if n == m {
52+
Ok(l)
53+
} else {
54+
Err(NotSquareError::new(n, m).into())
55+
}
56+
}
57+
58+
fn as_allocated(&self) -> Result<&[A]> {
59+
let slice = self.as_slice_memory_order().ok_or(MemoryContError::new())?;
60+
Ok(slice)
61+
}
62+
}

src/lib.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,18 @@ extern crate num_traits;
3737
extern crate num_complex;
3838
#[macro_use(s)]
3939
extern crate ndarray;
40+
#[macro_use]
41+
extern crate enum_error_derive;
42+
#[macro_use]
43+
extern crate derive_new;
4044

41-
pub mod impls;
45+
pub mod types;
4246
pub mod error;
47+
pub mod layout;
48+
pub mod impls;
49+
pub mod impl2;
50+
51+
pub mod traits;
4352

4453
pub mod vector;
4554
pub mod matrix;

src/matrix.rs

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ use lapack::c::Layout;
88
use super::error::{LinalgError, StrideError};
99
use super::impls::qr::ImplQR;
1010
use super::impls::svd::ImplSVD;
11-
use super::impls::opnorm::ImplOpNorm;
1211
use super::impls::solve::ImplSolve;
1312

14-
pub trait MFloat: ImplQR + ImplSVD + ImplOpNorm + ImplSolve + NdFloat {}
15-
impl<A: ImplQR + ImplSVD + ImplOpNorm + ImplSolve + NdFloat> MFloat for A {}
13+
pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {}
14+
impl<A: ImplQR + ImplSVD + ImplSolve + NdFloat> MFloat for A {}
1615

1716
/// Methods for general matrices
1817
pub trait Matrix: Sized {
@@ -23,12 +22,6 @@ pub trait Matrix: Sized {
2322
fn size(&self) -> (usize, usize);
2423
/// Layout (C/Fortran) of matrix
2524
fn layout(&self) -> Result<Layout, StrideError>;
26-
/// Operator norm for L-1 norm
27-
fn opnorm_1(&self) -> Self::Scalar;
28-
/// Operator norm for L-inf norm
29-
fn opnorm_i(&self) -> Self::Scalar;
30-
/// Frobenius norm
31-
fn opnorm_f(&self) -> Self::Scalar;
3225
/// singular-value decomposition (SVD)
3326
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
3427
/// QR decomposition
@@ -84,28 +77,6 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
8477
fn layout(&self) -> Result<Layout, StrideError> {
8578
check_layout(self.strides())
8679
}
87-
fn opnorm_1(&self) -> Self::Scalar {
88-
let (m, n) = self.size();
89-
let strides = self.strides();
90-
if strides[0] > strides[1] {
91-
ImplOpNorm::opnorm_i(n, m, self.clone().into_raw_vec())
92-
} else {
93-
ImplOpNorm::opnorm_1(m, n, self.clone().into_raw_vec())
94-
}
95-
}
96-
fn opnorm_i(&self) -> Self::Scalar {
97-
let (m, n) = self.size();
98-
let strides = self.strides();
99-
if strides[0] > strides[1] {
100-
ImplOpNorm::opnorm_1(n, m, self.clone().into_raw_vec())
101-
} else {
102-
ImplOpNorm::opnorm_i(m, n, self.clone().into_raw_vec())
103-
}
104-
}
105-
fn opnorm_f(&self) -> Self::Scalar {
106-
let (m, n) = self.size();
107-
ImplOpNorm::opnorm_f(m, n, self.clone().into_raw_vec())
108-
}
10980
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
11081
let (n, m) = self.size();
11182
let layout = self.layout()?;
@@ -192,18 +163,6 @@ impl<A: MFloat> Matrix for RcArray<A, Ix2> {
192163
fn layout(&self) -> Result<Layout, StrideError> {
193164
check_layout(self.strides())
194165
}
195-
fn opnorm_1(&self) -> Self::Scalar {
196-
// XXX unnecessary clone
197-
self.to_owned().opnorm_1()
198-
}
199-
fn opnorm_i(&self) -> Self::Scalar {
200-
// XXX unnecessary clone
201-
self.to_owned().opnorm_i()
202-
}
203-
fn opnorm_f(&self) -> Self::Scalar {
204-
// XXX unnecessary clone
205-
self.to_owned().opnorm_f()
206-
}
207166
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
208167
let (u, s, v) = self.into_owned().svd()?;
209168
Ok((u.into_shared(), s.into_shared(), v.into_shared()))

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pub use hermite::HermiteMatrix;
55
pub use triangular::*;
66
pub use util::*;
77
pub use assert::*;
8+
pub use traits::*;

0 commit comments

Comments
 (0)