Skip to content

Commit 390970e

Browse files
authored
Merge pull request #981 from zeux/vfopt-arm
vertexfilter: Optimize various decoders for AArch64
2 parents 4fe6568 + e27b32c commit 390970e

File tree

1 file changed

+56
-44
lines changed

1 file changed

+56
-44
lines changed

src/vertexfilter.cpp

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,13 @@ inline float32x4_t vdivq_f32(float32x4_t x, float32x4_t y)
550550
r = vmulq_f32(r, vrecpsq_f32(y, r)); // refine rcp estimate
551551
return vmulq_f32(x, r);
552552
}
553+
554+
#ifndef __ARM_FEATURE_FMA
555+
inline float32x4_t vfmaq_f32(float32x4_t x, float32x4_t y, float32x4_t z)
556+
{
557+
return vaddq_f32(x, vmulq_f32(y, z));
558+
}
559+
#endif
553560
#endif
554561

555562
#ifdef SIMD_NEON
@@ -580,23 +587,21 @@ static void decodeFilterOctSimd8(signed char* data, size_t count)
580587
y = vaddq_f32(y, vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(t), vandq_s32(vreinterpretq_s32_f32(y), sign))));
581588

582589
// compute normal length & scale
583-
float32x4_t ll = vaddq_f32(vmulq_f32(x, x), vaddq_f32(vmulq_f32(y, y), vmulq_f32(z, z)));
590+
float32x4_t ll = vfmaq_f32(vfmaq_f32(vmulq_f32(x, x), y, y), z, z);
584591
float32x4_t rl = vrsqrteq_f32(ll);
585592
float32x4_t s = vmulq_f32(vdupq_n_f32(127.f), rl);
586593

587594
// fast rounded signed float->int: addition triggers renormalization after which mantissa stores the integer value
588-
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
595+
// note: the result is offset by 0x4B40_0000, but we only need the low 8 bits so we can omit the subtraction
589596
const float32x4_t fsnap = vdupq_n_f32(3 << 22);
590597

591-
int32x4_t xr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(x, s), fsnap));
592-
int32x4_t yr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(y, s), fsnap));
593-
int32x4_t zr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(z, s), fsnap));
598+
int32x4_t xr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, x, s));
599+
int32x4_t yr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, y, s));
600+
int32x4_t zr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, z, s));
594601

595602
// combine xr/yr/zr into final value
596-
int32x4_t res = vandq_s32(n4, vdupq_n_s32(0xff000000));
597-
res = vorrq_s32(res, vandq_s32(xr, vdupq_n_s32(0xff)));
598-
res = vorrq_s32(res, vshlq_n_s32(vandq_s32(yr, vdupq_n_s32(0xff)), 8));
599-
res = vorrq_s32(res, vshlq_n_s32(vandq_s32(zr, vdupq_n_s32(0xff)), 16));
603+
int32x4_t res = vsliq_n_s32(xr, vsliq_n_s32(yr, zr, 8), 8);
604+
res = vbslq_s32(vdupq_n_u32(0xff000000), n4, res);
600605

601606
vst1q_s32(reinterpret_cast<int32_t*>(&data[i * 4]), res);
602607
}
@@ -634,21 +639,25 @@ static void decodeFilterOctSimd16(short* data, size_t count)
634639
y = vaddq_f32(y, vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(t), vandq_s32(vreinterpretq_s32_f32(y), sign))));
635640

636641
// compute normal length & scale
637-
float32x4_t ll = vaddq_f32(vmulq_f32(x, x), vaddq_f32(vmulq_f32(y, y), vmulq_f32(z, z)));
642+
float32x4_t ll = vfmaq_f32(vfmaq_f32(vmulq_f32(x, x), y, y), z, z);
643+
#if !defined(__aarch64__) && !defined(_M_ARM64)
638644
float32x4_t rl = vrsqrteq_f32(ll);
639645
rl = vmulq_f32(rl, vrsqrtsq_f32(vmulq_f32(rl, ll), rl)); // refine rsqrt estimate
640646
float32x4_t s = vmulq_f32(vdupq_n_f32(32767.f), rl);
647+
#else
648+
float32x4_t s = vdivq_f32(vdupq_n_f32(32767.f), vsqrtq_f32(ll));
649+
#endif
641650

642651
// fast rounded signed float->int: addition triggers renormalization after which mantissa stores the integer value
643652
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
644653
const float32x4_t fsnap = vdupq_n_f32(3 << 22);
645654

646-
int32x4_t xr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(x, s), fsnap));
647-
int32x4_t yr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(y, s), fsnap));
648-
int32x4_t zr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(z, s), fsnap));
655+
int32x4_t xr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, x, s));
656+
int32x4_t yr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, y, s));
657+
int32x4_t zr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, z, s));
649658

