Skip to content

Commit 0dcbcaa

Browse files
committed
Finalized cross-platform distance functions
1 parent 9b92520 commit 0dcbcaa

File tree

7 files changed

+363
-210
lines changed

7 files changed

+363
-210
lines changed

src/distance-avx2.c

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
//
22
// distance-avx2.c
3-
// sqlitevector_test
3+
// sqlitevector
44
//
55
// Created by Marco Bambini on 20/06/25.
66
//
77

88
#include "distance-avx2.h"
9+
#include "distance-cpu.h"
910

1011
#if defined(__AVX2__) || (defined(_MSC_VER) && defined(__AVX2__))
1112
#include <immintrin.h>
@@ -19,7 +20,7 @@ extern char *distance_backend_name;
1920

2021
// MARK: - FLOAT32 -
2122

22-
float float32_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
23+
static inline float float32_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
2324
const float *a = (const float *)v1;
2425
const float *b = (const float *)v2;
2526

@@ -47,11 +48,11 @@ float float32_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool
4748
}
4849

4950
float float32_distance_l2_avx2 (const void *v1, const void *v2, int n) {
50-
return float_distance_l2_impl_avx2(v1, v2, n, true);
51+
return float32_distance_l2_impl_avx2(v1, v2, n, true);
5152
}
5253

5354
float float32_distance_l2_squared_avx2 (const void *v1, const void *v2, int n) {
54-
return float_distance_l2_impl_avx2(v1, v2, n, false);
55+
return float32_distance_l2_impl_avx2(v1, v2, n, false);
5556
}
5657

5758
float float32_distance_l1_avx2 (const void *v1, const void *v2, int n) {
@@ -102,24 +103,23 @@ float float32_distance_dot_avx2 (const void *v1, const void *v2, int n) {
102103
total += a[i] * b[i];
103104
}
104105

105-
return total;
106+
return -total;
106107
}
107108

108109
float float32_distance_cosine_avx2 (const void *a, const void *b, int n) {
109-
float dot = float32_distance_dot_avx2(a, b, n);
110-
float norm_a = sqrtf(float32_distance_dot_avx2(a, a, n));
111-
float norm_b = sqrtf(float32_distance_dot_avx2(b, b, n));
110+
float dot = -float32_distance_dot_avx2(a, b, n);
111+
float norm_a = sqrtf(-float32_distance_dot_avx2(a, a, n));
112+
float norm_b = sqrtf(-float32_distance_dot_avx2(b, b, n));
112113

113114
if (norm_a == 0.0f || norm_b == 0.0f) return 1.0f;
114115

115116
float cosine_similarity = dot / (norm_a * norm_b);
116117
return 1.0f - cosine_similarity;
117118
}
118119

119-
120120
// MARK: - UINT8 -
121121

122-
float uint8_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
122+
static inline float uint8_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
123123
const uint8_t *a = (const uint8_t *)v1;
124124
const uint8_t *b = (const uint8_t *)v2;
125125

@@ -232,7 +232,7 @@ float uint8_distance_dot_avx2 (const void *v1, const void *v2, int n) {
232232
total += a[i] * b[i];
233233
}
234234

235-
return (float)total;
235+
return -(float)total;
236236
}
237237

