Skip to content

Commit 1fa8143

Browse files
committed
fix(android): add support for amerabi-v7a (arm NEON 32bit); fixes #30
1 parent c066d9e commit 1fa8143

File tree

3 files changed

+144
-8
lines changed

3 files changed

+144
-8
lines changed

.github/workflows/main.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
build:
1212
runs-on: ${{ matrix.os }}
1313
container: ${{ matrix.container && matrix.container || '' }}
14-
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
14+
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
1515
timeout-minutes: 20
1616
strategy:
1717
fail-fast: false
@@ -47,6 +47,10 @@ jobs:
4747
arch: arm64-v8a
4848
name: android
4949
make: PLATFORM=android ARCH=arm64-v8a
50+
- os: ubuntu-22.04
51+
arch: armeabi-v7a
52+
name: android
53+
make: PLATFORM=android ARCH=armeabi-v7a
5054
- os: ubuntu-22.04
5155
arch: x86_64
5256
name: android
@@ -140,7 +144,7 @@ jobs:
140144
security delete-keychain build.keychain
141145
142146
- name: android setup test environment
143-
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
147+
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
144148
run: |
145149
146150
echo "::group::enable kvm group perms"
@@ -168,7 +172,7 @@ jobs:
168172
echo "::endgroup::"
169173
170174
- name: android test sqlite-vector
171-
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
175+
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
172176
uses: reactivecircus/[email protected]
173177
with:
174178
api-level: 26

Makefile

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,22 @@ else ifeq ($(PLATFORM),macos)
5959
STRIP = strip -x -S $@
6060
else ifeq ($(PLATFORM),android)
6161
ifndef ARCH # Set ARCH to find Android NDK's Clang compiler, the user should set the ARCH
62-
$(error "Android ARCH must be set to ARCH=x86_64 or ARCH=arm64-v8a")
62+
$(error "Android ARCH must be set to ARCH=x86_64, ARCH=arm64-v8a, or ARCH=armeabi-v7a")
6363
endif
6464
ifndef ANDROID_NDK # Set ANDROID_NDK path to find android build tools; e.g. on MacOS: export ANDROID_NDK=/Users/username/Library/Android/sdk/ndk/25.2.9519653
6565
$(error "Android NDK must be set")
6666
endif
6767
BIN = $(ANDROID_NDK)/toolchains/llvm/prebuilt/$(HOST)-x86_64/bin
6868
ifneq (,$(filter $(ARCH),arm64 arm64-v8a))
6969
override ARCH := aarch64
70+
ANDROID_ABI := android26
71+
else ifeq ($(ARCH),armeabi-v7a)
72+
override ARCH := armv7a
73+
ANDROID_ABI := androideabi26
74+
else
75+
ANDROID_ABI := android26
7076
endif
71-
CC = $(BIN)/$(ARCH)-linux-android26-clang
77+
CC = $(BIN)/$(ARCH)-linux-$(ANDROID_ABI)-clang
7278
TARGET := $(DIST_DIR)/vector.so
7379
LDFLAGS += -lm -shared
7480
STRIP = $(BIN)/llvm-strip --strip-unneeded $@
@@ -184,11 +190,14 @@ $(DIST_DIR)/%.xcframework: $(LIB_NAMES)
184190

185191
xcframework: $(DIST_DIR)/vector.xcframework
186192

