1818extern distance_function_t dispatch_distance_table [VECTOR_DISTANCE_MAX ][VECTOR_TYPE_MAX ];
1919extern char * distance_backend_name ;
2020
21+ // Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
22+ #if __SIZEOF_POINTER__ == 4
23+ static inline uint16_t vmaxv_u16_compat (uint16x4_t v ) {
24+ // Use pairwise max to reduce vector
25+ uint16x4_t m = vpmax_u16 (v , v ); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
26+ m = vpmax_u16 (m , m ); // [max(all), max(all), max(all), max(all)]
27+ return vget_lane_u16 (m , 0 );
28+ }
29+ #define vmaxv_u16 vmaxv_u16_compat
30+ #endif
31+
2132// MARK: FLOAT32 -
2233
2334float float32_distance_l2_impl_neon (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
@@ -158,6 +169,31 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
158169 const uint16_t * a = (const uint16_t * )v1 ;
159170 const uint16_t * b = (const uint16_t * )v2 ;
160171
172+ #if __SIZEOF_POINTER__ == 4
173+ // 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
174+ double sum = 0.0 ;
175+ int i = 0 ;
176+
177+ for (; i <= n - 4 ; i += 4 ) {
178+ uint16x4_t av16 = vld1_u16 (a + i );
179+ uint16x4_t bv16 = vld1_u16 (b + i );
180+
181+ float32x4_t va = bf16x4_to_f32x4_u16 (av16 );
182+ float32x4_t vb = bf16x4_to_f32x4_u16 (bv16 );
183+ float32x4_t d = vsubq_f32 (va , vb );
184+ // mask-out NaNs: m = (d==d)
185+ uint32x4_t m = vceqq_f32 (d , d );
186+ d = vbslq_f32 (m , d , vdupq_n_f32 (0.0f ));
187+
188+ // Store and accumulate in scalar double
189+ float tmp [4 ];
190+ vst1q_f32 (tmp , d );
191+ for (int j = 0 ; j < 4 ; j ++ ) {
192+ double dj = (double )tmp [j ];
193+ sum = fma (dj , dj , sum );
194+ }
195+ }
196+ #else
161197 // Accumulate in f64 to avoid overflow from huge bf16 values.
162198 float64x2_t acc0 = vdupq_n_f64 (0.0 ), acc1 = vdupq_n_f64 (0.0 );
163199 int i = 0 ;
@@ -205,6 +241,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
205241 }
206242
207243 double sum = vaddvq_f64 (vaddq_f64 (acc0 , acc1 ));
244+ #endif
208245
209246 // scalar tail; treat NaN as 0, Inf as +Inf result
210247 for (; i < n ; ++ i ) {
@@ -409,8 +446,15 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
409446 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
410447 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
411448
449+ #if __SIZEOF_POINTER__ == 4
450+ // 32-bit ARM: use scalar double accumulation
451+ double sum = 0.0 ;
452+ int i = 0 ;
453+ #else
454+ // 64-bit ARM: use float64x2_t NEON intrinsics
412455 float64x2_t acc0 = vdupq_n_f64 (0.0 ), acc1 = vdupq_n_f64 (0.0 );
413456 int i = 0 ;
457+ #endif
414458
415459 for (; i <= n - 4 ; i += 4 ) {
416460 uint16x4_t av16 = vld1_u16 (a + i );
@@ -443,6 +487,16 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
443487 uint32x4_t m = vceqq_f32 (d32 , d32 ); /* true where not-NaN */
444488 d32 = vbslq_f32 (m , d32 , vdupq_n_f32 (0.0f ));
445489
490+ #if __SIZEOF_POINTER__ == 4
491+ // 32-bit ARM: accumulate in scalar double
492+ float tmp [4 ];
493+ vst1q_f32 (tmp , d32 );
494+ for (int j = 0 ; j < 4 ; j ++ ) {
495+ double dj = (double )tmp [j ];
496+ sum = fma (dj , dj , sum );
497+ }
498+ #else
499+ // 64-bit ARM: use NEON f64 operations
446500 float64x2_t dlo = vcvt_f64_f32 (vget_low_f32 (d32 ));
447501 float64x2_t dhi = vcvt_f64_f32 (vget_high_f32 (d32 ));
448502#if defined(__ARM_FEATURE_FMA )
@@ -451,10 +505,13 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
451505#else
452506 acc0 = vaddq_f64 (acc0 , vmulq_f64 (dlo , dlo ));
453507 acc1 = vaddq_f64 (acc1 , vmulq_f64 (dhi , dhi ));
508+ #endif
454509#endif
455510 }
456511
512+ #if __SIZEOF_POINTER__ != 4
457513 double sum = vaddvq_f64 (vaddq_f64 (acc0 , acc1 ));
514+ #endif
458515
459516 /* tail (scalar; same Inf/NaN policy) */
460517 for (; i < n ; ++ i ) {
@@ -487,10 +544,17 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
487544 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
488545 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
489546
547+ #if __SIZEOF_POINTER__ == 4
548+ // 32-bit ARM: use scalar double accumulation
549+ double dot = 0.0 , normx = 0.0 , normy = 0.0 ;
550+ int i = 0 ;
551+ #else
552+ // 64-bit ARM: use float64x2_t NEON intrinsics
490553 float64x2_t acc_dot_lo = vdupq_n_f64 (0.0 ), acc_dot_hi = vdupq_n_f64 (0.0 );
491554 float64x2_t acc_a2_lo = vdupq_n_f64 (0.0 ), acc_a2_hi = vdupq_n_f64 (0.0 );
492555 float64x2_t acc_b2_lo = vdupq_n_f64 (0.0 ), acc_b2_hi = vdupq_n_f64 (0.0 );
493556 int i = 0 ;
557+ #endif
494558
495559 for (; i <= n - 4 ; i += 4 ) {
496560 uint16x4_t av16 = vld1_u16 (a + i );
@@ -512,6 +576,19 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
512576 ax = vbslq_f32 (mx , ax , vdupq_n_f32 (0.0f ));
513577 by = vbslq_f32 (my , by , vdupq_n_f32 (0.0f ));
514578
579+ #if __SIZEOF_POINTER__ == 4
580+ // 32-bit ARM: accumulate in scalar double
581+ float ax_tmp [4 ], by_tmp [4 ];
582+ vst1q_f32 (ax_tmp , ax );
583+ vst1q_f32 (by_tmp , by );
584+ for (int j = 0 ; j < 4 ; j ++ ) {
585+ double x = (double )ax_tmp [j ];
586+ double y = (double )by_tmp [j ];
587+ dot += x * y ;
588+ normx += x * x ;
589+ normy += y * y ;
590+ }
591+ #else
515592 /* widen to f64 and accumulate */
516593 float64x2_t ax_lo = vcvt_f64_f32 (vget_low_f32 (ax )), ax_hi = vcvt_f64_f32 (vget_high_f32 (ax ));
517594 float64x2_t by_lo = vcvt_f64_f32 (vget_low_f32 (by )), by_hi = vcvt_f64_f32 (vget_high_f32 (by ));
@@ -530,12 +607,15 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
530607 acc_a2_hi = vaddq_f64 (acc_a2_hi , vmulq_f64 (ax_hi , ax_hi ));
531608 acc_b2_lo = vaddq_f64 (acc_b2_lo , vmulq_f64 (by_lo , by_lo ));
532609 acc_b2_hi = vaddq_f64 (acc_b2_hi , vmulq_f64 (by_hi , by_hi ));
610+ #endif
533611#endif
534612 }
535613
614+ #if __SIZEOF_POINTER__ != 4
536615 double dot = vaddvq_f64 (vaddq_f64 (acc_dot_lo , acc_dot_hi ));
537616 double normx = vaddvq_f64 (vaddq_f64 (acc_a2_lo , acc_a2_hi ));
538617 double normy = vaddvq_f64 (vaddq_f64 (acc_b2_lo , acc_b2_hi ));
618+ #endif
539619
540620 /* tail (scalar) */
541621 for (; i < n ; ++ i ) {
@@ -569,8 +649,15 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
569649 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
570650 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
571651
652+ #if __SIZEOF_POINTER__ == 4
653+ // 32-bit ARM: use scalar double accumulation
654+ double dot = 0.0 ;
655+ int i = 0 ;
656+ #else
657+ // 64-bit ARM: use float64x2_t NEON intrinsics
572658 float64x2_t acc_lo = vdupq_n_f64 (0.0 ), acc_hi = vdupq_n_f64 (0.0 );
573659 int i = 0 ;
660+ #endif
574661
575662 for (; i <= n - 4 ; i += 4 ) {
576663 uint16x4_t av16 = vld1_u16 (a + i );
@@ -588,7 +675,11 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
588675 if (isnan (x ) || isnan (y )) continue ;
589676 double p = (double )x * (double )y ;
590677 if (isinf (p )) return (p > 0 )? - INFINITY : INFINITY ;
678+ #if __SIZEOF_POINTER__ == 4
679+ dot += p ;
680+ #else
591681 acc_lo = vsetq_lane_f64 (vgetq_lane_f64 (acc_lo ,0 )+ p , acc_lo , 0 ); /* cheap add */
682+ #endif
592683 }
593684 continue ;
594685 }
@@ -603,13 +694,26 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
603694 by = vbslq_f32 (my , by , vdupq_n_f32 (0.0f ));
604695
605696 float32x4_t prod = vmulq_f32 (ax , by );
697+
698+ #if __SIZEOF_POINTER__ == 4
699+ // 32-bit ARM: accumulate in scalar double
700+ float prod_tmp [4 ];
701+ vst1q_f32 (prod_tmp , prod );
702+ for (int j = 0 ; j < 4 ; j ++ ) {
703+ dot += (double )prod_tmp [j ];
704+ }
705+ #else
706+ // 64-bit ARM: use NEON f64 operations
606707 float64x2_t lo = vcvt_f64_f32 (vget_low_f32 (prod ));
607708 float64x2_t hi = vcvt_f64_f32 (vget_high_f32 (prod ));
608709 acc_lo = vaddq_f64 (acc_lo , lo );
609710 acc_hi = vaddq_f64 (acc_hi , hi );
711+ #endif
610712 }
611713
714+ #if __SIZEOF_POINTER__ != 4
612715 double dot = vaddvq_f64 (vaddq_f64 (acc_lo , acc_hi ));
716+ #endif
613717
614718 for (; i < n ; ++ i ) {
615719 float x = float16_to_float32 (a [i ]);
@@ -635,8 +739,15 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
635739 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
636740 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
637741
742+ #if __SIZEOF_POINTER__ == 4
743+ // 32-bit ARM: use scalar double accumulation
744+ double sum = 0.0 ;
745+ int i = 0 ;
746+ #else
747+ // 64-bit ARM: use float64x2_t NEON intrinsics
638748 float64x2_t acc = vdupq_n_f64 (0.0 );
639749 int i = 0 ;
750+ #endif
640751
641752 for (; i <= n - 4 ; i += 4 ) {
642753 uint16x4_t av16 = vld1_u16 (a + i );
@@ -665,13 +776,25 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
665776 uint32x4_t m = vceqq_f32 (d , d ); /* mask NaNs -> 0 */
666777 d = vbslq_f32 (m , d , vdupq_n_f32 (0.0f ));
667778
779+ #if __SIZEOF_POINTER__ == 4
780+ // 32-bit ARM: accumulate in scalar double
781+ float tmp [4 ];
782+ vst1q_f32 (tmp , d );
783+ for (int j = 0 ; j < 4 ; j ++ ) {
784+ sum += (double )tmp [j ];
785+ }
786+ #else
787+ // 64-bit ARM: use NEON f64 operations
668788 float64x2_t lo = vcvt_f64_f32 (vget_low_f32 (d ));
669789 float64x2_t hi = vcvt_f64_f32 (vget_high_f32 (d ));
670790 acc = vaddq_f64 (acc , lo );
671791 acc = vaddq_f64 (acc , hi );
792+ #endif
672793 }
673794
795+ #if __SIZEOF_POINTER__ != 4
674796 double sum = vaddvq_f64 (acc );
797+ #endif
675798
676799 for (; i < n ; ++ i ) {
677800 uint16_t ai = a [i ], bi = b [i ];
0 commit comments