650659
// mix x/z and y/0 to make 16-bit unpack easier
651-
int32x4_t xzr = vorrq_s32(vandq_s32(xr, vdupq_n_s32(0xffff)), vshlq_n_s32(zr, 16));
660+
int32x4_t xzr = vsliq_n_s32(xr, zr, 16);
652661
int32x4_t y0r = vandq_s32(yr, vdupq_n_s32(0xffff));
653662

654663
// pack x/y/z using 16-bit unpacks; note that this has 0 where we should have .w
@@ -694,7 +703,7 @@ static void decodeFilterQuatSimd(short* data, size_t count)
694703

695704
// reconstruct w as a square root (unscaled); we clamp to 0.f to avoid NaN due to precision errors
696705
float32x4_t ws = vmulq_f32(s, s);
697-
float32x4_t ww = vsubq_f32(vaddq_f32(ws, ws), vaddq_f32(vmulq_f32(x, x), vaddq_f32(vmulq_f32(y, y), vmulq_f32(z, z))));
706+
float32x4_t ww = vsubq_f32(vaddq_f32(ws, ws), vfmaq_f32(vfmaq_f32(vmulq_f32(x, x), y, y), z, z));
698707
float32x4_t w = vsqrtq_f32(vmaxq_f32(ww, vdupq_n_f32(0.f)));
699708

700709
// compute final scale; note that all computations above are unscaled
@@ -705,26 +714,32 @@ static void decodeFilterQuatSimd(short* data, size_t count)
705714
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
706715
const float32x4_t fsnap = vdupq_n_f32(3 << 22);
707716

708-
int32x4_t xr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(x, ss), fsnap));
709-
int32x4_t yr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(y, ss), fsnap));
710-
int32x4_t zr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(z, ss), fsnap));
711-
int32x4_t wr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(w, ss), fsnap));
717+
int32x4_t xr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, x, ss));
718+
int32x4_t yr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, y, ss));
719+
int32x4_t zr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, z, ss));
720+
int32x4_t wr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, w, ss));
712721

713722
// mix x/z and w/y to make 16-bit unpack easier
714-
int32x4_t xzr = vorrq_s32(vandq_s32(xr, vdupq_n_s32(0xffff)), vshlq_n_s32(zr, 16));
715-
int32x4_t wyr = vorrq_s32(vandq_s32(wr, vdupq_n_s32(0xffff)), vshlq_n_s32(yr, 16));
723+
int32x4_t xzr = vsliq_n_s32(xr, zr, 16);
724+
int32x4_t wyr = vsliq_n_s32(wr, yr, 16);
716725

717726
// pack x/y/z/w using 16-bit unpacks; we pack wxyz by default (for qc=0)
718-
int32x4_t res_0 = vreinterpretq_s32_s16(vzipq_s16(vreinterpretq_s16_s32(wyr), vreinterpretq_s16_s32(xzr)).val[0]);
719-
int32x4_t res_1 = vreinterpretq_s32_s16(vzipq_s16(vreinterpretq_s16_s32(wyr), vreinterpretq_s16_s32(xzr)).val[1]);
727+
uint64x2_t res_0 = vreinterpretq_u64_s16(vzipq_s16(vreinterpretq_s16_s32(wyr), vreinterpretq_s16_s32(xzr)).val[0]);
728+
uint64x2_t res_1 = vreinterpretq_u64_s16(vzipq_s16(vreinterpretq_s16_s32(wyr), vreinterpretq_s16_s32(xzr)).val[1]);
729+
730+
// store results to stack so that we can rotate using scalar instructions
731+
// TODO: volatile works around LLVM mis-optimizing code; https://github.com/llvm/llvm-project/issues/166808
732+
volatile uint64_t res[4];
733+
vst1q_u64(const_cast<uint64_t*>(&res[0]), res_0);
734+
vst1q_u64(const_cast<uint64_t*>(&res[2]), res_1);
720735

721736
// rotate and store
722-
uint64_t* out = (uint64_t*)&data[i * 4];
737+
uint64_t* out = reinterpret_cast<uint64_t*>(&data[i * 4]);
723738

724-
out[0] = rotateleft64(vgetq_lane_u64(vreinterpretq_u64_s32(res_0), 0), vgetq_lane_s32(cf, 0) << 4);
725-
out[1] = rotateleft64(vgetq_lane_u64(vreinterpretq_u64_s32(res_0), 1), vgetq_lane_s32(cf, 1) << 4);
726-
out[2] = rotateleft64(vgetq_lane_u64(vreinterpretq_u64_s32(res_1), 0), vgetq_lane_s32(cf, 2) << 4);
727-
out[3] = rotateleft64(vgetq_lane_u64(vreinterpretq_u64_s32(res_1), 1), vgetq_lane_s32(cf, 3) << 4);
739+
out[0] = rotateleft64(res[0], data[(i + 0) * 4 + 3] << 4);
740+
out[1] = rotateleft64(res[1], data[(i + 1) * 4 + 3] << 4);
741+
out[2] = rotateleft64(res[2], data[(i + 2) * 4 + 3] << 4);
742+
out[3] = rotateleft64(res[3], data[(i + 3) * 4 + 3] << 4);
728743
}
729744
}
730745

