Skip to content

Commit 169e122

Browse files
committed
use apple invert for aarch64 simd inverse
1 parent 19f5a57 commit 169e122

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed

cidre/src/simd.rs

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,20 @@ where
778778
lhs.simd_mix(rhs, t)
779779
}
780780

781+
pub trait SimdInverse {
782+
type Output;
783+
784+
fn simd_inverse(self) -> Self::Output;
785+
}
786+
787+
#[inline]
788+
pub fn inverse<T>(val: T) -> <T as SimdInverse>::Output
789+
where
790+
T: SimdInverse,
791+
{
792+
val.simd_inverse()
793+
}
794+
781795
#[inline]
782796
#[cfg(target_arch = "aarch64")]
783797
fn f32x4_dot_cols(c0: f32x4, c1: f32x4, c2: f32x4, c3: f32x4, v: f32x4) -> f32x4 {
@@ -818,6 +832,16 @@ fn f32x3x3_with_cols(c0: f32x3, c1: f32x3, c2: f32x3) -> f32x3x3 {
818832
f32x3x3(std::arch::aarch64::float32x4x3_t(c0.0, c1.0, c2.0))
819833
}
820834

835+
#[inline]
836+
#[cfg(not(target_arch = "aarch64"))]
837+
fn f32x3_cross(a: f32x3, b: f32x3) -> f32x3 {
838+
f32x3::with_xyz(
839+
a.y() * b.z() - a.z() * b.y(),
840+
a.z() * b.x() - a.x() * b.z(),
841+
a.x() * b.y() - a.y() * b.x(),
842+
)
843+
}
844+
821845
#[inline]
822846
#[cfg(not(target_arch = "aarch64"))]
823847
fn f32x4x4_with_cols(c0: f32x4, c1: f32x4, c2: f32x4, c3: f32x4) -> f32x4x4 {
@@ -1001,6 +1025,182 @@ impl SimdMix<f32x3x3, f32> for f32x3x3 {
10011025
f32x3x3_with_cols(c0, c1, c2)
10021026
}
10031027
}
1028+
1029+
impl SimdInverse for f32x3x3 {
1030+
type Output = f32x3x3;
1031+
1032+
fn simd_inverse(self) -> Self::Output {
1033+
#[cfg(target_arch = "aarch64")]
1034+
{
1035+
let q0: std::arch::aarch64::float32x4_t;
1036+
let q1: std::arch::aarch64::float32x4_t;
1037+
let q2: std::arch::aarch64::float32x4_t;
1038+
1039+
unsafe {
1040+
core::arch::asm!(
1041+
"bl ___invert_f3",
1042+
inlateout("q0") self.0.0 => q0,
1043+
inlateout("q1") self.0.1 => q1,
1044+
inlateout("q2") self.0.2 => q2,
1045+
clobber_abi("C"),
1046+
);
1047+
}
1048+
1049+
return f32x3x3(std::arch::aarch64::float32x4x3_t(q0, q1, q2));
1050+
}
1051+
1052+
#[cfg(not(target_arch = "aarch64"))]
1053+
{
1054+
let c0 = self[0];
1055+
let c1 = self[1];
1056+
let c2 = self[2];
1057+
1058+
let cof0 = f32x3_cross(c1, c2);
1059+
let cof1 = f32x3_cross(c2, c0);
1060+
let cof2 = f32x3_cross(c0, c1);
1061+
1062+
let det = c0.x() * cof0.x() + c0.y() * cof0.y() + c0.z() * cof0.z();
1063+
let inv_det = 1.0 / det;
1064+
1065+
f32x3x3_with_cols(
1066+
f32x3::with_xyz(cof0.x() * inv_det, cof1.x() * inv_det, cof2.x() * inv_det),
1067+
f32x3::with_xyz(cof0.y() * inv_det, cof1.y() * inv_det, cof2.y() * inv_det),
1068+
f32x3::with_xyz(cof0.z() * inv_det, cof1.z() * inv_det, cof2.z() * inv_det),
1069+
)
1070+
}
1071+
}
1072+
}
1073+
1074+
impl SimdInverse for f32x4x4 {
1075+
type Output = f32x4x4;
1076+
1077+
fn simd_inverse(self) -> Self::Output {
1078+
#[cfg(target_arch = "aarch64")]
1079+
{
1080+
let q0: std::arch::aarch64::float32x4_t;
1081+
let q1: std::arch::aarch64::float32x4_t;
1082+
let q2: std::arch::aarch64::float32x4_t;
1083+
let q3: std::arch::aarch64::float32x4_t;
1084+
1085+
unsafe {
1086+
core::arch::asm!(
1087+
"bl ___invert_f4",
1088+
inlateout("q0") self.0.0 => q0,
1089+
inlateout("q1") self.0.1 => q1,
1090+
inlateout("q2") self.0.2 => q2,
1091+
inlateout("q3") self.0.3 => q3,
1092+
clobber_abi("C"),
1093+
);
1094+
}
1095+
1096+
return f32x4x4(std::arch::aarch64::float32x4x4_t(q0, q1, q2, q3));
1097+
}
1098+
1099+
#[cfg(not(target_arch = "aarch64"))]
1100+
{
1101+
let m00 = self[0].x();
1102+
let m01 = self[1].x();
1103+
let m02 = self[2].x();
1104+
let m03 = self[3].x();
1105+
let m10 = self[0].y();
1106+
let m11 = self[1].y();
1107+
let m12 = self[2].y();
1108+
let m13 = self[3].y();
1109+
let m20 = self[0].z();
1110+
let m21 = self[1].z();
1111+
let m22 = self[2].z();
1112+
let m23 = self[3].z();
1113+
let m30 = self[0].w();
1114+
let m31 = self[1].w();
1115+
let m32 = self[2].w();
1116+
let m33 = self[3].w();
1117+
1118+
let mut inv = [0.0f32; 16];
1119+
1120+
inv[0] = m11 * m22 * m33 - m11 * m23 * m32 - m21 * m12 * m33
1121+
+ m21 * m13 * m32
1122+
+ m31 * m12 * m23
1123+
- m31 * m13 * m22;
1124+
inv[4] = -m10 * m22 * m33 + m10 * m23 * m32 + m20 * m12 * m33
1125+
- m20 * m13 * m32
1126+
- m30 * m12 * m23
1127+
+ m30 * m13 * m22;
1128+
inv[8] = m10 * m21 * m33 - m10 * m23 * m31 - m20 * m11 * m33
1129+
+ m20 * m13 * m31
1130+
+ m30 * m11 * m23
1131+
- m30 * m13 * m21;
1132+
inv[12] = -m10 * m21 * m32 + m10 * m22 * m31 + m20 * m11 * m32
1133+
- m20 * m12 * m31
1134+
- m30 * m11 * m22
1135+
+ m30 * m12 * m21;
1136+
1137+
inv[1] = -m01 * m22 * m33 + m01 * m23 * m32 + m21 * m02 * m33
1138+
- m21 * m03 * m32
1139+
- m31 * m02 * m23
1140+
+ m31 * m03 * m22;
1141+
inv[5] = m00 * m22 * m33 - m00 * m23 * m32 - m20 * m02 * m33
1142+
+ m20 * m03 * m32
1143+
+ m30 * m02 * m23
1144+
- m30 * m03 * m22;
1145+
inv[9] = -m00 * m21 * m33 + m00 * m23 * m31 + m20 * m01 * m33
1146+
- m20 * m03 * m31
1147+
- m30 * m01 * m23
1148+
+ m30 * m03 * m21;
1149+
inv[13] = m00 * m21 * m32 - m00 * m22 * m31 - m20 * m01 * m32
1150+
+ m20 * m02 * m31
1151+
+ m30 * m01 * m22
1152+
- m30 * m02 * m21;
1153+
1154+
inv[2] = m01 * m12 * m33 - m01 * m13 * m32 - m11 * m02 * m33
1155+
+ m11 * m03 * m32
1156+
+ m31 * m02 * m13
1157+
- m31 * m03 * m12;
1158+
inv[6] = -m00 * m12 * m33 + m00 * m13 * m32 + m10 * m02 * m33
1159+
- m10 * m03 * m32
1160+
- m30 * m02 * m13
1161+
+ m30 * m03 * m12;
1162+
inv[10] = m00 * m11 * m33 - m00 * m13 * m31 - m10 * m01 * m33
1163+
+ m10 * m03 * m31
1164+
+ m30 * m01 * m13
1165+
- m30 * m03 * m11;
1166+
inv[14] = -m00 * m11 * m32 + m00 * m12 * m31 + m10 * m01 * m32
1167+
- m10 * m02 * m31
1168+
- m30 * m01 * m12
1169+
+ m30 * m02 * m11;
1170+
1171+
inv[3] = -m01 * m12 * m23 + m01 * m13 * m22 + m11 * m02 * m23
1172+
- m11 * m03 * m22
1173+
- m21 * m02 * m13
1174+
+ m21 * m03 * m12;
1175+
inv[7] = m00 * m12 * m23 - m00 * m13 * m22 - m10 * m02 * m23
1176+
+ m10 * m03 * m22
1177+
+ m20 * m02 * m13
1178+
- m20 * m03 * m12;
1179+
inv[11] = -m00 * m11 * m23 + m00 * m13 * m21 + m10 * m01 * m23
1180+
- m10 * m03 * m21
1181+
- m20 * m01 * m13
1182+
+ m20 * m03 * m11;
1183+
inv[15] = m00 * m11 * m22 - m00 * m12 * m21 - m10 * m01 * m22
1184+
+ m10 * m02 * m21
1185+
+ m20 * m01 * m12
1186+
- m20 * m02 * m11;
1187+
1188+
let det = m00 * inv[0] + m01 * inv[4] + m02 * inv[8] + m03 * inv[12];
1189+
let inv_det = 1.0 / det;
1190+
1191+
for v in &mut inv {
1192+
*v *= inv_det;
1193+
}
1194+
1195+
f32x4x4_with_cols(
1196+
f32x4::with_xyzw(inv[0], inv[4], inv[8], inv[12]),
1197+
f32x4::with_xyzw(inv[1], inv[5], inv[9], inv[13]),
1198+
f32x4::with_xyzw(inv[2], inv[6], inv[10], inv[14]),
1199+
f32x4::with_xyzw(inv[3], inv[7], inv[11], inv[15]),
1200+
)
1201+
}
1202+
}
1203+
}
10041204
#[cfg(feature = "half")]
10051205
#[derive(Debug, Copy, Clone, PartialEq)]
10061206
#[allow(non_camel_case_types)]
@@ -1433,6 +1633,13 @@ mod tests {
14331633
assert_f32_close(a.z(), b.z());
14341634
}
14351635

1636+
fn assert_f32x4x4_close(a: f32x4x4, b: f32x4x4) {
1637+
assert_f32x4_close(a[0], b[0]);
1638+
assert_f32x4_close(a[1], b[1]);
1639+
assert_f32x4_close(a[2], b[2]);
1640+
assert_f32x4_close(a[3], b[3]);
1641+
}
1642+
14361643
fn assert_f32x3x3_close(a: f32x3x3, b: f32x3x3) {
14371644
assert_f32x3_close(a[0], b[0]);
14381645
assert_f32x3_close(a[1], b[1]);
@@ -1609,6 +1816,54 @@ mod tests {
16091816
assert_f32x3x3_close(s, expected_s);
16101817
}
16111818

1819+
#[test]
1820+
fn f32x3x3_inverse() {
1821+
let d = f32x3x3::diagonal(f32x3::with_xyz(2.0, 4.0, 5.0));
1822+
let d_inv = simd::inverse(d);
1823+
let d_expected = f32x3x3::diagonal(f32x3::with_xyz(0.5, 0.25, 0.2));
1824+
assert_f32x3x3_close(d_inv, d_expected);
1825+
1826+
let t = super::f32x3x3_with_cols(
1827+
f32x3::with_xyz(1.0, 0.0, 0.0),
1828+
f32x3::with_xyz(2.0, 1.0, 0.0),
1829+
f32x3::with_xyz(3.0, 4.0, 1.0),
1830+
);
1831+
let t_inv = simd::inverse(t);
1832+
let t_expected = super::f32x3x3_with_cols(
1833+
f32x3::with_xyz(1.0, 0.0, 0.0),
1834+
f32x3::with_xyz(-2.0, 1.0, 0.0),
1835+
f32x3::with_xyz(5.0, -4.0, 1.0),
1836+
);
1837+
assert_f32x3x3_close(t_inv, t_expected);
1838+
}
1839+
1840+
#[test]
1841+
fn f32x4x4_inverse() {
1842+
let d = f32x4x4::diagonal(f32x4::with_xyzw(2.0, 4.0, 5.0, 10.0));
1843+
let d_inv = simd::inverse(d);
1844+
let d_expected = f32x4x4::diagonal(f32x4::with_xyzw(0.5, 0.25, 0.2, 0.1));
1845+
assert_f32x4x4_close(d_inv, d_expected);
1846+
1847+
let t = super::f32x4x4_with_cols(
1848+
f32x4::with_xyzw(1.0, 0.0, 0.0, 0.0),
1849+
f32x4::with_xyzw(2.0, 1.0, 0.0, 0.0),
1850+
f32x4::with_xyzw(3.0, 5.0, 1.0, 0.0),
1851+
f32x4::with_xyzw(4.0, 6.0, 7.0, 1.0),
1852+
);
1853+
let t_inv = simd::inverse(t);
1854+
let t_expected = super::f32x4x4_with_cols(
1855+
f32x4::with_xyzw(1.0, 0.0, 0.0, 0.0),
1856+
f32x4::with_xyzw(-2.0, 1.0, 0.0, 0.0),
1857+
f32x4::with_xyzw(7.0, -5.0, 1.0, 0.0),
1858+
f32x4::with_xyzw(-41.0, 29.0, -7.0, 1.0),
1859+
);
1860+
assert_f32x4x4_close(t_inv, t_expected);
1861+
1862+
let id = f32x4x4::identity();
1863+
assert_f32x4x4_close(simd::mul(t, t_inv), id);
1864+
assert_f32x4x4_close(simd::mul(t_inv, t), id);
1865+
}
1866+
16121867
#[cfg(feature = "half")]
16131868
#[test]
16141869
fn f16quat() {

0 commit comments

Comments
 (0)