Skip to content

Commit 081fec3

Browse files
authored
Merge pull request #148 from rust-ndarray/inner_prod
Inner Product
2 parents b5598ec + f7ff6a0 commit 081fec3

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

src/inner.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use crate::types::*;
2+
use ndarray::*;
3+
4+
/// Inner Product
5+
///
6+
/// Differenct from `Dot` trait, this take complex conjugate of `self` elements
7+
///
8+
pub trait InnerProduct {
9+
type Elem: Scalar;
10+
11+
/// Inner product `(self.conjugate, rhs)
12+
fn inner<S>(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Elem
13+
where
14+
S: Data<Elem = Self::Elem>;
15+
}
16+
17+
impl<A, S> InnerProduct for ArrayBase<S, Ix1>
18+
where
19+
A: Scalar,
20+
S: Data<Elem = A>,
21+
{
22+
type Elem = A;
23+
fn inner<St: Data<Elem = A>>(&self, rhs: &ArrayBase<St, Ix1>) -> A {
24+
assert_eq!(self.len(), rhs.len());
25+
Zip::from(self)
26+
.and(rhs)
27+
.fold_while(A::zero(), |acc, s, r| FoldWhile::Continue(acc + s.conj() * *r))
28+
.into_inner()
29+
}
30+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub mod diagonal;
4343
pub mod eigh;
4444
pub mod error;
4545
pub mod generate;
46+
pub mod inner;
4647
pub mod lapack;
4748
pub mod layout;
4849
pub mod norm;
@@ -62,6 +63,7 @@ pub use convert::*;
6263
pub use diagonal::*;
6364
pub use eigh::*;
6465
pub use generate::*;
66+
pub use inner::*;
6567
pub use layout::*;
6668
pub use norm::*;
6769
pub use operator::*;

tests/inner.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
#[should_panic]
5+
#[test]
6+
fn size_shoter() {
7+
let a: Array1<f32> = Array::zeros(3);
8+
let b = Array::zeros(4);
9+
a.inner(&b);
10+
}
11+
12+
#[should_panic]
13+
#[test]
14+
fn size_longer() {
15+
let a: Array1<f32> = Array::zeros(3);
16+
let b = Array::zeros(4);
17+
b.inner(&a);
18+
}
19+
20+
#[test]
21+
fn abs() {
22+
let a: Array1<c32> = random(1);
23+
let aa = a.inner(&a);
24+
assert_aclose!(aa.re(), a.norm().powi(2), 1e-5);
25+
assert_aclose!(aa.im(), 0.0, 1e-5);
26+
}

0 commit comments

Comments
 (0)