1414use std:: autodiff:: autodiff;
1515
1616#[ no_mangle]
17- #[ autodiff( d_square1, Forward , Dual , Dual ) ]
17+ // #[autodiff(d_square1, Forward, Dual, Dual)]
1818#[ autodiff( d_square2, Forward , 4 , Dualv , Dualv ) ]
1919#[ autodiff( d_square3, Forward , 4 , Dual , Dual ) ]
2020fn square ( x : & [ f32 ] , y : & mut [ f32 ] ) {
@@ -28,7 +28,7 @@ fn square(x: &[f32], y: &mut [f32]) {
2828}
2929
3030fn main ( ) {
31- let x1 = std:: hint:: black_box ( vec ! [ 0.0 , 1.0 , 2.0 , 3.0 , 4.0 ] ) ;
31+ let x1 = std:: hint:: black_box ( vec ! [ 0.0 , 1.0 , 2.0 , 3.0 ] ) ;
3232
3333 let mut dx1 = std:: hint:: black_box ( vec ! [ 1.0 ; 12 ] ) ;
3434
@@ -66,37 +66,42 @@ fn main() {
6666 let result = std:: hint:: black_box ( x1. iter ( ) . map ( |x| 2.0 * x) . collect :: < Vec < _ > > ( ) ) ;
6767
6868 // scalar.
69- d_square1 ( & x1, & z1, & mut y1, & mut dy1_1) ;
70- d_square1 ( & x1, & z2, & mut y2, & mut dy1_2) ;
71- d_square1 ( & x1, & z3, & mut y3, & mut dy1_3) ;
72- d_square1 ( & x1, & z4, & mut y4, & mut dy1_4) ;
69+ // d_square1(&x1, &z1, &mut y1, &mut dy1_1);
70+ // d_square1(&x1, &z2, &mut y2, &mut dy1_2);
71+ // d_square1(&x1, &z3, &mut y3, &mut dy1_3);
72+ // d_square1(&x1, &z4, &mut y4, &mut dy1_4);
7373
7474 // assert y1 == y2 == y3 == y4
75- for i in 0 ..5 {
76- assert_eq ! ( y1[ i] , y2[ i] ) ;
77- assert_eq ! ( y1[ i] , y3[ i] ) ;
78- assert_eq ! ( y1[ i] , y4[ i] ) ;
79- }
75+ // for i in 0..5 {
76+ // assert_eq!(y1[i], y2[i]);
77+ // assert_eq!(y1[i], y3[i]);
78+ // assert_eq!(y1[i], y4[i]);
79+ // }
8080
8181 // batch mode A)
8282 //dx1 = std::hint::black_box(vec![1.0; 12]);
8383 d_square2 ( & x1, & z5, & mut y5, & mut dy2) ;
8484
8585 // assert y1 == y2 == y3 == y4 == y5
86- for i in 0 ..5 {
87- assert_eq ! ( y1[ i] , y5[ i] ) ;
88- }
86+ // for i in 0..5 {
87+ // assert_eq!(y1[i], y5[i]);
88+ // }
8989
9090 // batch mode B)
9191 d_square3 ( & x1, & z1, & z2, & z3, & z4, & mut y6, & mut dy3_1, & mut dy3_2, & mut dy3_3, & mut dy3_4) ;
9292 for i in 0 ..5 {
93- assert_eq ! ( y1 [ i] , y6[ i] ) ;
93+ assert_eq ! ( y5 [ i] , y6[ i] ) ;
9494 }
9595
96+ dbg ! ( & dy2) ;
97+ dbg ! ( & dy3_1) ;
98+ dbg ! ( & dy3_2) ;
99+ dbg ! ( & dy3_3) ;
100+ dbg ! ( & dy3_4) ;
96101 for i in 0 ..5 {
97- assert_eq ! ( dy1_1 [ i] , dy3_1[ i] ) ;
98- assert_eq ! ( dy1_2 [ i] , dy3_2[ i] ) ;
99- assert_eq ! ( dy1_3 [ i] , dy3_3[ i] ) ;
100- assert_eq ! ( dy1_4 [ i] , dy3_4[ i] ) ;
102+ assert_eq ! ( dy2 [ 0 .. 5 ] [ i] , dy3_1[ i] ) ;
103+ assert_eq ! ( dy2 [ 5 .. 10 ] [ i] , dy3_2[ i] ) ;
104+ assert_eq ! ( dy2 [ 10 .. 15 ] [ i] , dy3_3[ i] ) ;
105+ assert_eq ! ( dy2 [ 15 .. 20 ] [ i] , dy3_4[ i] ) ;
101106 }
102107}
0 commit comments