Skip to content

Commit cea4ea9

Browse files
committed
Add test for outer and bug fix
1 parent e60f49a commit cea4ea9

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/vector.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ pub fn outer<A, S1, S2>(a: &ArrayBase<S1, Ix1>, b: &ArrayBase<S2, Ix1>) -> Array
6464
{
6565
let m = a.len();
6666
let n = b.len();
67-
let mut ab = Array::zeros((m, n));
67+
let mut ab = Array::zeros((n, m));
6868
ImplOuter::outer(m,
6969
n,
7070
a.as_slice_memory_order().unwrap(),
7171
b.as_slice_memory_order().unwrap(),
7272
ab.as_slice_memory_order_mut().unwrap());
73-
ab
73+
ab.reversed_axes()
7474
}

tests/outer.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
include!("header.rs");
2+
use ndarray_linalg::vector::outer;
3+
4+
#[test]
5+
fn outer_() {
6+
let dist = RealNormal::<f64>::new(0.0, 1.0);
7+
let m = 2;
8+
let n = 3;
9+
let a = Array::random(m, dist);
10+
let b = Array::random(n, dist);
11+
println!("a = \n{:?}", &a);
12+
println!("b = \n{:?}", &b);
13+
let ab = outer(&a, &b);
14+
println!("ab = \n{:?}", &ab);
15+
for i in 0..m {
16+
for j in 0..n {
17+
ab[(i, j)].assert_close(a[i] * b[j], 1e-7);
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)