187-
AAR_ARM = packages/android/src/main/jniLibs/arm64-v8a/
193+
AAR_ARM64 = packages/android/src/main/jniLibs/arm64-v8a/
194+
AAR_ARM = packages/android/src/main/jniLibs/armeabi-v7a/
188195
AAR_X86 = packages/android/src/main/jniLibs/x86_64/
189196
aar:
190-
mkdir -p $(AAR_ARM) $(AAR_X86)
197+
mkdir -p $(AAR_ARM64) $(AAR_ARM) $(AAR_X86)
191198
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=arm64-v8a
199+
mv $(DIST_DIR)/vector.so $(AAR_ARM64)
200+
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=armeabi-v7a
192201
mv $(DIST_DIR)/vector.so $(AAR_ARM)
193202
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=x86_64
194203
mv $(DIST_DIR)/vector.so $(AAR_X86)
@@ -208,7 +217,7 @@ help:
208217
@echo " linux (default on Linux)"
209218
@echo " macos (default on macOS)"
210219
@echo " windows (default on Windows)"
211-
@echo " android (needs ARCH to be set to x86_64 or arm64-v8a and ANDROID_NDK to be set)"
220+
@echo " android (needs ARCH to be set to x86_64, arm64-v8a, or armeabi-v7a and ANDROID_NDK to be set)"
212221
@echo " ios (only on macOS)"
213222
@echo " ios-sim (only on macOS)"
214223
@echo ""

src/distance-neon.c

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
1919
extern char *distance_backend_name;
2020

21+
// Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
22+
#if __SIZEOF_POINTER__ == 4
23+
static inline uint16_t vmaxv_u16_compat(uint16x4_t v) {
24+
// Use pairwise max to reduce vector
25+
uint16x4_t m = vpmax_u16(v, v); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
26+
m = vpmax_u16(m, m); // [max(all), max(all), max(all), max(all)]
27+
return vget_lane_u16(m, 0);
28+
}
29+
#define vmaxv_u16 vmaxv_u16_compat
30+
#endif
31+
2132
// MARK: FLOAT32 -
2233

2334
float float32_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool use_sqrt) {
@@ -158,6 +169,31 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
158169
const uint16_t *a = (const uint16_t *)v1;
159170
const uint16_t *b = (const uint16_t *)v2;
160171

172+
#if __SIZEOF_POINTER__ == 4
173+
// 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
174+
double sum = 0.0;
175+
int i = 0;
176+
177+
for (; i <= n - 4; i += 4) {
178+
uint16x4_t av16 = vld1_u16(a + i);
179+
uint16x4_t bv16 = vld1_u16(b + i);
180+
181+
float32x4_t va = bf16x4_to_f32x4_u16(av16);
182+
float32x4_t vb = bf16x4_to_f32x4_u16(bv16);
183+
float32x4_t d = vsubq_f32(va, vb);
184+
// mask-out NaNs: m = (d==d)
185+
uint32x4_t m = vceqq_f32(d, d);
186+
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));
187+
188+
// Store and accumulate in scalar double
189+
float tmp[4];
190+
vst1q_f32(tmp, d);
191+
for (int j = 0; j < 4; j++) {
192+
double dj = (double)tmp[j];
193+
sum = fma(dj, dj, sum);
194+
}
195+
}
196+
#else
161197
// Accumulate in f64 to avoid overflow from huge bf16 values.
162198
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
163199
int i = 0;
@@ -205,6 +241,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
205241
}
206242

207243
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
244+
#endif
208245

