|
1 | 1 | // Code taken from the `packed_simd` crate
|
2 | 2 | // Run this code with `cargo test --example dot_product`
|
3 |
| -//use core::iter::zip; |
4 | 3 | //use std::iter::zip;
|
5 | 4 |
|
6 | 5 | #![feature(array_chunks)]
|
@@ -72,17 +71,21 @@ pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {
|
72 | 71 | .reduce_sum()
|
73 | 72 | }
|
74 | 73 |
|
75 |
| -// |
| 74 | +// A lot of knowledgeable use of SIMD comes from knowing specific instructions that are |
| 75 | +// available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for. |
| 76 | +use std_float::StdFloat; |
76 | 77 | pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
|
77 | 78 | assert_eq!(a.len(), b.len());
|
78 | 79 | // TODO handle remainder when a.len() % 4 != 0
|
| 80 | + let mut res = f32x4::splat(0.0); |
79 | 81 | a.array_chunks::<4>()
|
80 | 82 | .map(|&a| f32x4::from_array(a))
|
81 | 83 | .zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
|
82 |
| - .fold(f32x4::splat(0.0), |acc, zipped| {acc + zipped.0 * zipped.1}) |
83 |
| - .reduce_sum() |
| 84 | + .for_each(|(a,b)| { res = a.mul_add(b, res); }); |
| 85 | + res.reduce_sum() |
84 | 86 | }
|
85 | 87 |
|
| 88 | +// Finally, we will write the same operation but handling the loop remainder. |
86 | 89 | const LANES: usize = 4;
|
87 | 90 | pub fn dot_prod_simd_3(a: &[f32], b: &[f32]) -> f32 {
|
88 | 91 | assert_eq!(a.len(), b.len());
|
@@ -121,12 +124,16 @@ mod tests {
|
121 | 124 | let x: Vec<f32> = [0.5; 1003].to_vec();
|
122 | 125 | let y: Vec<f32> = [2.0; 1003].to_vec();
|
123 | 126 |
|
| 127 | + // Basic check |
124 | 128 | assert_eq!(0.0, dot_prod_0(&a, &b));
|
125 | 129 | assert_eq!(0.0, dot_prod_1(&a, &b));
|
126 | 130 | assert_eq!(0.0, dot_prod_simd_0(&a, &b));
|
127 | 131 | assert_eq!(0.0, dot_prod_simd_1(&a, &b));
|
128 | 132 | assert_eq!(0.0, dot_prod_simd_2(&a, &b));
|
129 | 133 | assert_eq!(0.0, dot_prod_simd_3(&a, &b));
|
| 134 | + |
| 135 | + // We can handle vectors that are non-multiples of 4 |
130 | 136 | assert_eq!(1003.0, dot_prod_simd_3(&x, &y));
|
| 137 | + |
131 | 138 | }
|
132 | 139 | }
|
0 commit comments