@@ -778,19 +793,16 @@ static void decodeFilterColorSimd8(unsigned char* data, size_t count)
778793
int32x4_t bf = vsubq_s32(yf, vaddq_s32(cof, cgf));
779794

780795
// fast rounded signed float->int: addition triggers renormalization after which mantissa stores the integer value
781-
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
796+
// note: the result is offset by 0x4B40_0000, but we only need the low 8 bits so we can omit the subtraction
782797
const float32x4_t fsnap = vdupq_n_f32(3 << 22);
783798

784-
int32x4_t rr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(rf), ss), fsnap));
785-
int32x4_t gr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(gf), ss), fsnap));
786-
int32x4_t br = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(bf), ss), fsnap));
787-
int32x4_t ar = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(af), ss), fsnap));
799+
int32x4_t rr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(rf), ss));
800+
int32x4_t gr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(gf), ss));
801+
int32x4_t br = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(bf), ss));
802+
int32x4_t ar = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(af), ss));
788803

789804
// repack rgba into final value
790-
int32x4_t res = vandq_s32(rr, vdupq_n_s32(0xff));
791-
res = vorrq_s32(res, vshlq_n_s32(vandq_s32(gr, vdupq_n_s32(0xff)), 8));
792-
res = vorrq_s32(res, vshlq_n_s32(vandq_s32(br, vdupq_n_s32(0xff)), 16));
793-
res = vorrq_s32(res, vshlq_n_s32(ar, 24));
805+
int32x4_t res = vsliq_n_s32(rr, vsliq_n_s32(gr, vsliq_n_s32(br, ar, 8), 8), 8);
794806

795807
vst1q_s32(reinterpret_cast<int32_t*>(&data[i * 4]), res);
796808
}
@@ -835,14 +847,14 @@ static void decodeFilterColorSimd16(unsigned short* data, size_t count)
835847
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
836848
const float32x4_t fsnap = vdupq_n_f32(3 << 22);
837849

838-
int32x4_t rr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(rf), ss), fsnap));
839-
int32x4_t gr = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(gf), ss), fsnap));
840-
int32x4_t br = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(bf), ss), fsnap));
841-
int32x4_t ar = vreinterpretq_s32_f32(vaddq_f32(vmulq_f32(vcvtq_f32_s32(af), ss), fsnap));
850+
int32x4_t rr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(rf), ss));
851+
int32x4_t gr = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(gf), ss));
852+
int32x4_t br = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(bf), ss));
853+
int32x4_t ar = vreinterpretq_s32_f32(vfmaq_f32(fsnap, vcvtq_f32_s32(af), ss));
842854

843855
// mix r/b and g/a to make 16-bit unpack easier
844-
int32x4_t rbr = vorrq_s32(vandq_s32(rr, vdupq_n_s32(0xffff)), vshlq_n_s32(br, 16));
845-
int32x4_t gar = vorrq_s32(vandq_s32(gr, vdupq_n_s32(0xffff)), vshlq_n_s32(ar, 16));
856+
int32x4_t rbr = vsliq_n_s32(rr, br, 16);
857+
int32x4_t gar = vsliq_n_s32(gr, ar, 16);
846858

847859
// pack r/g/b/a using 16-bit unpacks
848860
int32x4_t res_0 = vreinterpretq_s32_s16(vzipq_s16(vreinterpretq_s16_s32(rbr), vreinterpretq_s16_s32(gar)).val[0]);
@@ -1145,7 +1157,7 @@ static void decodeFilterColorSimd16(unsigned short* data, size_t count)
11451157
v128_t bf = wasm_i32x4_sub(yf, wasm_i32x4_add(cof, cgf));
11461158

11471159
// fast rounded signed float->int: addition triggers renormalization after which mantissa stores the integer value
1148-
// note: the result is offset by 0x4B40_0000, but we only need the low 8 bits so we can omit the subtraction
1160+
// note: the result is offset by 0x4B40_0000, but we only need the low 16 bits so we can omit the subtraction
11491161
const v128_t fsnap = wasm_f32x4_splat(3 << 22);
11501162

11511163
v128_t rr = wasm_f32x4_add(wasm_f32x4_mul(wasm_f32x4_convert_i32x4(rf), ss), fsnap);

0 commit comments

Comments
 (0)