|
13 | 13 | #include <string.h> |
14 | 14 | #include <math.h> |
15 | 15 |
|
16 | | -#if defined(__ARM_NEON) || defined(__ARM_NEON__) |
17 | 16 | #include "distance-neon.h" |
18 | | -#define USE_ARM_NEON 1 |
19 | | -#endif |
20 | | - |
21 | | -#if defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || (_M_IX86_FP >= 2))) |
22 | 17 | #include "distance-sse2.h" |
23 | | -#define USE_INTEL_SEE2 1 |
24 | | -#endif |
25 | | - |
26 | | -#if defined(__AVX2__) || (defined(_MSC_VER) && defined(__AVX2__)) |
27 | 18 | #include "distance-avx2.h" |
28 | | -#define USE_INTEL_AVX2 1 |
29 | | -#undef USE_INTEL_SEE2 |
30 | | -#endif |
31 | 19 |
|
32 | 20 | char *distance_backend_name = "CPU"; |
| 21 | +distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = {0}; |
33 | 22 |
|
34 | 23 | // MARK: - FLOAT32 - |
35 | 24 |
|
@@ -129,7 +118,7 @@ float float32_distance_dot_cpu (const void *v1, const void *v2, int n) { |
129 | 118 | dot += x * y; |
130 | 119 | } |
131 | 120 |
|
132 | | - return 1.0f - dot; |
| 121 | + return -dot; |
133 | 122 | } |
134 | 123 |
|
135 | 124 | float float32_distance_l1_cpu (const void *v1, const void *v2, int n) { |
@@ -157,7 +146,7 @@ float float32_distance_l1_cpu (const void *v1, const void *v2, int n) { |
157 | 146 |
|
158 | 147 | // MARK: - UINT8 - |
159 | 148 |
|
160 | | -inline float uint8_distance_l2_imp_cpu (const void *v1, const void *v2, int n, bool use_sqrt) { |
| 149 | +static inline float uint8_distance_l2_imp_cpu (const void *v1, const void *v2, int n, bool use_sqrt) { |
161 | 150 | const uint8_t *a = (const uint8_t *)v1; |
162 | 151 | const uint8_t *b = (const uint8_t *)v2; |
163 | 152 |
|
@@ -395,11 +384,15 @@ float int8_distance_l1_cpu (const void *v1, const void *v2, int n) { |
395 | 384 | } |
396 | 385 |
|
397 | 386 | bool cpu_supports_avx2 (void) { |
| 387 | + #if FORCE_AVX2 |
| 388 | + return true; |
| 389 | + #else |
398 | 390 | int eax, ebx, ecx, edx; |
399 | 391 | x86_cpuid(0, 0, &eax, &ebx, &ecx, &edx); |
400 | 392 | if (eax < 7) return false; |
401 | 393 | x86_cpuid(7, 0, &eax, &ebx, &ecx, &edx); |
402 | 394 | return (ebx & (1 << 5)) != 0; // AVX2 |
| 395 | + #endif |
403 | 396 | } |
404 | 397 |
|
405 | 398 | bool cpu_supports_sse2 (void) { |
@@ -429,51 +422,62 @@ float int8_distance_l1_cpu (const void *v1, const void *v2, int n) { |
429 | 422 |
|
430 | 423 | // MARK: - |
431 | 424 |
|
432 | | -void init_distance_functions (void) { |
433 | | - static bool table_inited = false; |
434 | | - if (table_inited) return; |
435 | | - table_inited = true; |
| 425 | +void init_cpu_functions (void) { |
| 426 | + distance_function_t cpu_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = { |
| 427 | + [VECTOR_DISTANCE_L2] = { |
| 428 | + [VECTOR_TYPE_F32] = float32_distance_l2_cpu, |
| 429 | + [VECTOR_TYPE_F16] = NULL, |
| 430 | + [VECTOR_TYPE_BF16] = NULL, |
| 431 | + [VECTOR_TYPE_U8] = uint8_distance_l2_cpu, |
| 432 | + [VECTOR_TYPE_I8] = int8_distance_l2_cpu, |
| 433 | + }, |
| 434 | + [VECTOR_DISTANCE_SQUARED_L2] = { |
| 435 | + [VECTOR_TYPE_F32] = float32_distance_l2_squared_cpu, |
| 436 | + [VECTOR_TYPE_F16] = NULL, |
| 437 | + [VECTOR_TYPE_BF16] = NULL, |
| 438 | + [VECTOR_TYPE_U8] = uint8_distance_l2_squared_cpu, |
| 439 | + [VECTOR_TYPE_I8] = int8_distance_l2_squared_cpu, |
| 440 | + }, |
| 441 | + [VECTOR_DISTANCE_COSINE] = { |
| 442 | + [VECTOR_TYPE_F32] = float32_distance_cosine_cpu, |
| 443 | + [VECTOR_TYPE_F16] = NULL, |
| 444 | + [VECTOR_TYPE_BF16] = NULL, |
| 445 | + [VECTOR_TYPE_U8] = uint8_distance_cosine_cpu, |
| 446 | + [VECTOR_TYPE_I8] = int8_distance_cosine_cpu, |
| 447 | + }, |
| 448 | + [VECTOR_DISTANCE_DOT] = { |
| 449 | + [VECTOR_TYPE_F32] = float32_distance_dot_cpu, |
| 450 | + [VECTOR_TYPE_F16] = NULL, |
| 451 | + [VECTOR_TYPE_BF16] = NULL, |
| 452 | + [VECTOR_TYPE_U8] = uint8_distance_dot_cpu, |
| 453 | + [VECTOR_TYPE_I8] = int8_distance_dot_cpu, |
| 454 | + }, |
| 455 | + [VECTOR_DISTANCE_L1] = { |
| 456 | + [VECTOR_TYPE_F32] = float32_distance_l1_cpu, |
| 457 | + [VECTOR_TYPE_F16] = NULL, |
| 458 | + [VECTOR_TYPE_BF16] = NULL, |
| 459 | + [VECTOR_TYPE_U8] = uint8_distance_l1_cpu, |
| 460 | + [VECTOR_TYPE_I8] = int8_distance_l1_cpu, |
| 461 | + } |
| 462 | + }; |
436 | 463 |
|
437 | | -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) |
438 | | - if (cpu_supports_avx2()) { |
| 464 | + memcpy(dispatch_distance_table, cpu_table, sizeof(cpu_table)); |
| 465 | +} |
| 466 | + |
| 467 | +void init_distance_functions (bool force_cpu) { |
| 468 | + init_cpu_functions(); |
| 469 | + if (force_cpu) return; |
| 470 | + |
| 471 | + #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) |
| 472 | + if (1/*cpu_supports_avx2()*/) { |
439 | 473 | init_distance_functions_avx2(); |
440 | 474 | } else if (cpu_supports_sse2()) { |
441 | 475 | init_distance_functions_sse2(); |
442 | 476 | } |
443 | | -#elif defined(__ARM_NEON) || defined(__aarch64__) |
| 477 | + #elif defined(__ARM_NEON) || defined(__aarch64__) |
444 | 478 | if (cpu_supports_neon()) { |
445 | 479 | init_distance_functions_neon(); |
446 | 480 | } |
447 | | -#else |
448 | | - // DO NOTHING |
449 | | -#endif |
| 481 | + #endif |
450 | 482 | } |
451 | 483 |
|
452 | | -distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX] = { |
453 | | - [VECTOR_DISTANCE_L2] = { |
454 | | - [VECTOR_TYPE_F32] = float32_distance_l2_cpu, |
455 | | - [VECTOR_TYPE_U8] = uint8_distance_l2_cpu, |
456 | | - [VECTOR_TYPE_I8] = int8_distance_l2_cpu, |
457 | | - }, |
458 | | - [VECTOR_DISTANCE_SQUARED_L2] = { |
459 | | - [VECTOR_TYPE_F32] = float32_distance_l2_squared_cpu, |
460 | | - [VECTOR_TYPE_U8] = uint8_distance_l2_squared_cpu, |
461 | | - [VECTOR_TYPE_I8] = int8_distance_l2_squared_cpu, |
462 | | - }, |
463 | | - [VECTOR_DISTANCE_COSINE] = { |
464 | | - [VECTOR_TYPE_F32] = float32_distance_cosine_cpu, |
465 | | - [VECTOR_TYPE_U8] = uint8_distance_cosine_cpu, |
466 | | - [VECTOR_TYPE_I8] = int8_distance_cosine_cpu, |
467 | | - }, |
468 | | - [VECTOR_DISTANCE_DOT] = { |
469 | | - [VECTOR_TYPE_F32] = float32_distance_dot_cpu, |
470 | | - [VECTOR_TYPE_U8] = uint8_distance_dot_cpu, |
471 | | - [VECTOR_TYPE_I8] = int8_distance_dot_cpu, |
472 | | - }, |
473 | | - [VECTOR_DISTANCE_L1] = { |
474 | | - [VECTOR_TYPE_F32] = float32_distance_l1_cpu, |
475 | | - [VECTOR_TYPE_U8] = uint8_distance_l1_cpu, |
476 | | - [VECTOR_TYPE_I8] = int8_distance_l1_cpu, |
477 | | - } |
478 | | -}; |
479 | | - |
|
0 commit comments