238238
float uint8_distance_l1_avx2 (const void *v1, const void *v2, int n) {
@@ -278,9 +278,9 @@ float uint8_distance_l1_avx2 (const void *v1, const void *v2, int n) {
278278
}
279279

280280
float uint8_distance_cosine_avx2 (const void *a, const void *b, int n) {
281-
float dot = uint8_distance_dot_avx2(a, b, n);
282-
float norm_a = sqrtf(uint8_distance_dot_avx2(a, a, n));
283-
float norm_b = sqrtf(uint8_distance_dot_avx2(b, b, n));
281+
float dot = -uint8_distance_dot_avx2(a, b, n);
282+
float norm_a = sqrtf(-uint8_distance_dot_avx2(a, a, n));
283+
float norm_b = sqrtf(-uint8_distance_dot_avx2(b, b, n));
284284

285285
if (norm_a == 0.0f || norm_b == 0.0f) return 1.0f;
286286

@@ -290,7 +290,7 @@ float uint8_distance_cosine_avx2 (const void *a, const void *b, int n) {
290290

291291
// MARK: - INT8 -
292292

293-
float int8_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
293+
static inline float int8_distance_l2_impl_avx2 (const void *v1, const void *v2, int n, bool use_sqrt) {
294294
const int8_t *a = (const int8_t *)v1;
295295
const int8_t *b = (const int8_t *)v2;
296296

@@ -417,7 +417,7 @@ float int8_distance_dot_avx2 (const void *v1, const void *v2, int n) {
417417
total += (int)a[i] * (int)b[i];
418418
}
419419

420-
return (float)total;
420+
return -(float)total;
421421
}
422422

423423
float int8_distance_l1_avx2 (const void *v1, const void *v2, int n) {
@@ -475,9 +475,9 @@ float int8_distance_l1_avx2 (const void *v1, const void *v2, int n) {
475475
}
476476

477477
float int8_distance_cosine_avx2 (const void *a, const void *b, int n) {
478-
float dot = int8_distance_dot_avx2(a, b, n);
479-
float norm_a = sqrtf(int8_distance_dot_avx2(a, a, n));
480-
float norm_b = sqrtf(int8_distance_dot_avx2(b, b, n));
478+
float dot = -int8_distance_dot_avx2(a, b, n);
479+
float norm_a = sqrtf(-int8_distance_dot_avx2(a, a, n));
480+
float norm_b = sqrtf(-int8_distance_dot_avx2(b, b, n));
481481

482482
if (norm_a == 0.0f || norm_b == 0.0f) return 1.0f;
483483

src/distance-cpu.c

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,12 @@
1313
#include <string.h>
1414
#include <math.h>
1515

16-
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
1716
#include "distance-neon.h"
18-
#define USE_ARM_NEON 1
19-
#endif
20-
21-
#if defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || (_M_IX86_FP >= 2)))
2217
#include "distance-sse2.h"
23-
#define USE_INTEL_SEE2 1
24-
#endif
25-
26-
#if defined(__AVX2__) || (defined(_MSC_VER) && defined(__AVX2__))
2718
#include "distance-avx2.h"
28-
#define USE_INTEL_AVX2 1
29-
#undef USE_INTEL_SEE2
30-
#endif
3119

3220
char *distance_backend_name = "CPU";
21+
distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = {0};
3322

3423
// MARK: - FLOAT32 -
3524

@@ -129,7 +118,7 @@ float float32_distance_dot_cpu (const void *v1, const void *v2, int n) {
129118
dot += x * y;
130119
}
131120

132-
return 1.0f - dot;
121+
return -dot;
133122
}
134123

135124
float float32_distance_l1_cpu (const void *v1, const void *v2, int n) {
@@ -157,7 +146,7 @@ float float32_distance_l1_cpu (const void *v1, const void *v2, int n) {
157146

158147
// MARK: - UINT8 -
159148

160-
inline float uint8_distance_l2_imp_cpu (const void *v1, const void *v2, int n, bool use_sqrt) {
149+
static inline float uint8_distance_l2_imp_cpu (const void *v1, const void *v2, int n, bool use_sqrt) {
161150
const uint8_t *a = (const uint8_t *)v1;
162151
const uint8_t *b = (const uint8_t *)v2;
163152

@@ -395,11 +384,15 @@ float int8_distance_l1_cpu (const void *v1, const void *v2, int n) {
395384
}
396385

397386
bool cpu_supports_avx2 (void) {
387+
#if FORCE_AVX2
388+
return true;
389+
#else
398390
int eax, ebx, ecx, edx;
399391
x86_cpuid(0, 0, &eax, &ebx, &ecx, &edx);
400392
if (eax < 7) return false;
401393
x86_cpuid(7, 0, &eax, &ebx, &ecx, &edx);
402394
return (ebx & (1 << 5)) != 0; // AVX2
395+
#endif
403396
}
404397

405398
bool cpu_supports_sse2 (void) {
@@ -429,51 +422,62 @@ float int8_distance_l1_cpu (const void *v1, const void *v2, int n) {
429422

430423
// MARK: -
431424

432-
void init_distance_functions (void) {
433-
static bool table_inited = false;
434-
if (table_inited) return;
435-
table_inited = true;
425+
void init_cpu_functions (void) {
426+
distance_function_t cpu_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = {
427+
[VECTOR_DISTANCE_L2] = {
428+
[VECTOR_TYPE_F32] = float32_distance_l2_cpu,
429+
[VECTOR_TYPE_F16] = NULL,
430+
[VECTOR_TYPE_BF16] = NULL,
431+
[VECTOR_TYPE_U8] = uint8_distance_l2_cpu,
432+
[VECTOR_TYPE_I8] = int8_distance_l2_cpu,
433+
},
434+
[VECTOR_DISTANCE_SQUARED_L2] = {
435+
[VECTOR_TYPE_F32] = float32_distance_l2_squared_cpu,
436+
[VECTOR_TYPE_F16] = NULL,
437+
[VECTOR_TYPE_BF16] = NULL,
438+
[VECTOR_TYPE_U8] = uint8_distance_l2_squared_cpu,
439+
[VECTOR_TYPE_I8] = int8_distance_l2_squared_cpu,
440+
},
441+
[VECTOR_DISTANCE_COSINE] = {
442+
[VECTOR_TYPE_F32] = float32_distance_cosine_cpu,
443+
[VECTOR_TYPE_F16] = NULL,
444+
[VECTOR_TYPE_BF16] = NULL,
445+
[VECTOR_TYPE_U8] = uint8_distance_cosine_cpu,
446+
[VECTOR_TYPE_I8] = int8_distance_cosine_cpu,
447+
},
448+
[VECTOR_DISTANCE_DOT] = {
449+
[VECTOR_TYPE_F32] = float32_distance_dot_cpu,
450+
[VECTOR_TYPE_F16] = NULL,
451+
[VECTOR_TYPE_BF16] = NULL,
452+
[VECTOR_TYPE_U8] = uint8_distance_dot_cpu,
453+
[VECTOR_TYPE_I8] = int8_distance_dot_cpu,
454+
},
455+
[VECTOR_DISTANCE_L1] = {
456+
[VECTOR_TYPE_F32] = float32_distance_l1_cpu,
457+
[VECTOR_TYPE_F16] = NULL,
458+
[VECTOR_TYPE_BF16] = NULL,
459+
[VECTOR_TYPE_U8] = uint8_distance_l1_cpu,
460+
[VECTOR_TYPE_I8] = int8_distance_l1_cpu,
461+
}
462+
};
436463

437-
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86)
438-
if (cpu_supports_avx2()) {
464+
memcpy(dispatch_distance_table, cpu_table, sizeof(cpu_table));
465+
}
466+
467+
void init_distance_functions (bool force_cpu) {
468+
init_cpu_functions();
469+
if (force_cpu) return;
470+
471+
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86)
472+
if (1/*cpu_supports_avx2()*/) {
439473
init_distance_functions_avx2();
440474
} else if (cpu_supports_sse2()) {
441475
init_distance_functions_sse2();
442476
}
443-
#elif defined(__ARM_NEON) || defined(__aarch64__)
477+
#elif defined(__ARM_NEON) || defined(__aarch64__)
444478
if (cpu_supports_neon()) {
445479
init_distance_functions_neon();
446480
}
447-
#else
448-
// DO NOTHING
449-
#endif
481+
#endif
450482
}
451483

452-
distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = {
453-
[VECTOR_DISTANCE_L2] = {
454-
[VECTOR_TYPE_F32] = float32_distance_l2_cpu,
455-
[VECTOR_TYPE_U8] = uint8_distance_l2_cpu,
456-
[VECTOR_TYPE_I8] = int8_distance_l2_cpu,
457-
},
458-
[VECTOR_DISTANCE_SQUARED_L2] = {
459-
[VECTOR_TYPE_F32] = float32_distance_l2_squared_cpu,
460-
[VECTOR_TYPE_U8] = uint8_distance_l2_squared_cpu,
461-
[VECTOR_TYPE_I8] = int8_distance_l2_squared_cpu,
462-
},
463-
[VECTOR_DISTANCE_COSINE] = {
464-
[VECTOR_TYPE_F32] = float32_distance_cosine_cpu,
465-
[VECTOR_TYPE_U8] = uint8_distance_cosine_cpu,
466-
[VECTOR_TYPE_I8] = int8_distance_cosine_cpu,
467-
},
468-
[VECTOR_DISTANCE_DOT] = {
469-
[VECTOR_TYPE_F32] = float32_distance_dot_cpu,
470-
[VECTOR_TYPE_U8] = uint8_distance_dot_cpu,
471-
[VECTOR_TYPE_I8] = int8_distance_dot_cpu,
472-
},
473-
[VECTOR_DISTANCE_L1] = {
474-
[VECTOR_TYPE_F32] = float32_distance_l1_cpu,
475-
[VECTOR_TYPE_U8] = uint8_distance_l1_cpu,
476-
[VECTOR_TYPE_I8] = int8_distance_l1_cpu,
477-
}
478-
};
479-

src/distance-cpu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define __VECTOR_DISTANCE_CPU__
1010

1111
#include <inttypes.h>
12+
#include <stdbool.h>
1213

1314
typedef enum {
1415
VECTOR_TYPE_F32 = 1,
@@ -35,6 +36,6 @@ typedef enum {
3536
typedef float (*distance_function_t)(const void *v1, const void *v2, int n);
3637

3738
// ENTRYPOINT
38-
void init_distance_functions (void);
39+
void init_distance_functions (bool force_cpu);
3940

4041
#endif

0 commit comments

Comments
 (0)