1
1
// Code taken from the `packed_simd` crate
2
2
// Run this code with `cargo test --example dot_product`
3
+ //use core::iter::zip;
4
+ //use std::iter::zip;
5
+
3
6
#![ feature( array_chunks) ]
7
+ #![ feature( slice_as_chunks) ]
8
+ // Add these imports to use the stdsimd library
9
+ #![ feature( portable_simd) ]
4
10
use core_simd:: * ;
5
11
6
- /// This is your barebones dot product implementation:
7
- /// Take 2 vectors, multiply them element wise and *then*
8
- /// add up the result. In the next example we will see if there
9
- /// is any difference to adding as we go along multiplying.
12
+ // This is your barebones dot product implementation:
13
+ // Take 2 vectors, multiply them element wise and *then*
14
+ // go along the resulting array and add up the result.
15
+ // In the next example we will see if there
16
+ // is any difference to adding and multiplying in tandem.
10
17
pub fn dot_prod_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
11
18
assert_eq ! ( a. len( ) , b. len( ) ) ;
12
19
13
20
a. iter ( )
14
21
. zip ( b. iter ( ) )
15
- . map ( |a, b| a * b)
22
+ . map ( |( a, b) | a * b)
16
23
. sum ( )
17
24
}
18
25
26
+ // When dealing with SIMD, it is very important to think about the amount
27
+ // of data movement and when it happens. We're going over simple computation examples here, and yet
28
+ // it is not trivial to understand what may or may not contribute to performance
29
+ // changes. Eventually, you will need tools to inspect the generated assembly and confirm your
30
+ // hypothesis and benchmarks - we will mention them later on.
31
+ // With the use of `fold`, we're doing a multiplication,
32
+ // and then adding it to the sum, one element from both vectors at a time.
19
33
pub fn dot_prod_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
20
34
assert_eq ! ( a. len( ) , b. len( ) ) ;
21
35
a. iter ( )
22
36
. zip ( b. iter ( ) )
23
- . fold ( 0.0 , |a, b| a * b )
37
+ . fold ( 0.0 , |a, zipped| { a + zipped . 0 * zipped . 1 } )
24
38
}
25
39
40
+ // We now move on to the SIMD implementations: notice the following constructs:
41
+ // `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
42
+ // `f32x4::from_array`: construct the SIMD vector from a slice
43
+ // `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
44
+ // This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
45
+ // and then add those with `sum()`. This is suboptimal.
46
+ // TODO: ASCII diagrams
26
47
pub fn dot_prod_simd_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
27
48
assert_eq ! ( a. len( ) , b. len( ) ) ;
28
-
29
49
// TODO handle remainder when a.len() % 4 != 0
30
50
a. array_chunks :: < 4 > ( )
31
51
. map ( |& a| f32x4:: from_array ( a) )
32
52
. zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
33
- . map ( |( a, b) | ( a * b) . horizontal_sum ( ) )
53
+ . map ( |( a, b) | ( a * b) . reduce_sum ( ) )
34
54
. sum ( )
35
55
}
36
56
57
+ // There's some simple ways to improve the previous code:
58
+ // 1. Make a `zero` `f32x4` SIMD vector that we will be accumulating into
59
+ // So that there is only one `sum()` reduction when the last `f32x4` has been processed
60
+ // 2. Exploit Fused Multiply Add so that the multiplication, addition and sinking into the reduciton
61
+ // happen in the same step.
62
+ // If the arrays are large, minimizing the data shuffling will lead to great perf.
63
+ // If the arrays are small, handling the remainder elements when the length isn't a multiple of 4
64
+ // Can become a problem.
65
+ pub fn dot_prod_simd_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
66
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
67
+ // TODO handle remainder when a.len() % 4 != 0
68
+ a. array_chunks :: < 4 > ( )
69
+ . map ( |& a| f32x4:: from_array ( a) )
70
+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
71
+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| { acc + zipped. 0 * zipped. 1 } )
72
+ . reduce_sum ( )
73
+ }
74
+
75
+ //
76
+ pub fn dot_prod_simd_2 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
77
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
78
+ // TODO handle remainder when a.len() % 4 != 0
79
+ a. array_chunks :: < 4 > ( )
80
+ . map ( |& a| f32x4:: from_array ( a) )
81
+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
82
+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| { acc + zipped. 0 * zipped. 1 } )
83
+ . reduce_sum ( )
84
+ }
85
+
86
+ const LANES : usize = 4 ;
87
+ pub fn dot_prod_simd_3 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
88
+ assert_eq ! ( a. len( ) , b. len( ) ) ;
89
+
90
+ let ( a_extra, a_chunks) = a. as_rchunks ( ) ;
91
+ let ( b_extra, b_chunks) = b. as_rchunks ( ) ;
92
+
93
+ // These are always true, but for emphasis:
94
+ assert_eq ! ( a_chunks. len( ) , b_chunks. len( ) ) ;
95
+ assert_eq ! ( a_extra. len( ) , b_extra. len( ) ) ;
96
+
97
+ let mut sums = [ 0.0 ; LANES ] ;
98
+ for ( ( x, y) , d) in std:: iter:: zip ( a_extra, b_extra) . zip ( & mut sums) {
99
+ * d = x * y;
100
+ }
101
+
102
+ let mut sums = f32x4:: from_array ( sums) ;
103
+ std:: iter:: zip ( a_chunks, b_chunks)
104
+ . for_each ( |( x, y) | {
105
+ sums += f32x4:: from_array ( * x) * f32x4:: from_array ( * y) ;
106
+ } ) ;
107
+
108
+ sums. reduce_sum ( )
109
+ }
37
110
fn main ( ) {
38
111
// Empty main to make cargo happy
39
112
}
@@ -45,10 +118,15 @@ mod tests {
45
118
use super :: * ;
46
119
let a: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ;
47
120
let b: Vec < f32 > = vec ! [ -8.0 , -7.0 , -6.0 , -5.0 , 4.0 , 3.0 , 2.0 , 1.0 ] ;
121
+ let x: Vec < f32 > = [ 0.5 ; 1003 ] . to_vec ( ) ;
122
+ let y: Vec < f32 > = [ 2.0 ; 1003 ] . to_vec ( ) ;
48
123
49
124
assert_eq ! ( 0.0 , dot_prod_0( & a, & b) ) ;
50
125
assert_eq ! ( 0.0 , dot_prod_1( & a, & b) ) ;
51
126
assert_eq ! ( 0.0 , dot_prod_simd_0( & a, & b) ) ;
52
127
assert_eq ! ( 0.0 , dot_prod_simd_1( & a, & b) ) ;
128
+ assert_eq ! ( 0.0 , dot_prod_simd_2( & a, & b) ) ;
129
+ assert_eq ! ( 0.0 , dot_prod_simd_3( & a, & b) ) ;
130
+ assert_eq ! ( 1003.0 , dot_prod_simd_3( & x, & y) ) ;
53
131
}
54
132
}
0 commit comments