|
1 |
| -include!("header.rs"); |
2 | 1 |
|
3 |
| -macro_rules! impl_test { |
4 |
| - ($modname:ident, $random:path) => { |
5 |
| -mod $modname { |
6 |
| - use ndarray::prelude::*; |
7 |
| - use ndarray_linalg::prelude::*; |
8 |
| - use ndarray_rand::RandomExt; |
9 |
| - use rand_extra::*; |
10 |
| - #[test] |
11 |
| - fn solve_upper() { |
12 |
| - let r_dist = RealNormal::new(0.0, 1.0); |
13 |
| - let a = drop_lower($random((3, 3), r_dist)); |
14 |
| - println!("a = \n{:?}", &a); |
15 |
| - let b = $random(3, r_dist); |
16 |
| - println!("b = \n{:?}", &b); |
17 |
| - let x = a.solve_upper(b.clone()).unwrap(); |
18 |
| - println!("x = \n{:?}", &x); |
19 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
20 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
21 |
| - } |
| 2 | +extern crate ndarray; |
| 3 | +#[macro_use] |
| 4 | +extern crate ndarray_linalg; |
22 | 5 |
|
23 |
| - #[test] |
24 |
| - fn solve_upper_t() { |
25 |
| - let r_dist = RealNormal::new(0., 1.); |
26 |
| - let a = drop_lower($random((3, 3), r_dist).reversed_axes()); |
27 |
| - println!("a = \n{:?}", &a); |
28 |
| - let b = $random(3, r_dist); |
29 |
| - println!("b = \n{:?}", &b); |
30 |
| - let x = a.solve_upper(b.clone()).unwrap(); |
31 |
| - println!("x = \n{:?}", &x); |
32 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
33 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
34 |
| - } |
| 6 | +use ndarray::*; |
| 7 | +use ndarray_linalg::prelude::*; |
35 | 8 |
|
36 |
| - #[test] |
37 |
| - fn solve_lower() { |
38 |
| - let r_dist = RealNormal::new(0., 1.); |
39 |
| - let a = drop_upper($random((3, 3), r_dist)); |
40 |
| - println!("a = \n{:?}", &a); |
41 |
| - let b = $random(3, r_dist); |
42 |
| - println!("b = \n{:?}", &b); |
43 |
| - let x = a.solve_lower(b.clone()).unwrap(); |
44 |
| - println!("x = \n{:?}", &x); |
45 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
46 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
47 |
| - } |
48 |
| - |
49 |
| - #[test] |
50 |
| - fn solve_lower_t() { |
51 |
| - let r_dist = RealNormal::new(0., 1.); |
52 |
| - let a = drop_upper($random((3, 3), r_dist).reversed_axes()); |
53 |
| - println!("a = \n{:?}", &a); |
54 |
| - let b = $random(3, r_dist); |
55 |
| - println!("b = \n{:?}", &b); |
56 |
| - let x = a.solve_lower(b.clone()).unwrap(); |
57 |
| - println!("x = \n{:?}", &x); |
58 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
59 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
60 |
| - } |
| 9 | +fn test1d<A, Sa, Sb, Tol>(uplo: UPLO, a: ArrayBase<Sa, Ix2>, b: ArrayBase<Sb, Ix1>, tol: Tol) |
| 10 | + where A: Field + Absolute<Output = Tol>, |
| 11 | + Sa: Data<Elem = A>, |
| 12 | + Sb: DataMut<Elem = A> + DataClone, |
| 13 | + Tol: RealField |
| 14 | +{ |
| 15 | + println!("a = {:?}", &a); |
| 16 | + println!("b = {:?}", &b); |
| 17 | + let ans = b.clone(); |
| 18 | + let x = a.solve_triangular(uplo, Diag::NonUnit, b).unwrap(); |
| 19 | + let b_ = a.dot(&x); |
| 20 | + assert_close_l2!(&b_, &ans, tol); |
61 | 21 | }
|
62 |
| -}} // impl_test_opnorm |
63 |
| - |
64 |
| -impl_test!(owned, Array<f64, _>::random); |
65 |
| -impl_test!(shared, RcArray<f64, _>::random); |
66 |
| - |
67 |
| -macro_rules! impl_test_2d { |
68 |
| - ($modname:ident, $drop:path, $solve:ident) => { |
69 |
| -mod $modname { |
70 |
| - use super::random_owned; |
71 |
| - use ndarray_linalg::prelude::*; |
72 |
| - #[test] |
73 |
| - fn solve_tt() { |
74 |
| - let a = $drop(random_owned(3, 3, true)); |
75 |
| - println!("a = \n{:?}", &a); |
76 |
| - let b = random_owned(3, 2, true); |
77 |
| - println!("b = \n{:?}", &b); |
78 |
| - let x = a.$solve(&b).unwrap(); |
79 |
| - println!("x = \n{:?}", &x); |
80 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
81 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
82 |
| - } |
83 |
| - #[test] |
84 |
| - fn solve_tf() { |
85 |
| - let a = $drop(random_owned(3, 3, true)); |
86 |
| - println!("a = \n{:?}", &a); |
87 |
| - let b = random_owned(3, 2, false); |
88 |
| - println!("b = \n{:?}", &b); |
89 |
| - let x = a.$solve(&b).unwrap(); |
90 |
| - println!("x = \n{:?}", &x); |
91 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
92 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
93 |
| - } |
94 |
| - #[test] |
95 |
| - fn solve_ft() { |
96 |
| - let a = $drop(random_owned(3, 3, false)); |
97 |
| - println!("a = \n{:?}", &a); |
98 |
| - let b = random_owned(3, 2, true); |
99 |
| - println!("b = \n{:?}", &b); |
100 |
| - let x = a.$solve(&b).unwrap(); |
101 |
| - println!("x = \n{:?}", &x); |
102 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
103 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
104 |
| - } |
105 |
| - #[test] |
106 |
| - fn solve_ff() { |
107 |
| - let a = $drop(random_owned(3, 3, false)); |
108 |
| - println!("a = \n{:?}", &a); |
109 |
| - let b = random_owned(3, 2, false); |
110 |
| - println!("b = \n{:?}", &b); |
111 |
| - let x = a.$solve(&b).unwrap(); |
112 |
| - println!("x = \n{:?}", &x); |
113 |
| - println!("Ax = \n{:?}", a.dot(&x)); |
114 |
| - assert_close_l2!(&a.dot(&x), &b, 1e-7); |
115 |
| - } |
116 |
| -} |
117 |
| -}} // impl_test_2d |
118 |
| - |
119 |
| -impl_test_2d!(lower2d, drop_upper, solve_lower); |
120 |
| -impl_test_2d!(upper2d, drop_lower, solve_upper); |
0 commit comments