diff --git a/README.md b/README.md index 66f7aee4..a50c5755 100644 --- a/README.md +++ b/README.md @@ -8,30 +8,16 @@ Linear algebra package for [rust-ndarray](https://github.com/bluss/rust-ndarray) Examples --------- +See [examples](https://github.com/termoshtt/ndarray-linalg/tree/master/examples) directory. -```rust -extern crate ndarray; -extern crate ndarray_linalg; - -use ndarray::prelude::*; -use ndarray_linalg::prelude::*; - -fn main() { - let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); - let (e, vecs) = a.clone().eigh().unwrap(); - println!("eigenvalues = \n{:?}", e); - println!("V = \n{:?}", vecs); - let av = a.dot(&vecs); - println!("AV = \n{:?}", av); -} -``` +Versions +--------- -See complete example at [src/bin/main.rs](src/bin/main.rs). +- v0.5.0 (not released) + - **Breaking Change** Rewrite all algorithms to support complex numbers and general `ArrayBase` -Progress ---------- -Some algorithms have not been implemented yet. See [#6](https://github.com/termoshtt/ndarray-linalg/issues/6). +- v0.4.1 + - ADD: assertion [#31](https://github.com/termoshtt/ndarray-linalg/pull/31) -Similar Projects ------------------ -- [linxal](https://github.com/masonium/linxal) +- v0.4.0 + - MOD: use ndarray v0.9 diff --git a/src/assert.rs b/src/assert.rs index da9d6ffa..73d584cd 100644 --- a/src/assert.rs +++ b/src/assert.rs @@ -5,6 +5,7 @@ use ndarray::*; use super::types::*; use super::norm::*; +/// check two values are close in terms of the relative torrence pub fn rclose(test: A, truth: A, rtol: Tol) -> Result where A: Field + Absolute, Tol: RealField @@ -13,6 +14,7 @@ pub fn rclose(test: A, truth: A, rtol: Tol) -> Result if dev < rtol { Ok(dev) } else { Err(dev) } } +/// check two values are close in terms of the absolute torrence pub fn aclose(test: A, truth: A, atol: Tol) -> Result where A: Field + Absolute, Tol: RealField diff --git a/src/cholesky.rs b/src/cholesky.rs index 2a61f5eb..219319e4 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -1,3 +1,4 @@ +//! Cholesky decomposition use ndarray::*; use num_traits::Zero; @@ -6,8 +7,8 @@ use super::error::*; use super::layout::*; use super::triangular::IntoTriangular; -use impl2::LapackScalar; -pub use impl2::UPLO; +use lapack_traits::LapackScalar; +pub use lapack_traits::UPLO; pub trait Cholesky { fn cholesky(self, UPLO) -> Result; diff --git a/src/eigh.rs b/src/eigh.rs index 869bca2a..e9decd06 100644 --- a/src/eigh.rs +++ b/src/eigh.rs @@ -1,11 +1,12 @@ +//! Eigenvalue decomposition for Hermite matrices use ndarray::*; use super::error::*; use super::layout::*; -use impl2::LapackScalar; -pub use impl2::UPLO; +use lapack_traits::LapackScalar; +pub use lapack_traits::UPLO; pub trait Eigh { fn eigh(self, UPLO) -> Result<(EigVal, EigVec)>; diff --git a/src/error.rs b/src/error.rs index d98df7b3..c366cfad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,7 @@ use ndarray::{Ixs, ShapeError}; pub type Result = ::std::result::Result; +/// Master Error type of this crate #[derive(Debug, EnumError)] pub enum LinalgError { NotSquare(NotSquareError), @@ -15,6 +16,7 @@ pub enum LinalgError { Shape(ShapeError), } +/// Error from LAPACK #[derive(Debug, new)] pub struct LapackError { pub return_code: i32, @@ -38,6 +40,7 @@ impl From for LapackError { } } +/// Error that matrix is not square #[derive(Debug, new)] pub struct NotSquareError { pub rows: i32, @@ -56,6 +59,7 @@ impl error::Error for NotSquareError { } } +/// Error that strides of the array is not supported #[derive(Debug, new)] pub struct StrideError { pub s0: Ixs, @@ -74,6 +78,7 @@ impl error::Error for StrideError { } } +/// Error that the memory is not aligned continously #[derive(Debug, new)] pub struct MemoryContError {} diff --git a/src/generate.rs b/src/generate.rs index c1f1598d..893b164f 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -1,3 +1,4 @@ +//! Generator functions for matrices use ndarray::*; use std::ops::*; @@ -7,6 +8,7 @@ use super::layout::*; use super::types::*; use super::error::*; +/// Hermite conjugate matrix pub fn conjugate(a: &ArrayBase) -> ArrayBase where A: Conjugate, Si: Data, @@ -19,6 +21,7 @@ pub fn conjugate(a: &ArrayBase) -> ArrayBase a } +/// Generate random array pub fn random(sh: Sh) -> ArrayBase where A: RandNormal, S: DataOwned, diff --git a/src/impl2/cholesky.rs b/src/lapack_traits/cholesky.rs similarity index 94% rename from src/impl2/cholesky.rs rename to src/lapack_traits/cholesky.rs index 9b0c9ecf..61deaae5 100644 --- a/src/impl2/cholesky.rs +++ b/src/lapack_traits/cholesky.rs @@ -1,4 +1,4 @@ -//! implement Cholesky decomposition +//! Cholesky decomposition use lapack::c; diff --git a/src/impl2/eigh.rs b/src/lapack_traits/eigh.rs similarity index 89% rename from src/impl2/eigh.rs rename to src/lapack_traits/eigh.rs index a9d6c610..9e13dcb9 100644 --- a/src/impl2/eigh.rs +++ b/src/lapack_traits/eigh.rs @@ -1,3 +1,4 @@ +//! Eigenvalue decomposition for Hermite matrices use lapack::c; use num_traits::Zero; @@ -8,6 +9,7 @@ use layout::Layout; use super::{into_result, UPLO}; +/// Wraps `*syev` for real and `*heev` for complex pub trait Eigh_: AssociatedReal { fn eigh(calc_eigenvec: bool, Layout, UPLO, a: &mut [Self]) -> Result>; } diff --git a/src/impl2/mod.rs b/src/lapack_traits/mod.rs similarity index 90% rename from src/impl2/mod.rs rename to src/lapack_traits/mod.rs index efd80e7c..7a94f690 100644 --- a/src/impl2/mod.rs +++ b/src/lapack_traits/mod.rs @@ -1,3 +1,4 @@ +//! Define traits wrapping LAPACK routines pub mod opnorm; pub mod qr; @@ -33,6 +34,7 @@ pub fn into_result(info: i32, val: T) -> Result { } } +/// Upper/Lower specification for seveal usages #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum UPLO { diff --git a/src/impl2/opnorm.rs b/src/lapack_traits/opnorm.rs similarity index 96% rename from src/impl2/opnorm.rs rename to src/lapack_traits/opnorm.rs index b1efd9f7..7ebff918 100644 --- a/src/impl2/opnorm.rs +++ b/src/lapack_traits/opnorm.rs @@ -1,4 +1,4 @@ -//! Implement Operator norms for matrices +//! Operator norms of matrices use lapack::c; use lapack::c::Layout::ColumnMajor as cm; diff --git a/src/impl2/qr.rs b/src/lapack_traits/qr.rs similarity index 94% rename from src/impl2/qr.rs rename to src/lapack_traits/qr.rs index 714135d0..46c92956 100644 --- a/src/impl2/qr.rs +++ b/src/lapack_traits/qr.rs @@ -1,4 +1,4 @@ -//! Implement QR decomposition +//! QR decomposition use std::cmp::min; use num_traits::Zero; @@ -10,6 +10,7 @@ use layout::Layout; use super::into_result; +/// Wraps `*geqrf` and `*orgqr` (`*ungqr` for complex numbers) pub trait QR_: Sized { fn householder(Layout, a: &mut [Self]) -> Result>; fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>; diff --git a/src/impl2/solve.rs b/src/lapack_traits/solve.rs similarity index 94% rename from src/impl2/solve.rs rename to src/lapack_traits/solve.rs index 8a2df165..08261cf9 100644 --- a/src/impl2/solve.rs +++ b/src/lapack_traits/solve.rs @@ -1,3 +1,4 @@ +//! Solve linear problem using LU decomposition use lapack::c; @@ -9,6 +10,7 @@ use super::{Transpose, into_result}; pub type Pivot = Vec; +/// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Sized { fn lu(Layout, a: &mut [Self]) -> Result; fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; diff --git a/src/impl2/svd.rs b/src/lapack_traits/svd.rs similarity index 90% rename from src/impl2/svd.rs rename to src/lapack_traits/svd.rs index 1151f8c8..c16725d3 100644 --- a/src/impl2/svd.rs +++ b/src/lapack_traits/svd.rs @@ -1,4 +1,4 @@ -//! Implement Operator norms for matrices +//! Singular-value decomposition use lapack::c; use num_traits::Zero; @@ -17,12 +17,17 @@ enum FlagSVD { No = b'N', } +/// Result of SVD pub struct SVDOutput { + /// diagonal values pub s: Vec, + /// Unitary matrix for destination space pub u: Option>, + /// Unitary matrix for departure space pub vt: Option>, } +/// Wraps `*gesvd` pub trait SVD_: AssociatedReal { fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; } diff --git a/src/impl2/triangular.rs b/src/lapack_traits/triangular.rs similarity index 98% rename from src/impl2/triangular.rs rename to src/lapack_traits/triangular.rs index ed37ba8e..3ca4ab8e 100644 --- a/src/impl2/triangular.rs +++ b/src/lapack_traits/triangular.rs @@ -14,6 +14,7 @@ pub enum Diag { NonUnit = b'N', } +/// Wraps `*trtri` and `*trtrs` pub trait Triangular_: Sized { fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>; fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; diff --git a/src/layout.rs b/src/layout.rs index 93783853..8e43a3b6 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -1,3 +1,4 @@ +//! Memory layout of matrices use ndarray::*; use lapack::c; diff --git a/src/lib.rs b/src/lib.rs index 67a47ef4..135de75c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,35 +1,45 @@ -//! This crate implements matrix manipulation for -//! [rust-ndarray](https://github.com/bluss/rust-ndarray) using LAPACK. +//! Linear algebra package for [rust-ndarray](https://github.com/bluss/rust-ndarray) using LAPACK via [stainless-steel/lapack](https://github.com/stainless-steel/lapack) //! -//! Basic manipulations are implemented as matrix traits, -//! [Matrix](matrix/trait.Matrix.html), [SquareMatrix](square/trait.SquareMatrix.html), -//! and [HermiteMatrix](hermite/trait.HermiteMatrix.html). +//! Linear algebra methods +//! ----------------------- +//! - [QR decomposition](qr/trait.QR.html) +//! - [singular value decomposition](svd/trait.SVD.html) +//! - [solve linear problem](solve/index.html) +//! - [solve linear problem for triangular matrix](triangular/trait.SolveTriangular.html) +//! - [inverse matrix](solve/trait.Inverse.html) +//! - [eigenvalue decomposition for Hermite matrix][eigh] //! -//! Matrix -//! ------- -//! - [singular-value decomposition](matrix/trait.Matrix.html#tymethod.svd) -//! - [LU decomposition](matrix/trait.Matrix.html#tymethod.lu) -//! - [QR decomposition](matrix/trait.Matrix.html#tymethod.qr) -//! - [operator norm for L1 norm](matrix/trait.Matrix.html#tymethod.norm_1) -//! - [operator norm for L-inf norm](matrix/trait.Matrix.html#tymethod.norm_i) -//! - [Frobeiuns norm](matrix/trait.Matrix.html#tymethod.norm_f) +//! [eigh]:eigh/trait.Eigh.html //! -//! SquareMatrix -//! ------------- -//! - [inverse of matrix](square/trait.SquareMatrix.html#tymethod.inv) -//! - [trace of matrix](square/trait.SquareMatrix.html#tymethod.trace) -//! - [WIP] eigenvalue +//! Utilities +//! ----------- +//! - [assertions for array](index.html#macros) +//! - [generator functions](generate/index.html) +//! - [Scalar trait](types/trait.Field.html) //! -//! HermiteMatrix -//! -------------- -//! - [eigenvalue analysis](hermite/trait.HermiteMatrix.html#tymethod.eigh) -//! - [symmetric square root](hermite/trait.HermiteMatrix.html#tymethod.ssqrt) -//! - [Cholesky factorization](hermite/trait.HermiteMatrix.html#tymethod.cholesky) +//! Usage +//! ------ +//! Most functions in this crate is defined as [self-consuming trait technique][sct] like [serde] +//! does. //! -//! Others -//! ------- -//! - [solve triangular](triangular/trait.SolveTriangular.html) -//! - [misc utilities](util/index.html) +//! For example, we can execute [eigh][eigh] using three types of interfaces: +//! +//! ```rust,ignore +//! let a = random((3, 3)); +//! let (eval, evec) = a.eigh(UPLO::Upper)?; +//! let (eval, evec) = (&a).eigh(UPLO::Upper)?; +//! let (eval, evec) = (&mut a).eigh(UPLO::Upper)?; +//! ``` +//! +//! The first type `a.eigh()` consumes `a`, and the memory of `a` is used for `evec`. +//! The second type `(&a).eigh()` consumes the reference (not `a` itself), +//! and the memory for `evec` is newly allocated. +//! The last one `(&mut a).eigh()` is similar to the first one; +//! It borrows `a` mutably, and rewrite it to contains `evec`. +//! In all cases, the array `eval` is newly allocated. +//! +//! [sct]:https://github.com/serde-rs/serde/releases/tag/v0.9.0 +//! [serde]:https://github.com/serde-rs/serde extern crate blas; extern crate lapack; @@ -47,7 +57,7 @@ extern crate derive_new; pub mod types; pub mod error; pub mod layout; -pub mod impl2; +pub mod lapack_traits; pub mod cholesky; pub mod eigh; diff --git a/src/norm.rs b/src/norm.rs index 63877b07..23b5a6f2 100644 --- a/src/norm.rs +++ b/src/norm.rs @@ -1,4 +1,4 @@ -//! Define trait for vectors +//! Norm of vectors use std::ops::*; use ndarray::*; diff --git a/src/opnorm.rs b/src/opnorm.rs index e2996d22..ca4fab1d 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -1,3 +1,4 @@ +//! Operator norm use ndarray::*; @@ -5,8 +6,8 @@ use super::types::*; use super::error::*; use super::layout::*; -pub use impl2::NormType; -use impl2::LapackScalar; +pub use lapack_traits::NormType; +use lapack_traits::LapackScalar; pub trait OperationNorm { type Output; diff --git a/src/qr.rs b/src/qr.rs index 032e1f72..9b2d14e8 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -1,3 +1,4 @@ +//! QR decomposition use num_traits::Zero; use ndarray::*; @@ -5,7 +6,7 @@ use ndarray::*; use super::error::*; use super::layout::*; -use impl2::LapackScalar; +use lapack_traits::LapackScalar; pub trait QR { fn qr(self) -> Result<(Q, R)>; diff --git a/src/solve.rs b/src/solve.rs index 0bf85e99..f64de129 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -1,10 +1,11 @@ +//! Solve linear problems use ndarray::*; use super::layout::*; use super::error::*; -use super::impl2::*; +use super::lapack_traits::*; -pub use impl2::{Pivot, Transpose}; +pub use lapack_traits::{Pivot, Transpose}; pub struct Factorized { pub a: ArrayBase, diff --git a/src/svd.rs b/src/svd.rs index 1aa1488c..ccb6972c 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -1,9 +1,10 @@ +//! singular-value decomposition use ndarray::*; use super::error::*; use super::layout::*; -use impl2::LapackScalar; +use lapack_traits::LapackScalar; pub trait SVD { fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option, S, Option)>; diff --git a/src/trace.rs b/src/trace.rs index 86995161..4bde4667 100644 --- a/src/trace.rs +++ b/src/trace.rs @@ -1,3 +1,4 @@ +//! Trace calculation use ndarray::*; diff --git a/src/triangular.rs b/src/triangular.rs index 7c8002d4..3c48202d 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -1,13 +1,13 @@ -//! Define methods for triangular matrices +//! Methods for triangular matrices use ndarray::*; use num_traits::Zero; use super::layout::*; use super::error::*; -use super::impl2::*; +use super::lapack_traits::*; -pub use super::impl2::Diag; +pub use super::lapack_traits::Diag; /// solve a triangular system with upper triangular matrix pub trait SolveTriangular { @@ -126,15 +126,3 @@ impl IntoTriangular> for ArrayBase self } } - -pub fn drop_upper(a: ArrayBase) -> ArrayBase - where S: DataMut -{ - a.into_triangular(UPLO::Lower) -} - -pub fn drop_lower(a: ArrayBase) -> ArrayBase - where S: DataMut -{ - a.into_triangular(UPLO::Upper) -} diff --git a/src/types.rs b/src/types.rs index d2ce0759..2dd4331c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,4 @@ +//! Basic types and their methods for linear algebra use std::ops::*; use std::fmt::Debug; @@ -8,7 +9,7 @@ use rand::Rng; use rand::distributions::*; use ndarray::LinalgScalar; -use super::impl2::LapackScalar; +use super::lapack_traits::LapackScalar; pub use num_complex::Complex32 as c32; pub use num_complex::Complex64 as c64; @@ -35,14 +36,17 @@ trait_alias!(Field: LapackScalar, trait_alias!(RealField: Field, Float); +/// Define associating real float type pub trait AssociatedReal: Sized { type Real: Float + Mul; } + +/// Define associating complex type pub trait AssociatedComplex: Sized { type Complex; } -/// Field with norm +/// Define `abs()` more generally pub trait Absolute { type Output: RealField; fn squared(&self) -> Self::Output; @@ -51,14 +55,17 @@ pub trait Absolute { } } +/// Define `sqrt()` more generally pub trait SquareRoot { fn sqrt(&self) -> Self; } +/// Complex conjugate value pub trait Conjugate: Copy { fn conj(self) -> Self; } +/// Scalars which can be initialized from Gaussian random number pub trait RandNormal { fn randn(&mut R) -> Self; } diff --git a/tests/qr.rs b/tests/qr.rs index 5dc226a0..b61e981e 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -15,7 +15,7 @@ fn test(a: Array2, n: usize, m: usize) { println!("r = \n{:?}", &r); assert_close_l2!(&q.t().dot(&q), &Array::eye(min(n, m)), 1e-7); assert_close_l2!(&q.dot(&r), &ans, 1e-7); - assert_close_l2!(&drop_lower(r.clone()), &r, 1e-7); + assert_close_l2!(&r.clone().into_triangular(UPLO::Upper), &r, 1e-7); } #[test] diff --git a/tests/svd.rs b/tests/svd.rs index f0e963d1..9a221e04 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -2,12 +2,10 @@ extern crate ndarray; #[macro_use] extern crate ndarray_linalg; -extern crate num_traits; use std::cmp::min; use ndarray::*; use ndarray_linalg::*; -use num_traits::Float; fn test(a: Array2, n: usize, m: usize) { let answer = a.clone();