diff --git a/Cargo.lock b/Cargo.lock index ba5b23b..cf44b74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,9 +68,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "axum" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", "bytes", @@ -101,9 +101,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", @@ -238,9 +238,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.49" +version = "1.2.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" +checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" dependencies = [ "find-msvc-tools", "jobserver", @@ -429,9 +429,9 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" [[package]] name = "fixedbitset" @@ -948,9 +948,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jobserver" @@ -1512,9 +1512,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.26" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64", "bytes", @@ -1597,9 +1597,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags 2.10.0", "errno", @@ -1649,9 +1649,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "schannel" @@ -1723,15 +1723,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -1819,10 +1819,11 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.7" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] @@ -1963,9 +1964,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.23.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", "getrandom 0.3.4", @@ -2824,6 +2825,12 @@ dependencies = [ "syn", ] +[[package]] +name = "zmij" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d6085d62852e35540689d1f97ad663e3971fc19cf5eceab364d62c646ea167" + [[package]] name = "zstd-sys" version = "2.0.16+zstd.1.5.7" diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 8f929c3..6ebe861 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -7,6 +7,8 @@ pub enum DbError { DeserializationError, IndexError(String), LockError, + IndexInitError, //TODO: Change this + UnsupportedSimilarity, DimensionMismatch, } diff --git a/crates/index/src/flat.rs b/crates/index/src/flat.rs index a3a62bf..c0910e3 100644 --- a/crates/index/src/flat.rs +++ b/crates/index/src/flat.rs @@ -47,7 +47,7 @@ impl VectorIndex for FlatIndex { .index .iter() .map(|point| DistanceOrderedVector { - distance: distance(point.vector.clone(), query_vector.clone(), similarity), + distance: distance(&point.vector, &query_vector, similarity), query_vector: &query_vector, point_id: Some(point.id), }) diff --git a/crates/index/src/kd_tree.rs b/crates/index/src/kd_tree.rs deleted file mode 100644 index 444f578..0000000 --- a/crates/index/src/kd_tree.rs +++ /dev/null @@ -1,252 +0,0 @@ -use std::cmp::Ordering; -use std::cmp::Ordering::Less; - -use serde_derive::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize)] -pub struct KDTreeInternals { - pub kd_tree_allow_update: bool, - pub current_number_of_kd_tree_nodes: usize, - pub rebuild_threshold: f32, - pub previous_tree_size: usize, - pub rebuild_counter: usize, -} - -#[derive(Serialize, Deserialize)] -pub struct KDTreeNode { - pub left: Option>, - pub right: Option>, - pub key: String, - pub vector: Vec, - pub dim: usize, -} - -impl KDTreeNode { - // Add the logic here to create a new db and insert the tree into the database - fn new(data: (String, Vec), dim: usize) -> KDTreeNode { - KDTreeNode { - left: None, - right: None, - key: data.0, - vector: data.1, - dim, - } - } -} - -pub struct KDTree { - pub _root: Option>, - pub _internals: KDTreeInternals, - pub is_debug_run: bool, - pub dim: usize, -} - -impl KDTree { - // Create an empty tree with default values - pub fn new() -> KDTree { - KDTree { - _root: None, - _internals: KDTreeInternals { - kd_tree_allow_update: true, - current_number_of_kd_tree_nodes: 0, - rebuild_threshold: 2.0f32, - previous_tree_size: 0, - rebuild_counter: 0, - }, - is_debug_run: true, - dim: 0, - } - } - - // Add a node - // If the dimension of the tree is zero, then it becomes equal to the input data - pub fn add_node(&mut self, data: (String, Vec), depth: usize) { - if self._root.is_none() { - self.dim = data.1.len(); - self._root = Some(Box::new(KDTreeNode::new(data, 0))); - self._internals.current_number_of_kd_tree_nodes += 1; - return; - } - - assert_eq!(self.dim, data.1.len()); - - if !self._internals.kd_tree_allow_update { - println!("KDTree is locked for rebuild"); - return; - } - - if self._internals.previous_tree_size != 0 { - let current_ratio: f32 = self._internals.current_number_of_kd_tree_nodes as f32 - / self._internals.previous_tree_size as f32; - if current_ratio > self._internals.rebuild_threshold { - self._internals.previous_tree_size = - self._internals.current_number_of_kd_tree_nodes; - self.rebuild(); - } - } else { - self._internals.previous_tree_size = self._internals.current_number_of_kd_tree_nodes; - } - - self._internals.current_number_of_kd_tree_nodes += 1; - - let mut current_node = self._root.as_deref_mut().unwrap(); - let mut current_depth = depth; - loop { - let current_dimension = current_depth % self.dim; - if data.1[current_dimension] < current_node.vector[current_dimension] { - if current_node.left.is_none() { - current_node.left = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.left.as_deref_mut().unwrap(); - current_depth += 1; - } - } else { - if current_node.right.is_none() { - current_node.right = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.right.as_deref_mut().unwrap(); - current_depth += 1; - } - } - } - } - - // rebuild tree - fn rebuild(&mut self) { - self._internals.kd_tree_allow_update = false; - self._internals.rebuild_counter += 1; - if self.is_debug_run { - println!( - "Rebuilding tree..., Rebuild counter: {:?}", - self._internals.rebuild_counter - ); - } - let mut points = Vec::into_boxed_slice(self.traversal(0)); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; - } - - // traversal - pub fn traversal(&self, k_value: usize) -> Vec<(String, Vec)> { - let mut result: Vec<(String, Vec)> = Vec::new(); - inorder_traversal_helper(self._root.as_deref(), &mut result, k_value); - result - } - - // delete a node - pub fn delete_node(&mut self, data: String) { - self._internals.kd_tree_allow_update = false; - let mut points = self.traversal(0); - let index = points.iter().position(|x| *x.0 == data).unwrap(); - points.remove(index); - let mut points = Vec::into_boxed_slice(points); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; - } - - // print data for debug - pub fn print_tree_for_debug(&self) { - let iterated: Vec<(String, Vec)> = self.traversal(0); - for iter in iterated { - println!("{}", iter.0); - } - } - - // different methods of knn -} - -// Traversal helper function -fn inorder_traversal_helper( - node: Option<&KDTreeNode>, - result: &mut Vec<(String, Vec)>, - k_value: usize, -) -> Option { - if node.is_none() { - return None; - } - if k_value != 0 && k_value <= result.len() { - return None; - } - let current_node = node.unwrap(); - inorder_traversal_helper(current_node.to_owned().left.as_deref(), result, k_value); - result.push((current_node.key.clone(), current_node.vector.clone())); - inorder_traversal_helper(current_node.to_owned().right.as_deref(), result, k_value); - - Some(true) -} - -// Rebuild tree helper functions -fn create_tree_helper(points: &mut [(String, Vec)], dim: usize) -> KDTreeNode { - let points_len = points.len(); - if points_len == 1 { - return KDTreeNode { - key: points[0].0.clone(), - vector: points[0].1.clone(), - left: None, - right: None, - dim, - }; - } - - // Split around the median - let pivot = quickselect_by(points, points_len / 2, &|a, b| { - a.1[dim].partial_cmp(&b.1[dim]).unwrap() - }); - - let left = Some(Box::new(create_tree_helper( - &mut points[0..points_len / 2], - (dim + 1) % pivot.1.len(), - ))); - let right = if points.len() >= 3 { - Some(Box::new(create_tree_helper( - &mut points[points_len / 2 + 1..points_len], - (dim + 1) % pivot.1.len(), - ))) - } else { - None - }; - - KDTreeNode { - key: pivot.0, - vector: pivot.1, - left, - right, - dim, - } -} - -fn quickselect_by(arr: &mut [T], position: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> T -where - T: Clone, -{ - let mut pivot_index = 0; - // Need to wrap in another closure or we get ownership complaints. - // Tried using an unboxed closure to get around this but couldn't get it to work. - pivot_index = partition_by(arr, pivot_index, &|a: &T, b: &T| cmp(a, b)); - let array_len = arr.len(); - match position.cmp(&pivot_index) { - Ordering::Equal => arr[position].clone(), - Ordering::Less => quickselect_by(&mut arr[0..pivot_index], position, cmp), - Ordering::Greater => quickselect_by( - &mut arr[pivot_index + 1..array_len], - position - pivot_index - 1, - cmp, - ), - } -} - -fn partition_by(arr: &mut [T], pivot_index: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> usize { - let array_len = arr.len(); - arr.swap(pivot_index, array_len - 1); - let mut store_index = 0; - for i in 0..array_len - 1 { - if cmp(&arr[i], &arr[array_len - 1]) == Less { - arr.swap(i, store_index); - store_index += 1; - } - } - arr.swap(array_len - 1, store_index); - store_index -} diff --git a/crates/index/src/kd_tree/index.rs b/crates/index/src/kd_tree/index.rs new file mode 100644 index 0000000..dc623a2 --- /dev/null +++ b/crates/index/src/kd_tree/index.rs @@ -0,0 +1,466 @@ +use super::types::{KDTreeNode, Neighbor}; +use crate::{VectorIndex, distance}; +use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashSet}, + vec, +}; +use uuid::Uuid; + +pub struct KDTree { + pub dim: usize, + pub root: Option>, + // In memory point ids, to check existence before O(n) deletion logic + pub point_ids: HashSet, + // Rebuild tracking + pub total_nodes: usize, + pub deleted_count: usize, +} + +impl KDTree { + // Rebuild threshold + const BALANCE_THRESHOLD: f32 = 0.7; + const DELETE_REBUILD_RATIO: f32 = 0.25; + + // Build an empty index with no points + pub fn build_empty(dim: usize) -> Self { + KDTree { + dim, + root: None, + point_ids: HashSet::new(), + total_nodes: 0, + deleted_count: 0, + } + } + + // Builds the vector index from provided vectors, there should atleast be single vector for dim calculation + pub fn build(mut vectors: Vec) -> Result { + if vectors.is_empty() { + Err(DbError::IndexInitError) + } else { + let dim = vectors[0].vector.len(); + + let mut point_ids = HashSet::with_capacity(vectors.len()); + for indexed_vector in vectors.iter() { + point_ids.insert(indexed_vector.id); + } + + let root_node = Self::build_recursive(&mut vectors, 0, dim); + Ok(KDTree { + dim, + root: Some(root_node), + point_ids, + total_nodes: vectors.len(), + deleted_count: 0, + }) + } + } + + // Builds the tree recursively with given vectors and returns the pointer of the root node + pub fn build_recursive( + vectors: &mut [IndexedVector], + depth: usize, + dim: usize, + ) -> Box { + if vectors.is_empty() { + panic!("Cannot build from an empty slice recursively"); + } + + let axis = depth % dim; + let mid_idx = vectors.len() / 2; + + vectors.select_nth_unstable_by(mid_idx, |a, b| { + let a_at_axis = a.vector[axis]; + let b_at_axis = b.vector[axis]; + a_at_axis.partial_cmp(&b_at_axis).unwrap_or(Ordering::Equal) + }); + + // Using swap so that we don't need to clone the whole vector + let mut median_vec = IndexedVector { + id: Uuid::new_v4(), + vector: vec![], + }; // dummy + std::mem::swap(&mut vectors[mid_idx], &mut median_vec); + + let (left_points, right_points_with_median) = vectors.split_at_mut(mid_idx); + let right_points = &mut right_points_with_median[1..]; // Exclude the swapped-out median + + let left = if left_points.is_empty() { + None + } else { + Some(Self::build_recursive(left_points, depth + 1, dim)) + }; + + let right = if right_points.is_empty() { + None + } else { + Some(Self::build_recursive(right_points, depth + 1, dim)) + }; + + let left_size = left.as_ref().map_or(0, |n| n.subtree_size); + let right_size = right.as_ref().map_or(0, |n| n.subtree_size); + + Box::new(KDTreeNode { + indexed_vector: median_vec, + left, + right, + is_deleted: false, + subtree_size: left_size + right_size + 1, + }) + } + + pub fn insert_point(&mut self, new_vector: IndexedVector) { + // Add to point_ids + self.point_ids.insert(new_vector.id); + self.total_nodes += 1; + + // use a traverse function to get the final leaf where this belongs + if self.root.is_none() { + self.root = Some(Box::new(KDTreeNode { + indexed_vector: new_vector, + left: None, + right: None, + is_deleted: false, + subtree_size: 1, + })); + return; + } + + let mut path: Vec<(usize, bool)> = Vec::new(); + let dim = self.dim; + + let mut current_link = &mut self.root; + let mut depth = 0; + + while let Some(node_box) = current_link { + let axis = depth % dim; + let current_node = node_box.as_mut(); + + current_node.subtree_size += 1; + + let va = new_vector.vector[axis]; + let vb = current_node.indexed_vector.vector[axis]; + + let go_left = va <= vb; + path.push((depth, go_left)); + + if go_left { + current_link = &mut current_node.left; + } else { + current_link = &mut current_node.right; + } + depth += 1; + } + + // Assign the new node to current link which is &mut Option> + let new_node = Box::new(KDTreeNode { + indexed_vector: new_vector, + left: None, + right: None, + is_deleted: false, + subtree_size: 1, + }); + + *current_link = Some(new_node); + + self.check_and_rebalance(&path); + } + + // Rebuild helper methods + fn is_unbalanced(node: &KDTreeNode) -> bool { + let left_size = node.left.as_ref().map_or(0, |n| n.subtree_size); + let right_size = node.right.as_ref().map_or(0, |n| n.subtree_size); + let max_child = left_size.max(right_size); + + max_child as f32 > Self::BALANCE_THRESHOLD * node.subtree_size as f32 + } + + fn collect_recursive(node: KDTreeNode, result: &mut Vec) { + if !node.is_deleted { + result.push(node.indexed_vector); + } + if let Some(left) = node.left { + Self::collect_recursive(*left, result); + } + if let Some(right) = node.right { + Self::collect_recursive(*right, result); + } + } + + fn collect_active_vectors(node: KDTreeNode) -> Vec { + let mut result = Vec::with_capacity(node.subtree_size); + Self::collect_recursive(node, &mut result); + result + } + + fn rebuild_at_depth(&mut self, path: &[(usize, bool)], target_depth: usize) { + let dim = self.dim; + + // Navigate to parent of target node + if target_depth == 0 { + // Rebuild root + if let Some(root) = self.root.take() { + let old_size = root.subtree_size; + let mut vectors = Self::collect_active_vectors(*root); + let new_size = vectors.len(); + if !vectors.is_empty() { + self.root = Some(Self::build_recursive(&mut vectors, 0, dim)); + } + // Update global counts as deleted nodes were purged + self.total_nodes -= old_size - new_size; + self.deleted_count = 0; + } + } else { + // Navigate to target node + let mut current_link = &mut self.root; + for (_depth, go_left) in path.iter().take(target_depth) { + let node = current_link.as_mut().unwrap(); + current_link = if *go_left { + &mut node.left + } else { + &mut node.right + }; + } + + // Rebuild tree at current link + if let Some(subtree_root) = current_link.take() { + let old_size = subtree_root.subtree_size; + let mut vectors = Self::collect_active_vectors(*subtree_root); + let new_size = vectors.len(); + + if !vectors.is_empty() { + *current_link = Some(Self::build_recursive(&mut vectors, target_depth, dim)); + } + + // Only update ancestors if size changed (deleted nodes were purged) + if old_size != new_size { + let size_diff = old_size - new_size; + self.subtract_size_from_ancestors(path, target_depth, size_diff); + + self.total_nodes -= size_diff; + self.deleted_count = self.deleted_count.saturating_sub(size_diff); + } + } + } + } + + fn subtract_size_from_ancestors( + &mut self, + path: &[(usize, bool)], + up_to_depth: usize, + diff: usize, + ) { + let mut current = &mut self.root; + for (_, go_left) in path.iter().take(up_to_depth) { + if let Some(node) = current { + node.subtree_size -= diff; + current = if *go_left { + &mut node.left + } else { + &mut node.right + }; + } + } + } + + fn check_and_rebalance(&mut self, path: &[(usize, bool)]) { + // Find the shallowest (closest to root) depth where imbalance occurs + // so that rebuilding fixes the largest unbalanced subtree + let mut unbalanced_depth: Option = None; + + let mut current = self.root.as_ref(); + + // Check root first (depth 0) + if let Some(node) = current + && Self::is_unbalanced(node) + { + unbalanced_depth = Some(0); + } + + // Then traverse the path and check each node + // Once we find the shallowest unbalanced node, break immediately + for (idx, (_depth, go_left)) in path.iter().enumerate() { + if unbalanced_depth.is_some() { + break; + } + + if let Some(node) = current { + current = if *go_left { + node.left.as_ref() + } else { + node.right.as_ref() + }; + + // Check the child node we just moved to (at depth idx + 1) + if let Some(child) = current + && Self::is_unbalanced(child) + { + unbalanced_depth = Some(idx + 1); + break; + } + } + } + + if let Some(target_depth) = unbalanced_depth { + self.rebuild_at_depth(path, target_depth); + } + } + + fn should_rebuild_global(&self) -> bool { + self.total_nodes > 0 + && (self.deleted_count as f32 / self.total_nodes as f32) > Self::DELETE_REBUILD_RATIO + } + + // Returns true if point found and deleted, else false + pub fn delete_point(&mut self, point_id: &PointId) -> bool { + if self.point_ids.contains(point_id) { + let deleted = Self::find_and_mark_deleted(&mut self.root, *point_id); + if deleted { + self.deleted_count += 1; + self.point_ids.remove(point_id); + } + + if Self::should_rebuild_global(self) + && let Some(root) = self.root.take() + { + let mut vectors = Self::collect_active_vectors(*root); + if !vectors.is_empty() { + self.root = Some(Self::build_recursive(&mut vectors, 0, self.dim)); + } + + self.total_nodes = vectors.len(); + self.deleted_count = 0; + } + + return deleted; + } + false + } + + fn find_and_mark_deleted(node_opt: &mut Option>, target_id: PointId) -> bool { + if let Some(node) = node_opt { + if node.indexed_vector.id == target_id { + node.is_deleted = true; + return true; + } + + // Search left first then right + Self::find_and_mark_deleted(&mut node.left, target_id) + || Self::find_and_mark_deleted(&mut node.right, target_id) + } else { + false + } + } + + pub fn search_top_k( + &self, + query_vector: DenseVector, + k: usize, + dist_type: Similarity, + ) -> Vec<(PointId, f32)> { + //Searches for top k closest vectors according to specified metric + + if self.root.is_none() || k == 0 { + return Vec::new(); + } + + let mut best_neighbours = BinaryHeap::with_capacity(k); + + self.search_recursive( + &self.root, + &query_vector, + k, + &mut best_neighbours, + 0, + dist_type, + ); + + best_neighbours + .into_sorted_vec() + .iter() + .map(|neighbor| (neighbor.id, neighbor.distance)) + .collect() + } + + fn search_recursive( + &self, + node_opt: &Option>, + query_vector: &DenseVector, + k: usize, + heap: &mut BinaryHeap, + depth: usize, + dist_type: Similarity, + ) { + // Base case is that we hit a leaf node don't do anything + if let Some(node) = node_opt { + let axis = depth % self.dim; + + let (near_side, far_side) = if query_vector[axis] <= node.indexed_vector.vector[axis] { + (&node.left, &node.right) + } else { + (&node.right, &node.left) + }; + + // Recurse on near side first + self.search_recursive(near_side, query_vector, k, heap, depth + 1, dist_type); + + if !node.is_deleted { + // TODO: Possible overhead, here heap stores sqrt euclidean distance, we can eliminate that by storing squared distances in case of euclidean + let distance = distance(query_vector, &node.indexed_vector.vector, dist_type); + if heap.len() < k { + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } else if distance < heap.peek().unwrap().distance { + heap.pop(); + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } + } + + // Pruning on the farther side to check if there are better candidates + // Use <= to handle ties: when axis_diff == current worst distance, there could be + // a point on the far side with the same distance that should be included + let axis_diff = (query_vector[axis] - node.indexed_vector.vector[axis]).abs(); + let should_search_far = match dist_type { + Similarity::Euclidean | Similarity::Manhattan => { + heap.len() < k || axis_diff <= heap.peek().unwrap().distance + } + _ => true, // Cosine/Hamming - no effective pruning, always search + }; + + if should_search_far { + self.search_recursive(far_side, query_vector, k, heap, depth + 1, dist_type); + } + } + } +} + +impl VectorIndex for KDTree { + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { + self.insert_point(vector); + Ok(()) + } + + fn delete(&mut self, point_id: PointId) -> Result { + Ok(self.delete_point(&point_id)) + } + + fn search( + &self, + query_vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result, DbError> { + if matches!(similarity, Similarity::Cosine | Similarity::Hamming) { + return Err(DbError::UnsupportedSimilarity); + } + + let results = self.search_top_k(query_vector, k, similarity); + Ok(results.into_iter().map(|(id, _)| id).collect()) + } +} diff --git a/crates/index/src/kd_tree/mod.rs b/crates/index/src/kd_tree/mod.rs new file mode 100644 index 0000000..8765acf --- /dev/null +++ b/crates/index/src/kd_tree/mod.rs @@ -0,0 +1,5 @@ +pub mod index; +pub mod types; + +#[cfg(test)] +mod tests; diff --git a/crates/index/src/kd_tree/tests.rs b/crates/index/src/kd_tree/tests.rs new file mode 100644 index 0000000..faefae5 --- /dev/null +++ b/crates/index/src/kd_tree/tests.rs @@ -0,0 +1,703 @@ +use super::index::KDTree; +use crate::VectorIndex; +use crate::distance; +use crate::flat::FlatIndex; +use defs::{DbError, IndexedVector, Similarity}; +use std::collections::HashSet; +use uuid::Uuid; + +fn make_vector(vector: Vec) -> IndexedVector { + IndexedVector { + id: Uuid::new_v4(), + vector, + } +} + +fn make_vector_with_id(id: Uuid, vector: Vec) -> IndexedVector { + IndexedVector { id, vector } +} + +// Build Tests + +#[test] +fn test_build_empty() { + let tree = KDTree::build_empty(3); + assert!(tree.root.is_none()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 0); + assert!(tree.point_ids.is_empty()); +} + +#[test] +fn test_build_with_empty_vectors_returns_error() { + let result = KDTree::build(vec![]); + assert!(result.is_err()); +} + +#[test] +fn test_build_single_vector() { + let id = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id, vec![1.0, 2.0, 3.0])]; + let tree = KDTree::build(vectors).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 1); + assert!(tree.point_ids.contains(&id)); +} + +#[test] +fn test_build_multiple_vectors() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 2.0]), + make_vector_with_id(id2, vec![3.0, 4.0]), + make_vector_with_id(id3, vec![5.0, 6.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 2); + assert_eq!(tree.total_nodes, 3); + assert!(tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); +} + +// Insert Tests + +#[test] +fn test_insert_into_empty_tree() { + let mut tree = KDTree::build_empty(2); + let id = Uuid::new_v4(); + let vector = make_vector_with_id(id, vec![1.0, 2.0]); + + let result = tree.insert(vector); + assert!(result.is_ok()); + assert_eq!(tree.total_nodes, 1); + assert!(tree.point_ids.contains(&id)); + assert!(tree.root.is_some()); +} + +#[test] +fn test_insert_multiple_vectors() { + let mut tree = KDTree::build_empty(2); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + + tree.insert(make_vector_with_id(id1, vec![1.0, 2.0])) + .unwrap(); + tree.insert(make_vector_with_id(id2, vec![3.0, 4.0])) + .unwrap(); + tree.insert(make_vector_with_id(id3, vec![5.0, 6.0])) + .unwrap(); + + assert_eq!(tree.total_nodes, 3); + assert!(tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); +} + +// Delete Tests + +#[test] +fn test_delete_existing_point() { + let mut ids = Vec::new(); + let mut vectors = Vec::new(); + + // Create enough vectors so deleting one doesn't trigger global rebuild + for i in 0..10 { + let id = Uuid::new_v4(); + ids.push(id); + vectors.push(make_vector_with_id(id, vec![i as f32, i as f32])); + } + + let mut tree = KDTree::build(vectors).unwrap(); + + let result = tree.delete(ids[0]).unwrap(); + assert!(result); + assert!(!tree.point_ids.contains(&ids[0])); + assert_eq!(tree.deleted_count, 1); +} + +#[test] +fn test_delete_non_existing_point() { + let id1 = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id1, vec![1.0, 2.0])]; + let mut tree = KDTree::build(vectors).unwrap(); + + let non_existing_id = Uuid::new_v4(); + let result = tree.delete(non_existing_id).unwrap(); + assert!(!result); + assert_eq!(tree.deleted_count, 0); +} + +#[test] +fn test_delete_from_empty_tree() { + let mut tree = KDTree::build_empty(2); + let result = tree.delete(Uuid::new_v4()).unwrap(); + assert!(!result); +} + +#[test] +fn test_deleted_point_not_in_search_results() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![0.0, 0.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), + make_vector_with_id(id3, vec![10.0, 10.0]), + ]; + let mut tree = KDTree::build(vectors).unwrap(); + + // Delete the closest point + tree.delete(id1).unwrap(); + + // Search should not return the deleted point + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert!(!results.contains(&id1)); + assert!(results.contains(&id2)); +} + +// Search Tests (VectorIndex trait) + +#[test] +fn test_search_empty_tree() { + let tree = KDTree::build_empty(2); + let results = tree + .search(vec![1.0, 2.0], Similarity::Euclidean, 5) + .unwrap(); + assert!(results.is_empty()); +} + +#[test] +fn test_search_euclidean() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + make_vector_with_id(id3, vec![10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0], id1); // Closest + assert_eq!(results[1], id2); // Second closest +} + +#[test] +fn test_search_manhattan() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + make_vector_with_id(id3, vec![5.0, 5.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Manhattan, 2) + .unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0], id1); + assert_eq!(results[1], id2); +} + +#[test] +fn test_search_unsupported_similarity_cosine() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let result = tree.search(vec![1.0, 2.0], Similarity::Cosine, 1); + assert!(matches!(result, Err(DbError::UnsupportedSimilarity))); +} + +#[test] +fn test_search_unsupported_similarity_hamming() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let result = tree.search(vec![1.0, 2.0], Similarity::Hamming, 1); + assert!(matches!(result, Err(DbError::UnsupportedSimilarity))); +} + +#[test] +fn test_search_k_larger_than_tree_size() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 10) + .unwrap(); + assert_eq!(results.len(), 2); // Should return all available points +} + +#[test] +fn test_search_exact_match() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![5.0, 5.0]), + make_vector_with_id(id2, vec![10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![5.0, 5.0], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], id1); +} + +// Search Correctness Tests + +#[test] +fn test_search_correctness_3d() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![0.0, 0.0, 0.0]), + make_vector_with_id(id2, vec![1.0, 1.0, 1.0]), + make_vector_with_id(id3, vec![2.0, 2.0, 2.0]), + make_vector_with_id(id4, vec![10.0, 10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.5, 0.5, 0.5], Similarity::Euclidean, 2) + .unwrap(); + // id1 at distance sqrt(0.75) ≈ 0.866 + // id2 at distance sqrt(0.75) ≈ 0.866 + // Both are equidistant, should return both + assert_eq!(results.len(), 2); + assert!(results.contains(&id1) || results.contains(&id2)); +} + +#[test] +fn test_search_after_insert() { + let mut tree = KDTree::build_empty(2); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + + tree.insert(make_vector_with_id(id1, vec![10.0, 10.0])) + .unwrap(); + tree.insert(make_vector_with_id(id2, vec![1.0, 1.0])) + .unwrap(); + tree.insert(make_vector_with_id(id3, vec![5.0, 5.0])) + .unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results[0], id2); // Closest to origin + assert_eq!(results[1], id3); // Second closest +} + +#[test] +fn test_search_high_dimensional() { + let dim = 10; + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + + let vectors = vec![ + make_vector_with_id(id1, vec![0.0; dim]), + make_vector_with_id(id2, vec![1.0; dim]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let query = vec![0.1; dim]; + let results = tree.search(query, Similarity::Euclidean, 1).unwrap(); + assert_eq!(results[0], id1); // Closer to all-zeros +} + +// Rebalancing Tests + +#[test] +fn test_many_inserts_maintains_searchability() { + let mut tree = KDTree::build_empty(2); + let mut ids = Vec::new(); + + // Insert many points that would cause imbalance + for i in 0..20 { + let id = Uuid::new_v4(); + ids.push(id); + tree.insert(make_vector_with_id(id, vec![i as f32, i as f32])) + .unwrap(); + } + + // Search should still work correctly + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 5) + .unwrap(); + assert_eq!(results.len(), 5); + // First result should be the point at (0, 0) + assert_eq!(results[0], ids[0]); +} + +#[test] +fn test_delete_triggers_rebuild() { + let mut ids = Vec::new(); + let mut vectors = Vec::new(); + + for i in 0..10 { + let id = Uuid::new_v4(); + ids.push(id); + vectors.push(make_vector_with_id(id, vec![i as f32, i as f32])); + } + + let mut tree = KDTree::build(vectors).unwrap(); + + // Delete enough points to trigger rebuild (> 25%) + for id in ids.iter().take(3) { + tree.delete(*id).unwrap(); + } + + // Tree should still function correctly + let results = tree + .search(vec![5.0, 5.0], Similarity::Euclidean, 3) + .unwrap(); + assert_eq!(results.len(), 3); + // Deleted points should not appear + for id in ids.iter().take(3) { + assert!(!results.contains(id)); + } +} + +// Edge Cases + +#[test] +fn test_single_point_search() { + let id = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id, vec![5.0, 5.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], id); +} + +#[test] +fn test_duplicate_coordinates() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), // Same coordinates + make_vector_with_id(id3, vec![2.0, 2.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![1.0, 1.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results.len(), 2); + // Both id1 and id2 should be in results (both at distance 0) + assert!(results.contains(&id1) || results.contains(&id2)); +} + +#[test] +fn test_negative_coordinates() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![-1.0, -1.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![-0.5, -0.5], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results[0], id1); +} + +#[test] +fn test_search_with_zero_k() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![1.0, 2.0], Similarity::Euclidean, 0) + .unwrap(); + assert!(results.is_empty()); +} + +// Comparison Tests: KDTree vs FlatIndex + +/// Helper to create a fixed set of 10 vectors with known UUIDs for comparison tests +fn create_test_vectors_2d() -> Vec { + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + vec![ + make_vector_with_id(ids[0], vec![0.5, 0.5]), + make_vector_with_id(ids[1], vec![2.3, 1.7]), + make_vector_with_id(ids[2], vec![-1.0, 3.0]), + make_vector_with_id(ids[3], vec![4.5, -2.0]), + make_vector_with_id(ids[4], vec![7.0, 7.0]), + make_vector_with_id(ids[5], vec![-3.5, -1.5]), + make_vector_with_id(ids[6], vec![1.0, 5.0]), + make_vector_with_id(ids[7], vec![6.0, 2.0]), + make_vector_with_id(ids[8], vec![-2.0, -4.0]), + make_vector_with_id(ids[9], vec![3.0, 3.0]), + ] +} + +fn create_test_vectors_3d() -> Vec { + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + vec![ + make_vector_with_id(ids[0], vec![1.0, 2.0, 3.0]), + make_vector_with_id(ids[1], vec![-1.5, 0.5, 2.0]), + make_vector_with_id(ids[2], vec![4.0, 4.0, 4.0]), + make_vector_with_id(ids[3], vec![0.0, 0.0, 0.0]), + make_vector_with_id(ids[4], vec![2.5, -1.0, 3.5]), + make_vector_with_id(ids[5], vec![-2.0, 3.0, -1.0]), + make_vector_with_id(ids[6], vec![5.0, 1.0, 2.0]), + make_vector_with_id(ids[7], vec![3.0, 3.0, 3.0]), + make_vector_with_id(ids[8], vec![-0.5, -0.5, 1.0]), + make_vector_with_id(ids[9], vec![1.5, 2.5, 0.5]), + ] +} + +/// Helper to verify that two result sets are valid k-nearest neighbor results +/// Both should return the k closest points (by distance), but may differ on tie-breaking +fn verify_same_results( + tree_results: &[Uuid], + flat_results: &[Uuid], + vectors: &[IndexedVector], + query: &[f32], + similarity: Similarity, + k: usize, +) { + // Same length + assert_eq!( + tree_results.len(), + flat_results.len(), + "Result lengths differ" + ); + + // Both should return at most k results + assert!(tree_results.len() <= k); + + // Get distances for all results + let query_vec = query.to_vec(); + let get_distance = |id: &Uuid| -> f32 { + let vec = vectors.iter().find(|v| v.id == *id).unwrap(); + distance(&vec.vector, &query_vec, similarity) + }; + + // Verify tree results are sorted by distance + for i in 1..tree_results.len() { + let d1 = get_distance(&tree_results[i - 1]); + let d2 = get_distance(&tree_results[i]); + assert!( + d1 <= d2 + 1e-6, + "KDTree results not sorted: {} > {}", + d1, + d2 + ); + } + + // Verify flat results are sorted by distance + for i in 1..flat_results.len() { + let d1 = get_distance(&flat_results[i - 1]); + let d2 = get_distance(&flat_results[i]); + assert!(d1 <= d2 + 1e-6, "Flat results not sorted: {} > {}", d1, d2); + } + + // The maximum distance in both result sets should be the same (k-th nearest distance) + if !tree_results.is_empty() { + let tree_max_dist = get_distance(tree_results.last().unwrap()); + let flat_max_dist = get_distance(flat_results.last().unwrap()); + assert!( + (tree_max_dist - flat_max_dist).abs() < 1e-6, + "Max distances differ: tree={}, flat={}", + tree_max_dist, + flat_max_dist + ); + } + + // Verify that for each result in tree_results, either: + // 1. It's also in flat_results, OR + // 2. It has the same distance as the last element (tie-breaking difference) + let flat_set: HashSet<_> = flat_results.iter().collect(); + let flat_max_dist = if flat_results.is_empty() { + 0.0 + } else { + get_distance(flat_results.last().unwrap()) + }; + + for id in tree_results { + if !flat_set.contains(id) { + // This ID is not in flat results, verify it's a tie + let dist = get_distance(id); + assert!( + (dist - flat_max_dist).abs() < 1e-6, + "KDTree returned {:?} with distance {} but it's not in flat results and not a tie (flat max: {})", + id, + dist, + flat_max_dist + ); + } + } + + // Similarly verify flat_results + let tree_set: HashSet<_> = tree_results.iter().collect(); + let tree_max_dist = if tree_results.is_empty() { + 0.0 + } else { + get_distance(tree_results.last().unwrap()) + }; + + for id in flat_results { + if !tree_set.contains(id) { + let dist = get_distance(id); + assert!( + (dist - tree_max_dist).abs() < 1e-6, + "Flat returned {:?} with distance {} but it's not in tree results and not a tie (tree max: {})", + id, + dist, + tree_max_dist + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_2d() { + let vectors = create_test_vectors_2d(); + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + // Test multiple query points and different k values + let queries = vec![ + vec![0.0, 0.0], + vec![3.0, 3.0], + vec![-1.0, 2.0], + vec![5.0, 5.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_3d() { + let vectors = create_test_vectors_3d(); + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + let queries = vec![ + vec![0.0, 0.0, 0.0], + vec![2.0, 2.0, 2.0], + vec![-1.0, 1.0, 1.0], + vec![4.0, 3.0, 3.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_5d() { + // Test with higher dimensionality + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + let vectors = vec![ + make_vector_with_id(ids[0], vec![1.0, 2.0, 3.0, 4.0, 5.0]), + make_vector_with_id(ids[1], vec![-1.0, 0.0, 1.0, 2.0, 3.0]), + make_vector_with_id(ids[2], vec![5.0, 4.0, 3.0, 2.0, 1.0]), + make_vector_with_id(ids[3], vec![0.0, 0.0, 0.0, 0.0, 0.0]), + make_vector_with_id(ids[4], vec![2.5, 2.5, 2.5, 2.5, 2.5]), + make_vector_with_id(ids[5], vec![-2.0, -1.0, 0.0, 1.0, 2.0]), + make_vector_with_id(ids[6], vec![3.0, 3.0, 3.0, 3.0, 3.0]), + make_vector_with_id(ids[7], vec![1.0, 1.0, 1.0, 1.0, 1.0]), + make_vector_with_id(ids[8], vec![4.0, 0.0, -1.0, 2.0, 5.0]), + make_vector_with_id(ids[9], vec![-0.5, 1.5, 2.5, 3.5, 4.5]), + ]; + + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + let queries = vec![ + vec![0.0, 0.0, 0.0, 0.0, 0.0], + vec![2.0, 2.0, 2.0, 2.0, 2.0], + vec![1.0, 2.0, 3.0, 4.0, 5.0], + vec![-1.0, -1.0, 0.0, 1.0, 1.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} diff --git a/crates/index/src/kd_tree/types.rs b/crates/index/src/kd_tree/types.rs new file mode 100644 index 0000000..9999cb3 --- /dev/null +++ b/crates/index/src/kd_tree/types.rs @@ -0,0 +1,37 @@ +use std::cmp::Ordering; + +use defs::{IndexedVector, PointId}; + +// the node which will be the part of the KD Tree +pub struct KDTreeNode { + pub indexed_vector: IndexedVector, + pub left: Option>, + pub right: Option>, + pub is_deleted: bool, + + pub subtree_size: usize, +} + +// The struct definition which is present in max heap while search +#[derive(Debug, Clone, PartialEq)] +pub struct Neighbor { + pub id: PointId, + pub distance: f32, +} + +impl Eq for Neighbor {} + +// Custom Ord implementation for the max-heap +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + self.distance + .partial_cmp(&other.distance) + .unwrap_or(Ordering::Equal) + } +} + +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index bd802b2..0dfb39c 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -1,6 +1,7 @@ use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; pub mod flat; +pub mod kd_tree; pub trait VectorIndex: Send + Sync { fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>; @@ -19,7 +20,7 @@ pub trait VectorIndex: Send + Sync { } /// Distance function to get the distance between two vectors (taken from old version) -pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { +pub fn distance(a: &DenseVector, b: &DenseVector, dist_type: Similarity) -> f32 { assert_eq!(a.len(), b.len()); match dist_type { Similarity::Euclidean => {