Skip to content

Commit d4153e4

Browse files
committed
add remainder dot_product
1 parent 42b95ad commit d4153e4

File tree

1 file changed

+86
-8
lines changed

1 file changed

+86
-8
lines changed
Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,112 @@
11
// Code taken from the `packed_simd` crate
22
// Run this code with `cargo test --example dot_product`
3+
//use core::iter::zip;
4+
//use std::iter::zip;
5+
36
#![feature(array_chunks)]
7+
#![feature(slice_as_chunks)]
8+
// Add these imports to use the stdsimd library
9+
#![feature(portable_simd)]
410
use core_simd::*;
511

6-
/// This is your barebones dot product implementation:
7-
/// Take 2 vectors, multiply them element wise and *then*
8-
/// add up the result. In the next example we will see if there
9-
/// is any difference to adding as we go along multiplying.
12+
// This is your barebones dot product implementation:
13+
// Take 2 vectors, multiply them element wise and *then*
14+
// go along the resulting array and add up the result.
15+
// In the next example we will see if there
16+
// is any difference to adding and multiplying in tandem.
1017
pub fn dot_prod_0(a: &[f32], b: &[f32]) -> f32 {
1118
assert_eq!(a.len(), b.len());
1219

1320
a.iter()
1421
.zip(b.iter())
15-
.map(|a, b| a * b)
22+
.map(|(a, b)| a * b)
1623
.sum()
1724
}
1825

26+
// When dealing with SIMD, it is very important to think about the amount
27+
// of data movement and when it happens. We're going over simple computation examples here, and yet
28+
// it is not trivial to understand what may or may not contribute to performance
29+
// changes. Eventually, you will need tools to inspect the generated assembly and confirm your
30+
// hypothesis and benchmarks - we will mention them later on.
31+
// With the use of `fold`, we're doing a multiplication,
32+
// and then adding it to the sum, one element from both vectors at a time.
1933
pub fn dot_prod_1(a: &[f32], b: &[f32]) -> f32 {
2034
assert_eq!(a.len(), b.len());
2135
a.iter()
2236
.zip(b.iter())
23-
.fold(0.0, |a, b| a * b)
37+
.fold(0.0, |a, zipped| {a + zipped.0 * zipped.1})
2438
}
2539

40+
// We now move on to the SIMD implementations: notice the following constructs:
41+
// `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
42+
// `f32x4::from_array`: construct the SIMD vector from a slice
43+
// `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
44+
// This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
45+
// and then add those with `sum()`. This is suboptimal.
46+
// TODO: ASCII diagrams
2647
pub fn dot_prod_simd_0(a: &[f32], b: &[f32]) -> f32 {
2748
assert_eq!(a.len(), b.len());
28-
2949
// TODO handle remainder when a.len() % 4 != 0
3050
a.array_chunks::<4>()
3151
.map(|&a| f32x4::from_array(a))
3252
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
33-
.map(|(a, b)| (a * b).horizontal_sum())
53+
.map(|(a, b)| (a * b).reduce_sum())
3454
.sum()
3555
}
3656

57+
// There's some simple ways to improve the previous code:
58+
// 1. Make a `zero` `f32x4` SIMD vector that we will be accumulating into
59+
// So that there is only one `sum()` reduction when the last `f32x4` has been processed
60+
// 2. Exploit Fused Multiply Add so that the multiplication, addition and sinking into the reduciton
61+
// happen in the same step.
62+
// If the arrays are large, minimizing the data shuffling will lead to great perf.
63+
// If the arrays are small, handling the remainder elements when the length isn't a multiple of 4
64+
// Can become a problem.
65+
pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {
66+
assert_eq!(a.len(), b.len());
67+
// TODO handle remainder when a.len() % 4 != 0
68+
a.array_chunks::<4>()
69+
.map(|&a| f32x4::from_array(a))
70+
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
71+
.fold(f32x4::splat(0.0), |acc, zipped| {acc + zipped.0 * zipped.1})
72+
.reduce_sum()
73+
}
74+
75+
//
76+
pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
77+
assert_eq!(a.len(), b.len());
78+
// TODO handle remainder when a.len() % 4 != 0
79+
a.array_chunks::<4>()
80+
.map(|&a| f32x4::from_array(a))
81+
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
82+
.fold(f32x4::splat(0.0), |acc, zipped| {acc + zipped.0 * zipped.1})
83+
.reduce_sum()
84+
}
85+
86+
const LANES: usize = 4;
87+
pub fn dot_prod_simd_3(a: &[f32], b: &[f32]) -> f32 {
88+
assert_eq!(a.len(), b.len());
89+
90+
let (a_extra, a_chunks) = a.as_rchunks();
91+
let (b_extra, b_chunks) = b.as_rchunks();
92+
93+
// These are always true, but for emphasis:
94+
assert_eq!(a_chunks.len(), b_chunks.len());
95+
assert_eq!(a_extra.len(), b_extra.len());
96+
97+
let mut sums = [0.0; LANES];
98+
for ((x, y), d) in std::iter::zip(a_extra, b_extra).zip(&mut sums) {
99+
*d = x * y;
100+
}
101+
102+
let mut sums = f32x4::from_array(sums);
103+
std::iter::zip(a_chunks, b_chunks)
104+
.for_each(|(x, y)| {
105+
sums += f32x4::from_array(*x) * f32x4::from_array(*y);
106+
});
107+
108+
sums.reduce_sum()
109+
}
37110
fn main() {
38111
// Empty main to make cargo happy
39112
}
@@ -45,10 +118,15 @@ mod tests {
45118
use super::*;
46119
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
47120
let b: Vec<f32> = vec![-8.0, -7.0, -6.0, -5.0, 4.0, 3.0, 2.0, 1.0];
121+
let x: Vec<f32> = [0.5; 1003].to_vec();
122+
let y: Vec<f32> = [2.0; 1003].to_vec();
48123

49124
assert_eq!(0.0, dot_prod_0(&a, &b));
50125
assert_eq!(0.0, dot_prod_1(&a, &b));
51126
assert_eq!(0.0, dot_prod_simd_0(&a, &b));
52127
assert_eq!(0.0, dot_prod_simd_1(&a, &b));
128+
assert_eq!(0.0, dot_prod_simd_2(&a, &b));
129+
assert_eq!(0.0, dot_prod_simd_3(&a, &b));
130+
assert_eq!(1003.0, dot_prod_simd_3(&x, &y));
53131
}
54132
}

0 commit comments

Comments
 (0)