Skip to content

Commit 27371c5

Browse files
committed
Introduce Field, RealField
1 parent 873272b commit 27371c5

File tree

7 files changed

+80
-35
lines changed

7 files changed

+80
-35
lines changed

src/assert.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
11
//! Assertions for array
22
3-
use std::iter::Sum;
4-
use num_traits::Float;
53
use ndarray::*;
64

75
use super::types::*;
86
use super::vector::*;
97

108
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>
11-
where A: LinalgScalar + Absolute<Output = Tol>,
12-
Tol: Float
9+
where A: Field + Absolute<Output = Tol>,
10+
Tol: RealField
1311
{
1412
let dev = (test - truth).abs() / truth.abs();
1513
if dev < rtol { Ok(dev) } else { Err(dev) }
1614
}
1715

1816
pub fn aclose<A, Tol>(test: A, truth: A, atol: Tol) -> Result<Tol, Tol>
19-
where A: LinalgScalar + Absolute<Output = Tol>,
20-
Tol: Float
17+
where A: Field + Absolute<Output = Tol>,
18+
Tol: RealField
2119
{
2220
let dev = (test - truth).abs();
2321
if dev < atol { Ok(dev) } else { Err(dev) }
2422
}
2523

2624
/// check two arrays are close in maximum norm
2725
pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: Tol) -> Result<Tol, Tol>
28-
where A: LinalgScalar + Absolute<Output = Tol>,
29-
Tol: Float + Sum,
26+
where A: Field + Absolute<Output = Tol>,
27+
Tol: RealField,
3028
S1: Data<Elem = A>,
3129
S2: Data<Elem = A>,
3230
D: Dimension
@@ -37,8 +35,8 @@ pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S
3735

3836
/// check two arrays are close in L1 norm
3937
pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
40-
where A: LinalgScalar + Absolute<Output = Tol>,
41-
Tol: Float + Sum,
38+
where A: Field + Absolute<Output = Tol>,
39+
Tol: RealField,
4240
S1: Data<Elem = A>,
4341
S2: Data<Elem = A>,
4442
D: Dimension
@@ -49,8 +47,8 @@ pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2
4947

5048
/// check two arrays are close in L2 norm
5149
pub fn close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
52-
where A: LinalgScalar + Absolute<Output = Tol>,
53-
Tol: Float + Sum,
50+
where A: Field + Absolute<Output = Tol>,
51+
Tol: RealField,
5452
S1: Data<Elem = A>,
5553
S2: Data<Elem = A>,
5654
D: Dimension

src/impl2/triangular.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use lapack::c;
44

55
use error::*;
6+
use types::*;
67
use layout::Layout;
78
use super::{UPLO, Transpose, into_result};
89

@@ -18,11 +19,14 @@ pub trait Triangular_: Sized {
1819
fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>;
1920
}
2021

21-
impl Triangular_ for f64 {
22+
macro_rules! impl_triangular {
23+
($scalar:ty, $trtri:path, $trtrs:path) => {
24+
25+
impl Triangular_ for $scalar {
2226
fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> {
2327
let (n, _) = l.size();
2428
let lda = l.lda();
25-
let info = c::dtrtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda);
29+
let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda);
2630
into_result(info, ())
2731
}
2832

@@ -31,16 +35,14 @@ impl Triangular_ for f64 {
3135
let lda = al.lda();
3236
let nrhs = bl.len();
3337
let ldb = bl.lda();
34-
let info = c::dtrtrs(al.lapacke_layout(),
35-
uplo as u8,
36-
Transpose::No as u8,
37-
diag as u8,
38-
n,
39-
nrhs,
40-
a,
41-
lda,
42-
&mut b,
43-
ldb);
38+
let info = $trtrs(al.lapacke_layout(), uplo as u8, Transpose::No as u8, diag as u8, n, nrhs, a, lda, &mut b, ldb);
4439
into_result(info, ())
4540
}
4641
}
42+
43+
}} // impl_triangular!
44+
45+
impl_triangular!(f64, c::dtrtri, c::dtrtrs);
46+
impl_triangular!(f32, c::strtri, c::strtrs);
47+
impl_triangular!(c64, c::ztrtri, c::ztrtrs);
48+
impl_triangular!(c32, c::ctrtri, c::ctrtrs);

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ pub use opnorm::*;
1313
pub use solve::*;
1414
pub use eigh::*;
1515
pub use cholesky::*;
16+
pub use impl2::LapackScalar;