209246
// scalar tail; treat NaN as 0, Inf as +Inf result
210247
for (; i < n; ++i) {
@@ -409,8 +446,15 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
409446
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
410447
const uint16x4_t ZERO16 = vdup_n_u16(0);
411448

449+
#if __SIZEOF_POINTER__ == 4
450+
// 32-bit ARM: use scalar double accumulation
451+
double sum = 0.0;
452+
int i = 0;
453+
#else
454+
// 64-bit ARM: use float64x2_t NEON intrinsics
412455
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
413456
int i = 0;
457+
#endif
414458

415459
for (; i <= n - 4; i += 4) {
416460
uint16x4_t av16 = vld1_u16(a + i);
@@ -443,6 +487,16 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
443487
uint32x4_t m = vceqq_f32(d32, d32); /* true where not-NaN */
444488
d32 = vbslq_f32(m, d32, vdupq_n_f32(0.0f));
445489

490+
#if __SIZEOF_POINTER__ == 4
491+
// 32-bit ARM: accumulate in scalar double
492+
float tmp[4];
493+
vst1q_f32(tmp, d32);
494+
for (int j = 0; j < 4; j++) {
495+
double dj = (double)tmp[j];
496+
sum = fma(dj, dj, sum);
497+
}
498+
#else
499+
// 64-bit ARM: use NEON f64 operations
446500
float64x2_t dlo = vcvt_f64_f32(vget_low_f32(d32));
447501
float64x2_t dhi = vcvt_f64_f32(vget_high_f32(d32));
448502
#if defined(__ARM_FEATURE_FMA)
@@ -451,10 +505,13 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
451505
#else
452506
acc0 = vaddq_f64(acc0, vmulq_f64(dlo, dlo));
453507
acc1 = vaddq_f64(acc1, vmulq_f64(dhi, dhi));
508+
#endif
454509
#endif
455510
}
456511

512+
#if __SIZEOF_POINTER__ != 4
457513
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
514+
#endif
458515

459516
/* tail (scalar; same Inf/NaN policy) */
460517
for (; i < n; ++i) {
@@ -487,10 +544,17 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
487544
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
488545
const uint16x4_t ZERO16 = vdup_n_u16(0);
489546

547+
#if __SIZEOF_POINTER__ == 4
548+
// 32-bit ARM: use scalar double accumulation
549+
double dot = 0.0, normx = 0.0, normy = 0.0;
550+
int i = 0;
551+
#else
552+
// 64-bit ARM: use float64x2_t NEON intrinsics
490553
float64x2_t acc_dot_lo = vdupq_n_f64(0.0), acc_dot_hi = vdupq_n_f64(0.0);
491554
float64x2_t acc_a2_lo = vdupq_n_f64(0.0), acc_a2_hi = vdupq_n_f64(0.0);
492555
float64x2_t acc_b2_lo = vdupq_n_f64(0.0), acc_b2_hi = vdupq_n_f64(0.0);
493556
int i = 0;
557+
#endif
494558

495559
for (; i <= n - 4; i += 4) {
496560
uint16x4_t av16 = vld1_u16(a + i);
@@ -512,6 +576,19 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
512576
ax = vbslq_f32(mx, ax, vdupq_n_f32(0.0f));
513577
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));
514578

579+
#if __SIZEOF_POINTER__ == 4
580+
// 32-bit ARM: accumulate in scalar double
581+
float ax_tmp[4], by_tmp[4];
582+
vst1q_f32(ax_tmp, ax);
583+
vst1q_f32(by_tmp, by);
584+
for (int j = 0; j < 4; j++) {
585+
double x = (double)ax_tmp[j];
586+
double y = (double)by_tmp[j];
587+
dot += x * y;
588+
normx += x * x;
589+
normy += y * y;
590+
}
591+
#else
515592
/* widen to f64 and accumulate */
516593
float64x2_t ax_lo = vcvt_f64_f32(vget_low_f32(ax)), ax_hi = vcvt_f64_f32(vget_high_f32(ax));
517594
float64x2_t by_lo = vcvt_f64_f32(vget_low_f32(by)), by_hi = vcvt_f64_f32(vget_high_f32(by));
@@ -530,12 +607,15 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
530607
acc_a2_hi = vaddq_f64(acc_a2_hi, vmulq_f64(ax_hi, ax_hi));
531608
acc_b2_lo = vaddq_f64(acc_b2_lo, vmulq_f64(by_lo, by_lo));
532609
acc_b2_hi = vaddq_f64(acc_b2_hi, vmulq_f64(by_hi, by_hi));
610+
#endif
533611
#endif
534612
}
535613

614+
#if __SIZEOF_POINTER__ != 4
536615
double dot = vaddvq_f64(vaddq_f64(acc_dot_lo, acc_dot_hi));
537616
double normx= vaddvq_f64(vaddq_f64(acc_a2_lo, acc_a2_hi));
538617
double normy= vaddvq_f64(vaddq_f64(acc_b2_lo, acc_b2_hi));
618+
#endif
539619

