@@ -296,34 +296,37 @@ float int8_distance_cosine_cpu (const void *v1, const void *v2, int n) {
296296 const int8_t * a = (const int8_t * )v1 ;
297297 const int8_t * b = (const int8_t * )v2 ;
298298
299- float dot = 0.0f , norm_a2 = 0.0f , norm_b2 = 0.0f ;
299+ int32_t dot = 0 ;
300+ int32_t norm_a2 = 0 ;
301+ int32_t norm_b2 = 0 ;
302+
300303 int i = 0 ;
301-
302304 for (; i <= n - 4 ; i += 4 ) {
303- float a0 = ( float ) a [i + 0 ], b0 = ( float ) b [i + 0 ];
304- float a1 = ( float ) a [i + 1 ], b1 = ( float ) b [i + 1 ];
305- float a2 = ( float ) a [i + 2 ], b2 = ( float ) b [i + 2 ];
306- float a3 = ( float ) a [i + 3 ], b3 = ( float ) b [i + 3 ];
305+ int32_t a0 = a [i + 0 ], b0 = b [i + 0 ];
306+ int32_t a1 = a [i + 1 ], b1 = b [i + 1 ];
307+ int32_t a2 = a [i + 2 ], b2 = b [i + 2 ];
308+ int32_t a3 = a [i + 3 ], b3 = b [i + 3 ];
307309
308310 dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 ;
309311 norm_a2 += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3 ;
310312 norm_b2 += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3 ;
311313 }
312314
315+ // tail loop
313316 for (; i < n ; ++ i ) {
314- float ai = ( float ) a [i ];
315- float bi = ( float ) b [i ];
317+ int32_t ai = a [i ];
318+ int32_t bi = b [i ];
316319 dot += ai * bi ;
317320 norm_a2 += ai * ai ;
318321 norm_b2 += bi * bi ;
319322 }
320323
321- if (norm_a2 == 0.0f || norm_b2 == 0.0f ) {
324+ if (norm_a2 == 0 || norm_b2 == 0 ) {
322325 return 1.0f ;
323326 }
324327
325- float cosine_sim = dot / (sqrtf (norm_a2 ) * sqrtf (norm_b2 ));
326- return 1.0f - cosine_sim ;
328+ float cosine_similarity = dot / (sqrtf (( float ) norm_a2 ) * sqrtf (( float ) norm_b2 ));
329+ return 1.0f - cosine_similarity ;
327330}
328331
329332float int8_distance_dot_cpu (const void * v1 , const void * v2 , int n ) {
0 commit comments