Skip to content

Commit 4b93386

Browse files
committed
cleanup dot_product and README.md
1 parent d4153e4 commit 4b93386

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

crates/core_simd/examples/README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,4 @@ Run the tests with the command
1010
cargo run --example dot_product
1111
```
1212

13-
and the benchmarks via the command
14-
15-
```
16-
cargo run --example --benchmark ???
17-
```
18-
19-
and measure the timings on your local system.
13+
and verify the code for `dot_product.rs` on your machine.

crates/core_simd/examples/dot_product.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Code taken from the `packed_simd` crate
22
// Run this code with `cargo test --example dot_product`
3-
//use core::iter::zip;
43
//use std::iter::zip;
54

65
#![feature(array_chunks)]
@@ -72,17 +71,21 @@ pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 {
7271
.reduce_sum()
7372
}
7473

75-
//
74+
// A lot of knowledgeable use of SIMD comes from knowing specific instructions that are
75+
// available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for.
76+
use std_float::StdFloat;
7677
pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 {
7778
assert_eq!(a.len(), b.len());
7879
// TODO handle remainder when a.len() % 4 != 0
80+
let mut res = f32x4::splat(0.0);
7981
a.array_chunks::<4>()
8082
.map(|&a| f32x4::from_array(a))
8183
.zip(b.array_chunks::<4>().map(|&b| f32x4::from_array(b)))
82-
.fold(f32x4::splat(0.0), |acc, zipped| {acc + zipped.0 * zipped.1})
83-
.reduce_sum()
84+
.for_each(|(a,b)| { res = a.mul_add(b, res); });
85+
res.reduce_sum()
8486
}
8587

88+
// Finally, we will write the same operation but handling the loop remainder.
8689
const LANES: usize = 4;
8790
pub fn dot_prod_simd_3(a: &[f32], b: &[f32]) -> f32 {
8891
assert_eq!(a.len(), b.len());
@@ -121,12 +124,16 @@ mod tests {
121124
let x: Vec<f32> = [0.5; 1003].to_vec();
122125
let y: Vec<f32> = [2.0; 1003].to_vec();
123126

127+
// Basic check
124128
assert_eq!(0.0, dot_prod_0(&a, &b));
125129
assert_eq!(0.0, dot_prod_1(&a, &b));
126130
assert_eq!(0.0, dot_prod_simd_0(&a, &b));
127131
assert_eq!(0.0, dot_prod_simd_1(&a, &b));
128132
assert_eq!(0.0, dot_prod_simd_2(&a, &b));
129133
assert_eq!(0.0, dot_prod_simd_3(&a, &b));
134+
135+
// We can handle vectors that are non-multiples of 4
130136
assert_eq!(1003.0, dot_prod_simd_3(&x, &y));
137+
131138
}
132139
}

0 commit comments

Comments
 (0)