540620
/* tail (scalar) */
541621
for (; i < n; ++i) {
@@ -569,8 +649,15 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
569649
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
570650
const uint16x4_t ZERO16 = vdup_n_u16(0);
571651

652+
#if __SIZEOF_POINTER__ == 4
653+
// 32-bit ARM: use scalar double accumulation
654+
double dot = 0.0;
655+
int i = 0;
656+
#else
657+
// 64-bit ARM: use float64x2_t NEON intrinsics
572658
float64x2_t acc_lo = vdupq_n_f64(0.0), acc_hi = vdupq_n_f64(0.0);
573659
int i = 0;
660+
#endif
574661

575662
for (; i <= n - 4; i += 4) {
576663
uint16x4_t av16 = vld1_u16(a + i);
@@ -588,7 +675,11 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
588675
if (isnan(x) || isnan(y)) continue;
589676
double p = (double)x * (double)y;
590677
if (isinf(p)) return (p>0)? -INFINITY : INFINITY;
678+
#if __SIZEOF_POINTER__ == 4
679+
dot += p;
680+
#else
591681
acc_lo = vsetq_lane_f64(vgetq_lane_f64(acc_lo,0)+p, acc_lo, 0); /* cheap add */
682+
#endif
592683
}
593684
continue;
594685
}
@@ -603,13 +694,26 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
603694
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));
604695

605696
float32x4_t prod = vmulq_f32(ax, by);
697+
698+
#if __SIZEOF_POINTER__ == 4
699+
// 32-bit ARM: accumulate in scalar double
700+
float prod_tmp[4];
701+
vst1q_f32(prod_tmp, prod);
702+
for (int j = 0; j < 4; j++) {
703+
dot += (double)prod_tmp[j];
704+
}
705+
#else
706+
// 64-bit ARM: use NEON f64 operations
606707
float64x2_t lo = vcvt_f64_f32(vget_low_f32(prod));
607708
float64x2_t hi = vcvt_f64_f32(vget_high_f32(prod));
608709
acc_lo = vaddq_f64(acc_lo, lo);
609710
acc_hi = vaddq_f64(acc_hi, hi);
711+
#endif
610712
}
611713

714+
#if __SIZEOF_POINTER__ != 4
612715
double dot = vaddvq_f64(vaddq_f64(acc_lo, acc_hi));
716+
#endif
613717

614718
for (; i < n; ++i) {
615719
float x = float16_to_float32(a[i]);
@@ -635,8 +739,15 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
635739
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
636740
const uint16x4_t ZERO16 = vdup_n_u16(0);
637741

742+
#if __SIZEOF_POINTER__ == 4
743+
// 32-bit ARM: use scalar double accumulation
744+
double sum = 0.0;
745+
int i = 0;
746+
#else
747+
// 64-bit ARM: use float64x2_t NEON intrinsics
638748
float64x2_t acc = vdupq_n_f64(0.0);
639749
int i = 0;
750+
#endif
640751

641752
for (; i <= n - 4; i += 4) {
642753
uint16x4_t av16 = vld1_u16(a + i);
@@ -665,13 +776,25 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
665776
uint32x4_t m = vceqq_f32(d, d); /* mask NaNs -> 0 */
666777
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));
667778

779+
#if __SIZEOF_POINTER__ == 4
780+
// 32-bit ARM: accumulate in scalar double
781+
float tmp[4];
782+
vst1q_f32(tmp, d);
783+
for (int j = 0; j < 4; j++) {
784+
sum += (double)tmp[j];
785+
}
786+
#else
787+
// 64-bit ARM: use NEON f64 operations
668788
float64x2_t lo = vcvt_f64_f32(vget_low_f32(d));
669789
float64x2_t hi = vcvt_f64_f32(vget_high_f32(d));
670790
acc = vaddq_f64(acc, lo);
671791
acc = vaddq_f64(acc, hi);
792+
#endif
672793
}
673794

795+
#if __SIZEOF_POINTER__ != 4
674796
double sum = vaddvq_f64(acc);
797+
#endif
675798

676799
for (; i < n; ++i) {
677800
uint16_t ai=a[i], bi=b[i];

0 commit comments

Comments
 (0)