8
8
#![ feature( portable_simd) ]
9
9
use core_simd:: * ;
10
10
11
- // This is your barebones dot product implementation:
11
+ // This is your barebones dot product implementation:
12
12
// Take 2 vectors, multiply them element wise and *then*
13
13
// go along the resulting array and add up the result.
14
14
// In the next example we will see if there
15
15
// is any difference to adding and multiplying in tandem.
16
16
pub fn dot_prod_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
17
17
assert_eq ! ( a. len( ) , b. len( ) ) ;
18
18
19
- a. iter ( )
20
- . zip ( b. iter ( ) )
21
- . map ( |( a, b) | a * b)
22
- . sum ( )
19
+ a. iter ( ) . zip ( b. iter ( ) ) . map ( |( a, b) | a * b) . sum ( )
23
20
}
24
21
25
22
// When dealing with SIMD, it is very important to think about the amount
26
23
// of data movement and when it happens. We're going over simple computation examples here, and yet
27
24
// it is not trivial to understand what may or may not contribute to performance
28
- // changes. Eventually, you will need tools to inspect the generated assembly and confirm your
25
+ // changes. Eventually, you will need tools to inspect the generated assembly and confirm your
29
26
// hypothesis and benchmarks - we will mention them later on.
30
27
// With the use of `fold`, we're doing a multiplication,
31
28
// and then adding it to the sum, one element from both vectors at a time.
32
29
pub fn dot_prod_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
33
30
assert_eq ! ( a. len( ) , b. len( ) ) ;
34
31
a. iter ( )
35
- . zip ( b. iter ( ) )
36
- . fold ( 0.0 , |a, zipped| { a + zipped. 0 * zipped. 1 } )
32
+ . zip ( b. iter ( ) )
33
+ . fold ( 0.0 , |a, zipped| a + zipped. 0 * zipped. 1 )
37
34
}
38
35
39
36
// We now move on to the SIMD implementations: notice the following constructs:
40
37
// `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
41
- // `f32x4::from_array`: construct the SIMD vector from a slice
38
+ // `f32x4::from_array`: construct the SIMD vector from a slice
42
39
// `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
43
40
// This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
44
41
// and then add those with `sum()`. This is suboptimal.
@@ -67,11 +64,11 @@ pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {
67
64
a. array_chunks :: < 4 > ( )
68
65
. map ( |& a| f32x4:: from_array ( a) )
69
66
. zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
70
- . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| { acc + zipped. 0 * zipped. 1 } )
67
+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| acc + zipped. 0 * zipped. 1 )
71
68
. reduce_sum ( )
72
69
}
73
70
74
- // A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
71
+ // A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
75
72
// available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for.
76
73
use std_float:: StdFloat ;
77
74
pub fn dot_prod_simd_2 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
@@ -81,8 +78,10 @@ pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
81
78
a. array_chunks :: < 4 > ( )
82
79
. map ( |& a| f32x4:: from_array ( a) )
83
80
. zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
84
- . for_each ( |( a, b) | { res = a. mul_add ( b, res) ; } ) ;
85
- res. reduce_sum ( )
81
+ . for_each ( |( a, b) | {
82
+ res = a. mul_add ( b, res) ;
83
+ } ) ;
84
+ res. reduce_sum ( )
86
85
}
87
86
88
87
// Finally, we will write the same operation but handling the loop remainder.
@@ -103,10 +102,9 @@ pub fn dot_prod_simd_3(a: &[f32], b: &[f32]) -> f32 {
103
102
}
104
103
105
104
let mut sums = f32x4:: from_array ( sums) ;
106
- std:: iter:: zip ( a_chunks, b_chunks)
107
- . for_each ( |( x, y) | {
108
- sums += f32x4:: from_array ( * x) * f32x4:: from_array ( * y) ;
109
- } ) ;
105
+ std:: iter:: zip ( a_chunks, b_chunks) . for_each ( |( x, y) | {
106
+ sums += f32x4:: from_array ( * x) * f32x4:: from_array ( * y) ;
107
+ } ) ;
110
108
111
109
sums. reduce_sum ( )
112
110
}
@@ -134,6 +132,5 @@ mod tests {
134
132
135
133
// We can handle vectors that are non-multiples of 4
136
134
assert_eq ! ( 1003.0 , dot_prod_simd_3( & x, & y) ) ;
137
-
138
135
}
139
136
}
0 commit comments