Skip to content

Commit caacef0

Browse files
committed
Add tests and fix hamming distance
1 parent 3c37ca6 commit caacef0

File tree

2 files changed

+185
-1
lines changed

2 files changed

+185
-1
lines changed

crates/index/src/flat.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,187 @@ impl VectorIndex for FlatIndex {
7171
.collect::<Vec<_>>())
7272
}
7373
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
79+
#[test]
80+
fn test_flat_index_new() {
81+
let index = FlatIndex::new();
82+
assert_eq!(index.index.len(), 0);
83+
}
84+
85+
#[test]
86+
fn test_flat_index_build() {
87+
let vectors = vec![
88+
IndexedVector {
89+
id: 1,
90+
vector: vec![1.0, 2.0, 3.0],
91+
},
92+
IndexedVector {
93+
id: 2,
94+
vector: vec![4.0, 5.0, 6.0],
95+
},
96+
];
97+
let index = FlatIndex::build(vectors.clone());
98+
assert_eq!(index.index, vectors);
99+
}
100+
101+
#[test]
102+
fn test_insert() {
103+
let mut index = FlatIndex::new();
104+
let vector = IndexedVector {
105+
id: 1,
106+
vector: vec![1.0, 2.0, 3.0],
107+
};
108+
109+
assert!(index.insert(vector.clone()).is_ok());
110+
assert_eq!(index.index.len(), 1);
111+
assert_eq!(index.index[0], vector);
112+
}
113+
114+
#[test]
115+
fn test_delete_existing() {
116+
let mut index = FlatIndex::new();
117+
let vector = IndexedVector {
118+
id: 1,
119+
vector: vec![1.0, 2.0, 3.0],
120+
};
121+
index.insert(vector).unwrap();
122+
123+
let result = index.delete(1).unwrap();
124+
assert!(result);
125+
assert_eq!(index.index.len(), 0);
126+
}
127+
128+
#[test]
129+
fn test_delete_non_existing() {
130+
let mut index = FlatIndex::new();
131+
let vector = IndexedVector {
132+
id: 1,
133+
vector: vec![1.0, 2.0, 3.0],
134+
};
135+
index.insert(vector).unwrap();
136+
137+
let result = index.delete(999).unwrap();
138+
assert!(!result);
139+
assert_eq!(index.index.len(), 1);
140+
}
141+
142+
#[test]
143+
fn test_search_euclidean() {
144+
let mut index = FlatIndex::new();
145+
index
146+
.insert(IndexedVector {
147+
id: 1,
148+
vector: vec![1.0, 1.0],
149+
})
150+
.unwrap();
151+
index
152+
.insert(IndexedVector {
153+
id: 2,
154+
vector: vec![2.0, 2.0],
155+
})
156+
.unwrap();
157+
index
158+
.insert(IndexedVector {
159+
id: 3,
160+
vector: vec![10.0, 10.0],
161+
})
162+
.unwrap();
163+
164+
let results = index
165+
.search(vec![0.0, 0.0], Similarity::Euclidean, 2)
166+
.unwrap();
167+
assert_eq!(results, vec![1, 2]);
168+
}
169+
170+
#[test]
171+
fn test_search_cosine() {
172+
let mut index = FlatIndex::new();
173+
index
174+
.insert(IndexedVector {
175+
id: 1,
176+
vector: vec![1.0, 0.0],
177+
})
178+
.unwrap();
179+
index
180+
.insert(IndexedVector {
181+
id: 2,
182+
vector: vec![0.5, 0.5],
183+
})
184+
.unwrap();
185+
index
186+
.insert(IndexedVector {
187+
id: 3,
188+
vector: vec![0.0, 1.0],
189+
})
190+
.unwrap();
191+
192+
let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap();
193+
assert_eq!(results, vec![2, 1]);
194+
}
195+
196+
#[test]
197+
fn test_search_manhattan() {
198+
let mut index = FlatIndex::new();
199+
index
200+
.insert(IndexedVector {
201+
id: 1,
202+
vector: vec![1.0, 1.0],
203+
})
204+
.unwrap();
205+
index
206+
.insert(IndexedVector {
207+
id: 2,
208+
vector: vec![2.0, 2.0],
209+
})
210+
.unwrap();
211+
index
212+
.insert(IndexedVector {
213+
id: 3,
214+
vector: vec![5.0, 5.0],
215+
})
216+
.unwrap();
217+
218+
let results = index
219+
.search(vec![0.0, 0.0], Similarity::Manhattan, 2)
220+
.unwrap();
221+
assert_eq!(results, vec![1, 2]);
222+
}
223+
224+
#[test]
225+
fn test_search_hamming() {
226+
let mut index = FlatIndex::new();
227+
index
228+
.insert(IndexedVector {
229+
id: 1,
230+
vector: vec![1.0, 0.0, 1.0, 0.0],
231+
})
232+
.unwrap();
233+
index
234+
.insert(IndexedVector {
235+
id: 2,
236+
vector: vec![1.0, 0.0, 0.0, 0.0],
237+
})
238+
.unwrap();
239+
index
240+
.insert(IndexedVector {
241+
id: 3,
242+
vector: vec![0.0, 0.0, 0.0, 0.0],
243+
})
244+
.unwrap();
245+
246+
let results = index
247+
.search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2)
248+
.unwrap();
249+
assert_eq!(results, vec![2, 3]);
250+
}
251+
252+
#[test]
253+
fn test_default() {
254+
let index = FlatIndex::default();
255+
assert_eq!(index.index.len(), 0);
256+
}
257+
}

crates/index/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 {
4242
let score: Vec<f32> = a
4343
.iter()
4444
.zip(b.iter())
45-
.map(|(&x, &y)| (if x != y { 1f32 } else { 0f32 }))
45+
.map(|(&x, &y)| (if (x - y) > 1e-8 { 1f32 } else { 0f32 }))
4646
.collect();
4747
score.iter().sum::<f32>()
4848
}

0 commit comments

Comments
 (0)