Skip to content

Commit 311a631

Browse files
committed
add f32x3 dot and normalized
1 parent 7484403 commit 311a631

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

cidre/src/simd.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,16 @@ impl f32x3 {
251251
Self::load(&[val; 3])
252252
}
253253

254+
#[inline]
255+
pub fn dot(&self, other: &Self) -> f32 {
256+
unsafe {
257+
let mul = std::arch::aarch64::vmulq_f32(self.0, other.0);
258+
std::arch::aarch64::vgetq_lane_f32::<0>(mul)
259+
+ std::arch::aarch64::vgetq_lane_f32::<1>(mul)
260+
+ std::arch::aarch64::vgetq_lane_f32::<2>(mul)
261+
}
262+
}
263+
254264
pub fn to_bits(&self) -> u128 {
255265
unsafe { std::mem::transmute(*self) }
256266
}
@@ -1009,6 +1019,36 @@ impl SimdNormalized for f32x4 {
10091019
}
10101020
}
10111021

1022+
impl SimdNormalized for f32x3 {
1023+
type Output = f32x3;
1024+
1025+
#[cfg(target_arch = "aarch64")]
1026+
fn simd_normalized(self) -> Self::Output {
1027+
let len_sq = self.dot(&self);
1028+
if len_sq == 0.0 {
1029+
return self;
1030+
}
1031+
1032+
let inv_len = len_sq.sqrt().recip();
1033+
unsafe {
1034+
let scaled =
1035+
std::arch::aarch64::vmulq_f32(self.0, std::arch::aarch64::vdupq_n_f32(inv_len));
1036+
f32x3(std::arch::aarch64::vsetq_lane_f32::<3>(0.0, scaled))
1037+
}
1038+
}
1039+
1040+
#[cfg(not(target_arch = "aarch64"))]
1041+
fn simd_normalized(self) -> Self::Output {
1042+
let len_sq = self.dot(&self);
1043+
if len_sq == 0.0 {
1044+
return self;
1045+
}
1046+
1047+
let inv_len = len_sq.sqrt().recip();
1048+
f32x3::with_xyz_f32(self.x() * inv_len, self.y() * inv_len, self.z() * inv_len)
1049+
}
1050+
}
1051+
10121052
impl SimdMix<f32x3, f32x3> for f32x3 {
10131053
type Output = f32x3;
10141054

@@ -1868,6 +1908,23 @@ mod tests {
18681908
assert_eq!(s, f32x4::with_xyzw(5.0, 15.0, 25.0, 35.0));
18691909
}
18701910

1911+
#[test]
1912+
fn f32x3_dot() {
1913+
let a = f32x3::with_xyz(1.0, 2.0, 3.0);
1914+
let b = f32x3::with_xyz(5.0, 6.0, 7.0);
1915+
assert_eq!(a.dot(&b), 38.0);
1916+
}
1917+
1918+
#[test]
1919+
fn f32x3_normalized() {
1920+
let a = f32x3::with_xyz(3.0, 4.0, 0.0);
1921+
let n = simd::normalized(a);
1922+
assert_f32_close(n.x(), 0.6);
1923+
assert_f32_close(n.y(), 0.8);
1924+
assert_f32_close(n.z(), 0.0);
1925+
assert_f32_close(n.dot(&n), 1.0);
1926+
}
1927+
18711928
#[test]
18721929
fn f32x3_mix() {
18731930
let a = f32x3::with_xyz(0.0, 10.0, 20.0);

cidre/src/simd/vector_types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ impl Simd<f32, 4, 3> {
224224
pub const fn with_rgb_f32(r: f32, g: f32, b: f32) -> Self {
225225
Self([r, g, b, 0.0])
226226
}
227+
228+
#[inline]
229+
pub fn dot(&self, other: &Self) -> f32 {
230+
self.0[0] * other.0[0] + self.0[1] * other.0[1] + self.0[2] * other.0[2]
231+
}
227232
}
228233

229234
impl<T: Copy> Simd<T, 4, 4> {

0 commit comments

Comments
 (0)