src/triangular.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use super::layout::*;
77
use super::error::*;
88
use super::impl2::*;
99

10+
pub use super::impl2::Diag;
11+
1012
/// solve a triangular system with upper triangular matrix
1113
pub trait SolveTriangular<Rhs> {
1214
type Output;

src/types.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
11

2-
pub use num_complex::Complex32 as c32;
3-
pub use num_complex::Complex64 as c64;
4-
use num_complex::Complex;
5-
use num_traits::Float;
62
use std::ops::*;
3+
use std::fmt::Debug;
4+
use std::iter::Sum;
5+
use num_complex::Complex;
6+
use num_traits::*;
77
use rand::Rng;
88
use rand::distributions::*;
9+
use ndarray::LinalgScalar;
10+
11+
use super::impl2::LapackScalar;
12+
13+
pub use num_complex::Complex32 as c32;
14+
pub use num_complex::Complex64 as c64;
15+
16+
macro_rules! trait_alias {
17+
($name:ident: $($t:ident),*) => {
18+
19+
pub trait $name : $($t +)* {}
20+
21+
impl<T> $name for T where T: $($t +)* {}
22+
23+
}} // trait_alias!
24+
25+
trait_alias!(Field: LapackScalar,
26+
LinalgScalar,
27+
AssociatedReal,
28+
AssociatedComplex,
29+
Absolute,
30+
SquareRoot,
31+
Conjugate,
32+
RandNormal,
33+
Sum,
34+
Debug);
35+
36+
trait_alias!(RealField: Field, Float);
937

1038
pub trait AssociatedReal: Sized {
1139
type Real: Float + Mul<Self, Output = Self>;
@@ -16,13 +44,17 @@ pub trait AssociatedComplex: Sized {
1644

1745
/// Field with norm
1846
pub trait Absolute {
19-
type Output: Float;
47+
type Output: RealField;
2048
fn squared(&self) -> Self::Output;
2149
fn abs(&self) -> Self::Output {
2250
self.squared().sqrt()
2351
}
2452
}
2553

54+
pub trait SquareRoot {
55+
fn sqrt(&self) -> Self;
56+
}
57+
2658
pub trait Conjugate: Copy {
2759
fn conj(self) -> Self;
2860
}
@@ -70,6 +102,18 @@ impl Absolute for $complex {
70102
}
71103
}
72104

105+
impl SquareRoot for $real {
106+
fn sqrt(&self) -> Self {
107+
Float::sqrt(*self)
108+
}
109+
}
110+
111+
impl SquareRoot for $complex {
112+
fn sqrt(&self) -> Self {
113+
Complex::sqrt(self)
114+
}
115+
}
116+
73117
impl Conjugate for $real {
74118
fn conj(self) -> Self {
75119
self

src/util.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! misc utilities
22
3-
use std::iter::Sum;
43
use ndarray::*;
54
use num_traits::Float;
65
use std::ops::Div;
@@ -15,9 +14,9 @@ pub enum NormalizeAxis {
1514

1615
/// normalize in L2 norm
1716
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
18-
where A: LinalgScalar + Absolute<Output = T> + Div<T, Output = A>,
17+
where A: Field + Absolute<Output = T> + Div<T, Output = A>,
1918
S: DataMut<Elem = A>,
20-
T: Float + Sum
19+
T: Field + Float
2120
{
2221
let mut ms = Vec::new();
2322
for mut v in m.axis_iter_mut(Axis(axis as usize)) {

src/vector.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! Define trait for vectors
22
3-
use std::iter::Sum;
43
use ndarray::*;
54
use num_traits::Float;
65

@@ -24,8 +23,8 @@ pub trait Norm {
2423
}
2524

2625
impl<A, S, D, T> Norm for ArrayBase<S, D>
27-
where A: LinalgScalar + Absolute<Output = T>,
28-
T: Float + Sum,
26+
where A: Field + Absolute<Output = T>,
27+
T: Field + Float,
2928
S: Data<Elem = A>,
3029
D: Dimension
3130
{

0 commit comments

Comments
 (0)