|
1 | | -use core::{DenseVector, PointId}; |
2 | | -use std::fmt::Error; |
| 1 | +use core::{DbError, DenseVector, IndexedVector, PointId, Similarity}; |
| 2 | + |
| 3 | +pub mod flat; |
3 | 4 |
|
4 | 5 | pub trait VectorIndex { |
5 | | - fn insert(&self, vector: DenseVector) -> Result<(), Error>; |
6 | | - fn delete(&self, point_id: PointId) -> Result<(), Error>; |
7 | | - fn search(&self, query_vector: DenseVector) -> Result<DenseVector, Error>; |
8 | | - // fn build() -> Result<(), Error>; move this to impl for dyn compatibility |
| 6 | + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>; |
| 7 | + |
| 8 | + // Returns true if point id existed and is deleted, else returns false |
| 9 | + fn delete(&mut self, point_id: PointId) -> Result<bool, DbError>; |
| 10 | + |
| 11 | + fn search( |
| 12 | + &self, |
| 13 | + query_vector: DenseVector, |
| 14 | + similarity: Similarity, |
| 15 | + k: usize, |
| 16 | + ) -> Result<Vec<PointId>, DbError>; // Return a Vec of ids of closest vectors (length max k) |
| 17 | + |
| 18 | + // fn build() -> Result<(), DbError>; move this to impl for dyn compatibility |
| 19 | +} |
| 20 | + |
| 21 | +/// Distance function to get the distance between two vectors (taken from old version) |
| 22 | +pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { |
| 23 | + assert_eq!(a.len(), b.len()); |
| 24 | + match dist_type { |
| 25 | + Similarity::Euclidean => { |
| 26 | + let score: Vec<f32> = a |
| 27 | + .iter() |
| 28 | + .zip(b.iter()) |
| 29 | + .map(|(&x, &y)| (x - y) * (x - y)) |
| 30 | + .collect(); |
| 31 | + score.iter().sum::<f32>().sqrt() |
| 32 | + } |
| 33 | + Similarity::Manhattan => { |
| 34 | + let score: Vec<f32> = a |
| 35 | + .iter() |
| 36 | + .zip(b.iter()) |
| 37 | + .map(|(&x, &y)| (x - y).abs()) |
| 38 | + .collect(); |
| 39 | + score.iter().sum::<f32>() |
| 40 | + } |
| 41 | + Similarity::Hamming => { |
| 42 | + let score: Vec<f32> = a |
| 43 | + .iter() |
| 44 | + .zip(b.iter()) |
| 45 | + .map(|(&x, &y)| (if x != y { 1f32 } else { 0f32 })) |
| 46 | + .collect(); |
| 47 | + score.iter().sum::<f32>() |
| 48 | + } |
| 49 | + Similarity::Cosine => { |
| 50 | + let p_score: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect(); |
| 51 | + let p = p_score.iter().sum::<f32>(); |
| 52 | + let q_score: Vec<f32> = a.iter().map(|&n| n * n).collect(); |
| 53 | + let q = q_score.iter().sum::<f32>().sqrt(); |
| 54 | + let r_score: Vec<f32> = b.iter().map(|&n| n * n).collect(); |
| 55 | + let r = r_score.iter().sum::<f32>().sqrt(); |
| 56 | + p / (q * r) |
| 57 | + } |
| 58 | + } |
9 | 59 | } |
10 | 60 |
|
11 | 61 | pub enum IndexType { |
|
0 commit comments