Skip to content

Commit 3c37ca6

Browse files
committed
Implement Flat Index
1 parent 4a50811 commit 3c37ca6

File tree

4 files changed

+140
-12
lines changed

4 files changed

+140
-12
lines changed

crates/core/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ pub enum DbError {
44
StorageError(String),
55
SerializationError(String),
66
DeserializationError,
7+
IndexError(String),
78
}

crates/core/src/types.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ pub struct Point {
2727
pub payload: Option<Payload>,
2828
}
2929

30+
/// Struct which will be stored in the vector index
31+
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
32+
pub struct IndexedVector {
33+
pub id: PointId,
34+
pub vector: DenseVector,
35+
}
36+
37+
#[derive(Copy, Clone)]
38+
pub enum Similarity {
39+
Euclidean,
40+
Manhattan,
41+
Hamming,
42+
Cosine,
43+
}
44+
3045
// Query Vector. Basically the type of query results that can be generated. Not implementing this but referencing here for furture reference
3146
// #[derive(Debug, Clone)]
3247
// pub enum QueryVector {

crates/index/src/flat.rs

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,73 @@
1-
use core::DenseVector;
1+
use core::{DbError, DenseVector, IndexedVector, PointId, Similarity};
22

3-
struct FlatIndex {
4-
index: Vec<DenseVector>
3+
use crate::{distance, VectorIndex};
4+
5+
pub struct FlatIndex {
6+
index: Vec<IndexedVector>,
7+
}
8+
9+
impl FlatIndex {
10+
pub fn new() -> Self {
11+
Self { index: Vec::new() }
12+
}
13+
14+
pub fn build(vectors: Vec<IndexedVector>) -> Self {
15+
FlatIndex { index: vectors }
16+
}
17+
}
18+
19+
impl Default for FlatIndex {
20+
fn default() -> Self {
21+
Self::new()
22+
}
523
}
624

725
impl VectorIndex for FlatIndex {
8-
fn insert() {
9-
26+
fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> {
27+
self.index.push(vector);
28+
Ok(())
29+
}
30+
31+
fn delete(&mut self, point_id: PointId) -> Result<bool, DbError> {
32+
if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) {
33+
self.index.remove(pos);
34+
Ok(true)
35+
} else {
36+
Ok(false)
37+
}
38+
}
39+
40+
fn search(
41+
&self,
42+
query_vector: DenseVector,
43+
similarity: Similarity,
44+
k: usize,
45+
) -> Result<Vec<PointId>, DbError> {
46+
let mut scores = self
47+
.index
48+
.iter()
49+
.map(|point| {
50+
(
51+
point.id,
52+
distance(point.vector.clone(), query_vector.clone(), similarity),
53+
)
54+
})
55+
.collect::<Vec<_>>();
56+
57+
// Sorting logic according to type of metric used
58+
match similarity {
59+
Similarity::Euclidean | Similarity::Manhattan | Similarity::Hamming => {
60+
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
61+
}
62+
Similarity::Cosine => {
63+
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
64+
}
65+
}
66+
67+
Ok(scores
68+
.into_iter()
69+
.take(k)
70+
.map(|(id, _)| id)
71+
.collect::<Vec<_>>())
1072
}
11-
}
73+
}

crates/index/src/lib.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,61 @@
1-
use core::{DenseVector, PointId};
2-
use std::fmt::Error;
1+
use core::{DbError, DenseVector, IndexedVector, PointId, Similarity};
2+
3+
pub mod flat;
34

45
pub trait VectorIndex {
5-
fn insert(&self, vector: DenseVector) -> Result<(), Error>;
6-
fn delete(&self, point_id: PointId) -> Result<(), Error>;
7-
fn search(&self, query_vector: DenseVector) -> Result<DenseVector, Error>;
8-
// fn build() -> Result<(), Error>; move this to impl for dyn compatibility
6+
fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>;
7+
8+
// Returns true if point id existed and is deleted, else returns false
9+
fn delete(&mut self, point_id: PointId) -> Result<bool, DbError>;
10+
11+
fn search(
12+
&self,
13+
query_vector: DenseVector,
14+
similarity: Similarity,
15+
k: usize,
16+
) -> Result<Vec<PointId>, DbError>; // Return a Vec of ids of closest vectors (length max k)
17+
18+
// fn build() -> Result<(), DbError>; move this to impl for dyn compatibility
19+
}
20+
21+
/// Distance function to get the distance between two vectors (taken from old version)
22+
pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 {
23+
assert_eq!(a.len(), b.len());
24+
match dist_type {
25+
Similarity::Euclidean => {
26+
let score: Vec<f32> = a
27+
.iter()
28+
.zip(b.iter())
29+
.map(|(&x, &y)| (x - y) * (x - y))
30+
.collect();
31+
score.iter().sum::<f32>().sqrt()
32+
}
33+
Similarity::Manhattan => {
34+
let score: Vec<f32> = a
35+
.iter()
36+
.zip(b.iter())
37+
.map(|(&x, &y)| (x - y).abs())
38+
.collect();
39+
score.iter().sum::<f32>()
40+
}
41+
Similarity::Hamming => {
42+
let score: Vec<f32> = a
43+
.iter()
44+
.zip(b.iter())
45+
.map(|(&x, &y)| (if x != y { 1f32 } else { 0f32 }))
46+
.collect();
47+
score.iter().sum::<f32>()
48+
}
49+
Similarity::Cosine => {
50+
let p_score: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
51+
let p = p_score.iter().sum::<f32>();
52+
let q_score: Vec<f32> = a.iter().map(|&n| n * n).collect();
53+
let q = q_score.iter().sum::<f32>().sqrt();
54+
let r_score: Vec<f32> = b.iter().map(|&n| n * n).collect();
55+
let r = r_score.iter().sum::<f32>().sqrt();
56+
p / (q * r)
57+
}
58+
}
959
}
1060

1161
pub enum IndexType {

0 commit comments

Comments
 (0)