Skip to content

Commit 8bea362

Browse files
committed
replace sum() with horizontal_sum()
1 parent ab6af37 commit 8bea362

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

crates/core_simd/examples/nbody.rs

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,87 +5,85 @@ use std::f64::consts::PI;
55
const SOLAR_MASS: f64 = 4.0 * PI * PI;
66
const DAYS_PER_YEAR: f64 = 365.24;
77

8-
#[derive(Debug)]
9-
#[repr(simd)]
8+
#[derive(Debug, Clone, Copy)]
109
pub struct Body {
11-
pub x: Simdf64([f64, 4]),
12-
pub v: Simdf64([f64, 4]),
10+
pub x: f64x4,
11+
pub v: f64x4,
1312
pub mass: f64,
1413
}
1514

16-
// Translation attempt is this ^^^ far
17-
//
15+
// translation up to here
1816
const N_BODIES: usize = 5;
1917
#[allow(clippy::unreadable_literal)]
2018
const BODIES: [Body; N_BODIES] = [
2119
// sun:
2220
Body {
23-
x: f64x4::new(0., 0., 0., 0.),
24-
v: f64x4::new(0., 0., 0., 0.),
21+
x: f64x4::from_array([0., 0., 0., 0.]),
22+
v: f64x4::from_array([0., 0., 0., 0.]),
2523
mass: SOLAR_MASS,
2624
},
2725
// jupiter:
2826
Body {
29-
x: f64x4::new(
27+
x: f64x4::from_array([
3028
4.84143144246472090e+00,
3129
-1.16032004402742839e+00,
3230
-1.03622044471123109e-01,
3331
0.,
34-
),
35-
v: f64x4::new(
32+
]),
33+
v: f64x4::from_array([
3634
1.66007664274403694e-03 * DAYS_PER_YEAR,
3735
7.69901118419740425e-03 * DAYS_PER_YEAR,
3836
-6.90460016972063023e-05 * DAYS_PER_YEAR,
3937
0.,
40-
),
38+
]),
4139
mass: 9.54791938424326609e-04 * SOLAR_MASS,
4240
},
4341
// saturn:
4442
Body {
45-
x: f64x4::new(
43+
x: f64x4::from_array([
4644
8.34336671824457987e+00,
4745
4.12479856412430479e+00,
4846
-4.03523417114321381e-01,
4947
0.,
50-
),
51-
v: f64x4::new(
48+
]),
49+
v: f64x4::from_array([
5250
-2.76742510726862411e-03 * DAYS_PER_YEAR,
5351
4.99852801234917238e-03 * DAYS_PER_YEAR,
5452
2.30417297573763929e-05 * DAYS_PER_YEAR,
5553
0.,
56-
),
54+
]),
5755
mass: 2.85885980666130812e-04 * SOLAR_MASS,
5856
},
5957
// uranus:
6058
Body {
61-
x: f64x4::new(
59+
x: f64x4::from_array([
6260
1.28943695621391310e+01,
6361
-1.51111514016986312e+01,
6462
-2.23307578892655734e-01,
6563
0.,
66-
),
67-
v: f64x4::new(
64+
]),
65+
v: f64x4::from_array([
6866
2.96460137564761618e-03 * DAYS_PER_YEAR,
6967
2.37847173959480950e-03 * DAYS_PER_YEAR,
7068
-2.96589568540237556e-05 * DAYS_PER_YEAR,
7169
0.,
72-
),
70+
]),
7371
mass: 4.36624404335156298e-05 * SOLAR_MASS,
7472
},
7573
// neptune:
7674
Body {
77-
x: f64x4::new(
75+
x: f64x4::from_array([
7876
1.53796971148509165e+01,
7977
-2.59193146099879641e+01,
8078
1.79258772950371181e-01,
8179
0.,
82-
),
83-
v: f64x4::new(
80+
]),
81+
v: f64x4::from_array([
8482
2.68067772490389322e-03 * DAYS_PER_YEAR,
8583
1.62824170038242295e-03 * DAYS_PER_YEAR,
8684
-9.51592254519715870e-05 * DAYS_PER_YEAR,
8785
0.,
88-
),
86+
]),
8987
mass: 5.15138902046611451e-05 * SOLAR_MASS,
9088
},
9189
];
@@ -103,10 +101,10 @@ pub fn energy(bodies: &[Body; N_BODIES]) -> f64 {
103101
let mut e = 0.;
104102
for i in 0..N_BODIES {
105103
let bi = &bodies[i];
106-
e += bi.mass * (bi.v * bi.v).sum() * 0.5;
104+
e += bi.mass * (bi.v * bi.v).horizontal_sum() * 0.5;
107105
for bj in bodies.iter().take(N_BODIES).skip(i + 1) {
108106
let dx = bi.x - bj.x;
109-
e -= bi.mass * bj.mass / (dx * dx).sum().sqrt()
107+
e -= bi.mass * bj.mass / (dx * dx).horizontal_sum().sqrt()
110108
}
111109
}
112110
e
@@ -130,8 +128,8 @@ pub fn advance(bodies: &mut [Body; N_BODIES], dt: f64) {
130128
let mut mag = [0.0; N];
131129
let mut i = 0;
132130
while i < N {
133-
let d2s = f64x2::new((r[i] * r[i]).sum(), (r[i + 1] * r[i + 1]).sum());
134-
let dmags = f64x2::splat(dt) / (d2s * d2s.sqrte());
131+
let d2s = f64x2::from_array([(r[i] * r[i]).horizontal_sum(), (r[i + 1] * r[i + 1]).horizontal_sum()]);
132+
let dmags = f64x2::splat(dt) / (d2s * d2s.sqrt());
135133
dmags.write_to_slice_unaligned(&mut mag[i..]);
136134
i += 2;
137135
}
@@ -190,5 +188,5 @@ fn main() {
190188
//.parse()
191189
//.expect("argument should be a usize");
192190
//run(&mut std::io::stdout(), n, alg);
193-
println!("{:?}", run_k<10>(10, 10));
191+
println!("{:?}", run_k::<10>(10, 10));
194192
}

0 commit comments

Comments
 (0)