diff --git a/Cargo.lock b/Cargo.lock index ba5b23b..69efb5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,15 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -44,6 +53,7 @@ version = "0.1.0" dependencies = [ "defs", "index", + "snapshot", "storage", "tempfile", "uuid", @@ -66,11 +76,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "axum" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", "bytes", @@ -101,9 +117,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", @@ -199,6 +215,15 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -238,9 +263,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.49" +version = "1.2.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" +checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" dependencies = [ "find-msvc-tools", "jobserver", @@ -263,6 +288,20 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "chrono" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -330,6 +369,24 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossterm" version = "0.27.0" @@ -355,6 +412,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "defs" version = "0.1.0" @@ -363,6 +436,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -427,11 +510,23 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" [[package]] name = "fixedbitset" @@ -439,6 +534,16 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -475,6 +580,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -525,6 +640,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -777,6 +902,30 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -889,7 +1038,10 @@ checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5" name = "index" version = "0.1.0" dependencies = [ + "bincode", "defs", + "serde", + "storage", "uuid", ] @@ -948,9 +1100,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jobserver" @@ -1000,6 +1152,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "libredox" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +dependencies = [ + "bitflags 2.10.0", + "libc", + "redox_syscall 0.7.0", +] + [[package]] name = "librocksdb-sys" version = "0.11.0+8.1.1" @@ -1114,6 +1277,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -1181,6 +1345,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.37.3" @@ -1264,7 +1437,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", "windows-link", ] @@ -1481,6 +1654,15 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "redox_syscall" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" +dependencies = [ + "bitflags 2.10.0", +] + [[package]] name = "regex" version = "1.12.2" @@ -1512,9 +1694,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.26" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64", "bytes", @@ -1597,9 +1779,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags 2.10.0", "errno", @@ -1649,9 +1831,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "schannel" @@ -1691,6 +1873,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1723,15 +1911,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -1781,6 +1969,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1819,13 +2018,20 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.7" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "slab" version = "0.4.11" @@ -1838,6 +2044,26 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "snapshot" +version = "0.1.0" +dependencies = [ + "chrono", + "data-encoding", + "defs", + "flate2", + "fs2", + "index", + "semver", + "serde", + "serde_json", + "sha2", + "storage", + "tar", + "tempfile", + "uuid", +] + [[package]] name = "socket2" version = "0.6.1" @@ -1876,7 +2102,10 @@ version = "0.1.0" dependencies = [ "bincode", "defs", + "flate2", "rocksdb", + "serde", + "tar", "tempfile", "uuid", ] @@ -1961,11 +2190,22 @@ dependencies = [ "libc", ] +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tempfile" -version = "3.23.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", "getrandom 0.3.4", @@ -2275,6 +2515,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.8.1" @@ -2358,6 +2604,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "want" version = "0.3.1" @@ -2472,6 +2724,41 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" @@ -2741,6 +3028,16 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "yoke" version = "0.8.1" @@ -2824,6 +3121,12 @@ dependencies = [ "syn", ] +[[package]] +name = "zmij" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d6085d62852e35540689d1f97ad663e3971fc19cf5eceab364d62c646ea167" + [[package]] name = "zstd-sys" version = "2.0.16+zstd.1.5.7" diff --git a/Cargo.toml b/Cargo.toml index 091c725..cd35677 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/http", "crates/tui", "crates/grpc", + "crates/snapshot", ] [workspace.package] @@ -49,5 +50,6 @@ grpc = { path = "crates/grpc" } http = { path = "crates/http" } index = { path = "crates/index" } server = { path = "crates/server" } +snapshot = { path = "crates/snapshot" } storage = { path = "crates/storage" } tui = { path = "crates/tui" } diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 8ade9d8..b6e3e74 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -9,6 +9,7 @@ license.workspace = true [dependencies] defs.workspace = true index.workspace = true +snapshot.workspace = true storage.workspace = true tempfile.workspace = true uuid.workspace = true diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 72739ac..661a8d2 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,12 +1,15 @@ -use defs::{DbError, IndexedVector, Similarity}; +use defs::{DbError, IndexedVector, Similarity, SnapshottableDb}; use defs::{DenseVector, Payload, Point, PointId}; -use std::path::PathBuf; +use index::kd_tree::index::KDTree; +use std::path::{Path, PathBuf}; +use tempfile::tempdir; // use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, RwLock}; -use index::flat::FlatIndex; +use index::flat::index::FlatIndex; use index::{IndexType, VectorIndex}; +use snapshot::Snapshot; use storage::rocks_db::RocksDbStorage; use storage::{StorageEngine, StorageType, VectorPage}; @@ -131,6 +134,31 @@ impl VectorDb { } } +impl SnapshottableDb for VectorDb { + fn create_snapshot(&self, dir_path: &Path) -> Result { + if !dir_path.is_dir() { + return Err(DbError::SnapshotError(format!( + "Invalid path: {}", + dir_path.display() + ))); + } + + let index_snapshot = self + .index + .read() + .map_err(|_| DbError::LockError)? + .snapshot()?; + + let tempdir = tempdir().unwrap(); + let storage_checkpoint = self.storage.checkpoint_at(tempdir.path())?; + + let snapshot = Snapshot::new(index_snapshot, storage_checkpoint, self.dimension)?; + let snapshot_path = snapshot.save(dir_path)?; + + Ok(snapshot_path) + } +} + #[derive(Debug)] pub struct DbConfig { pub storage_type: StorageType, @@ -139,6 +167,28 @@ pub struct DbConfig { pub dimension: usize, } +#[derive(Debug)] +pub struct DbRestoreConfig { + pub data_path: PathBuf, + pub snapshot_path: PathBuf, +} + +impl DbRestoreConfig { + pub fn new(data_path: PathBuf, snapshot_path: PathBuf) -> Self { + Self { + data_path, + snapshot_path, + } + } +} + +pub fn restore_from_snapshot(config: &DbRestoreConfig) -> Result { + // restore the index from the snapshot + let (storage_engine, index, dimensions) = + Snapshot::load(&config.snapshot_path, &config.data_path)?; + Ok(VectorDb::_new(storage_engine, index, dimensions)) +} + pub fn init_api(config: DbConfig) -> Result { // Initialize the storage engine let storage = match config.storage_type { @@ -149,7 +199,8 @@ pub fn init_api(config: DbConfig) -> Result { // Initialize the vector index let index: Arc> = match config.index_type { IndexType::Flat => Arc::new(RwLock::new(FlatIndex::new())), - _ => Arc::new(RwLock::new(FlatIndex::new())), + IndexType::KDTree => Arc::new(RwLock::new(KDTree::build_empty(config.dimension))), + _ => Arc::new(RwLock::new(FlatIndex::new())), // TODO: add hnsw here }; // Init the db @@ -166,12 +217,15 @@ mod tests { // TODO: Add more exhaustive tests + use std::sync::Mutex; + use super::*; use defs::ContentType; - use tempfile::tempdir; + use snapshot::{engine::SnapshotEngine, registry::local::LocalRegistry}; + use tempfile::{TempDir, tempdir}; // Helper function to create a test database - fn create_test_db() -> VectorDb { + fn create_test_db() -> (VectorDb, TempDir) { let temp_dir = tempdir().unwrap(); let config = DbConfig { storage_type: StorageType::RocksDb, @@ -179,12 +233,12 @@ mod tests { data_path: temp_dir.path().to_path_buf(), dimension: 3, }; - init_api(config).unwrap() + (init_api(config).unwrap(), temp_dir) } #[test] fn test_insert_and_get() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); let vector = vec![1.0, 2.0, 3.0]; let payload = Payload { content_type: ContentType::Text, @@ -209,7 +263,7 @@ mod tests { #[test] fn test_dimension_mismatch() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); let v1 = vec![1.0, 2.0, 3.0]; let v2 = vec![1.0, 2.0]; let payload = defs::Payload { @@ -228,7 +282,7 @@ mod tests { #[test] fn test_delete() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); let vector = vec![1.0, 2.0, 3.0]; let payload = Payload { content_type: ContentType::Text, @@ -251,7 +305,7 @@ mod tests { #[test] fn test_search() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); // Insert some points let vectors = vec![ @@ -280,7 +334,7 @@ mod tests { #[test] fn test_search_limit() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); // Insert 5 points let mut ids = Vec::new(); @@ -307,7 +361,7 @@ mod tests { #[test] fn test_empty_database() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); // Get non-existent point assert!(db.get(Uuid::new_v4()).unwrap().is_none()); @@ -319,7 +373,7 @@ mod tests { #[test] fn test_list_vectors() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); // insert some points let mut ids = Vec::new(); for i in 0..10 { @@ -350,7 +404,7 @@ mod tests { #[test] fn test_build_index() { - let db = create_test_db(); + let (db, _temp_dir) = create_test_db(); // insert some points for i in 0..10 { @@ -370,4 +424,143 @@ mod tests { let inserted = db.build_index().unwrap(); assert_eq!(inserted, 10); } + + #[test] + fn test_create_and_load_snapshot() { + let (old_db, temp_dir) = create_test_db(); + + let v1 = vec![0.0, 1.0, 2.0]; + let v2 = vec![3.0, 4.0, 5.0]; + let v3 = vec![6.0, 7.0, 8.0]; + + let id1 = old_db + .insert( + v1.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let id2 = old_db + .insert( + v2.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let temp_snapshot_dir = tempdir().unwrap(); + let snapshot_path = old_db.create_snapshot(temp_snapshot_dir.path()).unwrap(); + + // insert v3 after snapshot + let id3 = old_db + .insert( + v3.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let reload_config = DbRestoreConfig { + data_path: temp_dir.path().to_path_buf(), + snapshot_path, + }; + + std::mem::drop(old_db); + let loaded_db = restore_from_snapshot(&reload_config).unwrap(); + + assert!(loaded_db.get(id1).unwrap_or(None).is_some()); + assert!(loaded_db.get(id2).unwrap_or(None).is_some()); + assert!(loaded_db.get(id3).unwrap_or(None).is_none()); // v3 was inserted after snapshot was taken + + // vector restore check + assert!(loaded_db.get(id1).unwrap().unwrap().vector.unwrap() == v1); + assert!(loaded_db.get(id2).unwrap().unwrap().vector.unwrap() == v2); + } + + #[test] + fn test_snapshot_engine() { + let (_db, _temp_dir) = create_test_db(); + let db = Arc::new(Mutex::new(_db)); + + let registry_tempdir = tempdir().unwrap(); + + let registry = Arc::new(Mutex::new( + LocalRegistry::new(registry_tempdir.path()).unwrap(), + )); + + let last_k = 4; + let mut se = SnapshotEngine::new(last_k, db.clone(), registry.clone()); + + let v1 = vec![0.0, 1.0, 2.0]; + let v2 = vec![3.0, 4.0, 5.0]; + let v3 = vec![6.0, 7.0, 8.0]; + + let test_vectors = vec![v1.clone(), v2.clone(), v3.clone()]; + let mut inserted_ids = Vec::new(); + + for (i, vector) in test_vectors.clone().into_iter().enumerate() { + se.snapshot().unwrap(); + let id = db + .lock() + .unwrap() + .insert( + vector.clone(), + Payload { + content_type: ContentType::Text, + content: format!("{}", i), + }, + ) + .unwrap(); + inserted_ids.push(id); + } + se.snapshot().unwrap(); + let snapshots = se.list_alive_snapshots().unwrap(); + + // asserting these cases: + // snapshot 0 : no vectors + // snapshot 1 : v1 + // snapshot 2 : v1, v2 + // snapshot 3 : v1, v2, v3 + + std::mem::drop(db); + std::mem::drop(se); + + for (i, snapshot) in snapshots.iter().enumerate() { + let temp_dir = tempdir().unwrap(); + let db = restore_from_snapshot(&DbRestoreConfig { + data_path: temp_dir.path().to_path_buf(), + snapshot_path: snapshot.path.clone(), + }) + .unwrap(); + for j in 0..i { + // test if point is present + assert!(db.get(inserted_ids[j]).unwrap_or(None).is_some()); + // test vector restore + assert!( + db.get(inserted_ids[j]).unwrap().unwrap().vector.unwrap() == test_vectors[j] + ); + // test payload restore + assert!( + db.get(inserted_ids[j]) + .unwrap() + .unwrap() + .payload + .unwrap() + .content + == format!("{}", j) + ); + } + for absent_id in inserted_ids.iter().skip(i) { + assert!(db.get(*absent_id).unwrap_or(None).is_none()); + } + std::mem::drop(db); + } + } } diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 8f929c3..5f35c1e 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -7,7 +7,16 @@ pub enum DbError { DeserializationError, IndexError(String), LockError, + IndexInitError, //TODO: Change this + UnsupportedSimilarity, DimensionMismatch, + SnapshotError(String), + StorageInitializationError, + StorageCheckpointError(String), + InvalidMagicBytes(String), + VectorNotFound(uuid::Uuid), + SnapshotRegistryError(String), + StorageEngineError(String), } #[derive(Debug)] diff --git a/crates/defs/src/lib.rs b/crates/defs/src/lib.rs index c2a79bf..074c352 100644 --- a/crates/defs/src/lib.rs +++ b/crates/defs/src/lib.rs @@ -1,6 +1,13 @@ pub mod error; pub mod types; +use std::path::{Path, PathBuf}; + // Without re-exports, users would need to write defs::types::SomeType instead of just defs::SomeType. Re-exports simplify the API by flattening the module hierarchy. The * means "everything public" from that module. pub use error::*; pub use types::*; + +// hoisted trait so it can be used by the snapshots crate +pub trait SnapshottableDb: Send + Sync { + fn create_snapshot(&self, dir_path: &Path) -> Result; +} diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index ae69f17..faccf92 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -51,6 +51,8 @@ pub enum Similarity { Cosine, } +pub type Magic = [u8; 4]; + // Struct which stores the distance between a vector and query vector and implements ordering traits #[derive(Copy, Clone)] pub struct DistanceOrderedVector<'q> { diff --git a/crates/index/Cargo.toml b/crates/index/Cargo.toml index 35f9957..a8c9ca9 100644 --- a/crates/index/Cargo.toml +++ b/crates/index/Cargo.toml @@ -7,5 +7,8 @@ edition.workspace = true license.workspace = true [dependencies] +bincode.workspace = true defs.workspace = true +serde.workspace = true +storage.workspace = true uuid.workspace = true diff --git a/crates/index/src/flat.rs b/crates/index/src/flat.rs deleted file mode 100644 index a3a62bf..0000000 --- a/crates/index/src/flat.rs +++ /dev/null @@ -1,270 +0,0 @@ -use defs::{DbError, DenseVector, DistanceOrderedVector, IndexedVector, PointId, Similarity}; - -use crate::{VectorIndex, distance}; - -pub struct FlatIndex { - index: Vec, -} - -impl FlatIndex { - pub fn new() -> Self { - Self { index: Vec::new() } - } - - pub fn build(vectors: Vec) -> Self { - FlatIndex { index: vectors } - } -} - -impl Default for FlatIndex { - fn default() -> Self { - Self::new() - } -} - -impl VectorIndex for FlatIndex { - fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { - self.index.push(vector); - Ok(()) - } - - fn delete(&mut self, point_id: PointId) -> Result { - if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) { - self.index.remove(pos); - Ok(true) - } else { - Ok(false) - } - } - - fn search( - &self, - query_vector: DenseVector, - similarity: Similarity, - k: usize, - ) -> Result, DbError> { - let scores = self - .index - .iter() - .map(|point| DistanceOrderedVector { - distance: distance(point.vector.clone(), query_vector.clone(), similarity), - query_vector: &query_vector, - point_id: Some(point.id), - }) - .collect::>(); - - // select k smallest elements in scores using a max heap - let mut heap = std::collections::BinaryHeap::::new(); - for score in scores { - if heap.len() < k { - heap.push(score); - } else if score < *heap.peek().unwrap() { - heap.pop(); - heap.push(score); - } - } - Ok(heap - .into_sorted_vec() - .into_iter() - .map(|v| v.point_id.unwrap()) - .collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use uuid::Uuid; - - #[test] - fn test_flat_index_new() { - let index = FlatIndex::new(); - assert_eq!(index.index.len(), 0); - } - - #[test] - fn test_flat_index_build() { - let vectors = vec![ - IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }, - IndexedVector { - id: Uuid::new_v4(), - vector: vec![4.0, 5.0, 6.0], - }, - ]; - let index = FlatIndex::build(vectors.clone()); - assert_eq!(index.index, vectors); - } - - #[test] - fn test_insert() { - let mut index = FlatIndex::new(); - let vector = IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }; - - assert!(index.insert(vector.clone()).is_ok()); - assert_eq!(index.index.len(), 1); - assert_eq!(index.index[0], vector); - } - - #[test] - fn test_delete_existing() { - let mut index = FlatIndex::new(); - let existing_id = Uuid::new_v4(); - let vector = IndexedVector { - id: existing_id, - vector: vec![1.0, 2.0, 3.0], - }; - index.insert(vector).unwrap(); - - let result = index.delete(existing_id).unwrap(); - assert!(result); - assert_eq!(index.index.len(), 0); - } - - #[test] - fn test_delete_non_existing() { - let mut index = FlatIndex::new(); - let vector = IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }; - index.insert(vector).unwrap(); - - let result = index.delete(Uuid::new_v4()).unwrap(); - assert!(!result); - assert_eq!(index.index.len(), 1); - } - - #[test] - fn test_search_euclidean() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![2.0, 2.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![10.0, 10.0], - }) - .unwrap(); - - let results = index - .search(vec![0.0, 0.0], Similarity::Euclidean, 2) - .unwrap(); - assert_eq!(results, vec![id1, id2]); - } - - #[test] - fn test_search_cosine() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 0.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![0.5, 0.5], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![0.0, 1.0], - }) - .unwrap(); - - let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap(); - assert_eq!(results, vec![id2, id1]); - } - - #[test] - fn test_search_manhattan() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![2.0, 2.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![5.0, 5.0], - }) - .unwrap(); - - let results = index - .search(vec![0.0, 0.0], Similarity::Manhattan, 2) - .unwrap(); - assert_eq!(results, vec![id1, id2]); - } - - #[test] - fn test_search_hamming() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 0.0, 1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![1.0, 0.0, 0.0, 0.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![0.0, 0.0, 0.0, 0.0], - }) - .unwrap(); - - let results = index - .search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2) - .unwrap(); - assert_eq!(results, vec![id2, id3]); - } - - #[test] - fn test_default() { - let index = FlatIndex::default(); - assert_eq!(index.index.len(), 0); - } -} diff --git a/crates/index/src/flat/index.rs b/crates/index/src/flat/index.rs new file mode 100644 index 0000000..87f814f --- /dev/null +++ b/crates/index/src/flat/index.rs @@ -0,0 +1,71 @@ +use crate::{VectorIndex, distance}; +use defs::{DbError, DenseVector, DistanceOrderedVector, IndexedVector, PointId, Similarity}; + +pub struct FlatIndex { + pub index: Vec, +} + +impl FlatIndex { + pub fn new() -> Self { + Self { index: Vec::new() } + } + + pub fn build(vectors: Vec) -> Self { + FlatIndex { index: vectors } + } +} + +impl Default for FlatIndex { + fn default() -> Self { + Self::new() + } +} + +impl VectorIndex for FlatIndex { + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { + self.index.push(vector); + Ok(()) + } + + fn delete(&mut self, point_id: PointId) -> Result { + if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) { + self.index.remove(pos); + Ok(true) + } else { + Ok(false) + } + } + + fn search( + &self, + query_vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result, DbError> { + let scores = self + .index + .iter() + .map(|point| DistanceOrderedVector { + distance: distance(&point.vector, &query_vector, similarity), + query_vector: &query_vector, + point_id: Some(point.id), + }) + .collect::>(); + + // select k smallest elements in scores using a max heap + let mut heap = std::collections::BinaryHeap::::new(); + for score in scores { + if heap.len() < k { + heap.push(score); + } else if score < *heap.peek().unwrap() { + heap.pop(); + heap.push(score); + } + } + Ok(heap + .into_sorted_vec() + .into_iter() + .map(|v| v.point_id.unwrap()) + .collect()) + } +} diff --git a/crates/index/src/flat/mod.rs b/crates/index/src/flat/mod.rs new file mode 100644 index 0000000..5e3f726 --- /dev/null +++ b/crates/index/src/flat/mod.rs @@ -0,0 +1,9 @@ +use defs::Magic; + +pub mod index; +mod serialize; + +#[cfg(test)] +mod tests; + +pub const FLAT_MAGIC_BYTES: Magic = [0x00, 0x00, 0x00, 0x01]; diff --git a/crates/index/src/flat/serialize.rs b/crates/index/src/flat/serialize.rs new file mode 100644 index 0000000..32af97e --- /dev/null +++ b/crates/index/src/flat/serialize.rs @@ -0,0 +1,108 @@ +use super::FLAT_MAGIC_BYTES; +use crate::IndexType; +use crate::flat::index::FlatIndex; +use crate::{IndexSnapshot, SerializableIndex}; +use defs::{DbError, IndexedVector}; +use serde::{Deserialize, Serialize}; +use std::io::{Cursor, Read}; +use storage::StorageEngine; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlatIndexMetadata { + total_points: usize, +} + +impl FlatIndex { + pub fn deserialize( + IndexSnapshot { + index_type, + magic, + topology_b, + metadata_b, + }: &IndexSnapshot, + ) -> Result { + if index_type != &IndexType::Flat { + return Err(DbError::SerializationError( + "Invalid index type".to_string(), + )); + } + + if magic != &FLAT_MAGIC_BYTES { + return Err(DbError::SerializationError( + "Invalid magic bytes".to_string(), + )); + } + + let metadata: FlatIndexMetadata = bincode::deserialize(metadata_b).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize FlatIndex Metadata: {}", e)) + })?; + let total_points = metadata.total_points; + + let mut cursor = Cursor::new(topology_b); + let mut vectors = Vec::new(); + + for _ in 0..total_points { + let mut uuid_slice = [0u8; 16]; + cursor.read_exact(&mut uuid_slice).map_err(|e| { + DbError::SerializationError(format!( + "Failed to deserialize FlatIndex Topology: {}", + e + )) + })?; + let id = Uuid::from_bytes_le(uuid_slice); + vectors.push(IndexedVector { + id, + vector: Vec::new(), + }); + } + + Ok(FlatIndex { index: vectors }) + } +} + +impl SerializableIndex for FlatIndex { + fn serialize_topology(&self) -> Result, DbError> { + let mut buffer: Vec = Vec::new(); + for point in &self.index { + buffer.extend_from_slice(&point.id.to_bytes_le()); + } + + Ok(buffer) + } + + fn serialize_metadata(&self) -> Result, DbError> { + let mut buffer: Vec = Vec::new(); + let metadata = FlatIndexMetadata { + total_points: self.index.len(), + }; + + let metadata_bytes = bincode::serialize(&metadata).map_err(|e| { + DbError::SerializationError(format!("Failed to serialize FlatIndex Metadata: {}", e)) + })?; + + buffer.extend_from_slice(&metadata_bytes); + Ok(buffer) + } + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { + for item in &mut self.index { + item.vector = storage + .get_vector(item.id)? + .ok_or(DbError::VectorNotFound(item.id))?; + } + Ok(()) + } + + fn snapshot(&self) -> Result { + let topology = self.serialize_topology()?; + let metadata = self.serialize_metadata()?; + + Ok(IndexSnapshot { + metadata_b: metadata, + topology_b: topology, + magic: FLAT_MAGIC_BYTES, + index_type: IndexType::Flat, + }) + } +} diff --git a/crates/index/src/flat/tests.rs b/crates/index/src/flat/tests.rs new file mode 100644 index 0000000..6d43c3d --- /dev/null +++ b/crates/index/src/flat/tests.rs @@ -0,0 +1,238 @@ +use super::index::FlatIndex; +use crate::{SerializableIndex, VectorIndex}; +use defs::{IndexedVector, Similarity}; +use uuid::Uuid; + +#[test] +fn test_flat_index_new() { + let index = FlatIndex::new(); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_flat_index_build() { + let vectors = vec![ + IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }, + IndexedVector { + id: Uuid::new_v4(), + vector: vec![4.0, 5.0, 6.0], + }, + ]; + let index = FlatIndex::build(vectors.clone()); + assert_eq!(index.index, vectors); +} + +#[test] +fn test_insert() { + let mut index = FlatIndex::new(); + let vector = IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }; + + assert!(index.insert(vector.clone()).is_ok()); + assert_eq!(index.index.len(), 1); + assert_eq!(index.index[0], vector); +} + +#[test] +fn test_delete_existing() { + let mut index = FlatIndex::new(); + let existing_id = Uuid::new_v4(); + let vector = IndexedVector { + id: existing_id, + vector: vec![1.0, 2.0, 3.0], + }; + index.insert(vector).unwrap(); + + let result = index.delete(existing_id).unwrap(); + assert!(result); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_delete_non_existing() { + let mut index = FlatIndex::new(); + let vector = IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }; + index.insert(vector).unwrap(); + + let result = index.delete(Uuid::new_v4()).unwrap(); + assert!(!result); + assert_eq!(index.index.len(), 1); +} + +#[test] +fn test_search_euclidean() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![2.0, 2.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![10.0, 10.0], + }) + .unwrap(); + + let results = index + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results, vec![id1, id2]); +} + +#[test] +fn test_search_cosine() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 0.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![0.5, 0.5], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![0.0, 1.0], + }) + .unwrap(); + + let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap(); + assert_eq!(results, vec![id2, id1]); +} + +#[test] +fn test_search_manhattan() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![2.0, 2.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![5.0, 5.0], + }) + .unwrap(); + + let results = index + .search(vec![0.0, 0.0], Similarity::Manhattan, 2) + .unwrap(); + assert_eq!(results, vec![id1, id2]); +} + +#[test] +fn test_search_hamming() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 0.0, 1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![1.0, 0.0, 0.0, 0.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![0.0, 0.0, 0.0, 0.0], + }) + .unwrap(); + + let results = index + .search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2) + .unwrap(); + assert_eq!(results, vec![id2, id3]); +} + +#[test] +fn test_default() { + let index = FlatIndex::default(); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_serialize_and_deserialize_topo() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + + let v1 = IndexedVector { + id: id1, + vector: vec![0.0, 0.0, 0.0, 0.0], + }; + let v2 = IndexedVector { + id: id2, + vector: vec![1.0, 0.0, 0.0, 0.0], + }; + let v3 = IndexedVector { + id: id3, + vector: vec![2.0, 0.0, 0.0, 0.0], + }; + let v4 = IndexedVector { + id: id4, + vector: vec![3.0, 0.0, 0.0, 0.0], + }; + + let vectors = vec![v1.clone(), v2.clone(), v3.clone(), v4.clone()]; + let mut index_before = FlatIndex::build(vectors); + index_before.insert(v4.clone()).unwrap(); + + index_before.delete(id1).unwrap(); + + let snapshot = index_before.snapshot().unwrap(); + + let idx = FlatIndex::deserialize(&snapshot).unwrap(); + + assert_eq!(idx.index.len(), 4); + assert!(!idx.index.iter().any(|v| v.id == id1)); + assert!(idx.index.iter().any(|v| v.id == id2)); + assert!(idx.index.iter().any(|v| v.id == id3)); + assert!(idx.index.iter().any(|v| v.id == id3)); + assert!(idx.index.iter().any(|v| v.id == id4)); +} diff --git a/crates/index/src/kd_tree.rs b/crates/index/src/kd_tree.rs deleted file mode 100644 index 444f578..0000000 --- a/crates/index/src/kd_tree.rs +++ /dev/null @@ -1,252 +0,0 @@ -use std::cmp::Ordering; -use std::cmp::Ordering::Less; - -use serde_derive::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize)] -pub struct KDTreeInternals { - pub kd_tree_allow_update: bool, - pub current_number_of_kd_tree_nodes: usize, - pub rebuild_threshold: f32, - pub previous_tree_size: usize, - pub rebuild_counter: usize, -} - -#[derive(Serialize, Deserialize)] -pub struct KDTreeNode { - pub left: Option>, - pub right: Option>, - pub key: String, - pub vector: Vec, - pub dim: usize, -} - -impl KDTreeNode { - // Add the logic here to create a new db and insert the tree into the database - fn new(data: (String, Vec), dim: usize) -> KDTreeNode { - KDTreeNode { - left: None, - right: None, - key: data.0, - vector: data.1, - dim, - } - } -} - -pub struct KDTree { - pub _root: Option>, - pub _internals: KDTreeInternals, - pub is_debug_run: bool, - pub dim: usize, -} - -impl KDTree { - // Create an empty tree with default values - pub fn new() -> KDTree { - KDTree { - _root: None, - _internals: KDTreeInternals { - kd_tree_allow_update: true, - current_number_of_kd_tree_nodes: 0, - rebuild_threshold: 2.0f32, - previous_tree_size: 0, - rebuild_counter: 0, - }, - is_debug_run: true, - dim: 0, - } - } - - // Add a node - // If the dimension of the tree is zero, then it becomes equal to the input data - pub fn add_node(&mut self, data: (String, Vec), depth: usize) { - if self._root.is_none() { - self.dim = data.1.len(); - self._root = Some(Box::new(KDTreeNode::new(data, 0))); - self._internals.current_number_of_kd_tree_nodes += 1; - return; - } - - assert_eq!(self.dim, data.1.len()); - - if !self._internals.kd_tree_allow_update { - println!("KDTree is locked for rebuild"); - return; - } - - if self._internals.previous_tree_size != 0 { - let current_ratio: f32 = self._internals.current_number_of_kd_tree_nodes as f32 - / self._internals.previous_tree_size as f32; - if current_ratio > self._internals.rebuild_threshold { - self._internals.previous_tree_size = - self._internals.current_number_of_kd_tree_nodes; - self.rebuild(); - } - } else { - self._internals.previous_tree_size = self._internals.current_number_of_kd_tree_nodes; - } - - self._internals.current_number_of_kd_tree_nodes += 1; - - let mut current_node = self._root.as_deref_mut().unwrap(); - let mut current_depth = depth; - loop { - let current_dimension = current_depth % self.dim; - if data.1[current_dimension] < current_node.vector[current_dimension] { - if current_node.left.is_none() { - current_node.left = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.left.as_deref_mut().unwrap(); - current_depth += 1; - } - } else { - if current_node.right.is_none() { - current_node.right = Some(Box::new(KDTreeNode::new(data, current_dimension))); - break; - } else { - current_node = current_node.right.as_deref_mut().unwrap(); - current_depth += 1; - } - } - } - } - - // rebuild tree - fn rebuild(&mut self) { - self._internals.kd_tree_allow_update = false; - self._internals.rebuild_counter += 1; - if self.is_debug_run { - println!( - "Rebuilding tree..., Rebuild counter: {:?}", - self._internals.rebuild_counter - ); - } - let mut points = Vec::into_boxed_slice(self.traversal(0)); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; - } - - // traversal - pub fn traversal(&self, k_value: usize) -> Vec<(String, Vec)> { - let mut result: Vec<(String, Vec)> = Vec::new(); - inorder_traversal_helper(self._root.as_deref(), &mut result, k_value); - result - } - - // delete a node - pub fn delete_node(&mut self, data: String) { - self._internals.kd_tree_allow_update = false; - let mut points = self.traversal(0); - let index = points.iter().position(|x| *x.0 == data).unwrap(); - points.remove(index); - let mut points = Vec::into_boxed_slice(points); - self._root = Some(Box::new(create_tree_helper(points.as_mut(), 0))); - self._internals.kd_tree_allow_update = true; - } - - // print data for debug - pub fn print_tree_for_debug(&self) { - let iterated: Vec<(String, Vec)> = self.traversal(0); - for iter in iterated { - println!("{}", iter.0); - } - } - - // different methods of knn -} - -// Traversal helper function -fn inorder_traversal_helper( - node: Option<&KDTreeNode>, - result: &mut Vec<(String, Vec)>, - k_value: usize, -) -> Option { - if node.is_none() { - return None; - } - if k_value != 0 && k_value <= result.len() { - return None; - } - let current_node = node.unwrap(); - inorder_traversal_helper(current_node.to_owned().left.as_deref(), result, k_value); - result.push((current_node.key.clone(), current_node.vector.clone())); - inorder_traversal_helper(current_node.to_owned().right.as_deref(), result, k_value); - - Some(true) -} - -// Rebuild tree helper functions -fn create_tree_helper(points: &mut [(String, Vec)], dim: usize) -> KDTreeNode { - let points_len = points.len(); - if points_len == 1 { - return KDTreeNode { - key: points[0].0.clone(), - vector: points[0].1.clone(), - left: None, - right: None, - dim, - }; - } - - // Split around the median - let pivot = quickselect_by(points, points_len / 2, &|a, b| { - a.1[dim].partial_cmp(&b.1[dim]).unwrap() - }); - - let left = Some(Box::new(create_tree_helper( - &mut points[0..points_len / 2], - (dim + 1) % pivot.1.len(), - ))); - let right = if points.len() >= 3 { - Some(Box::new(create_tree_helper( - &mut points[points_len / 2 + 1..points_len], - (dim + 1) % pivot.1.len(), - ))) - } else { - None - }; - - KDTreeNode { - key: pivot.0, - vector: pivot.1, - left, - right, - dim, - } -} - -fn quickselect_by(arr: &mut [T], position: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> T -where - T: Clone, -{ - let mut pivot_index = 0; - // Need to wrap in another closure or we get ownership complaints. - // Tried using an unboxed closure to get around this but couldn't get it to work. - pivot_index = partition_by(arr, pivot_index, &|a: &T, b: &T| cmp(a, b)); - let array_len = arr.len(); - match position.cmp(&pivot_index) { - Ordering::Equal => arr[position].clone(), - Ordering::Less => quickselect_by(&mut arr[0..pivot_index], position, cmp), - Ordering::Greater => quickselect_by( - &mut arr[pivot_index + 1..array_len], - position - pivot_index - 1, - cmp, - ), - } -} - -fn partition_by(arr: &mut [T], pivot_index: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> usize { - let array_len = arr.len(); - arr.swap(pivot_index, array_len - 1); - let mut store_index = 0; - for i in 0..array_len - 1 { - if cmp(&arr[i], &arr[array_len - 1]) == Less { - arr.swap(i, store_index); - store_index += 1; - } - } - arr.swap(array_len - 1, store_index); - store_index -} diff --git a/crates/index/src/kd_tree/index.rs b/crates/index/src/kd_tree/index.rs new file mode 100644 index 0000000..dc623a2 --- /dev/null +++ b/crates/index/src/kd_tree/index.rs @@ -0,0 +1,466 @@ +use super::types::{KDTreeNode, Neighbor}; +use crate::{VectorIndex, distance}; +use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashSet}, + vec, +}; +use uuid::Uuid; + +pub struct KDTree { + pub dim: usize, + pub root: Option>, + // In memory point ids, to check existence before O(n) deletion logic + pub point_ids: HashSet, + // Rebuild tracking + pub total_nodes: usize, + pub deleted_count: usize, +} + +impl KDTree { + // Rebuild threshold + const BALANCE_THRESHOLD: f32 = 0.7; + const DELETE_REBUILD_RATIO: f32 = 0.25; + + // Build an empty index with no points + pub fn build_empty(dim: usize) -> Self { + KDTree { + dim, + root: None, + point_ids: HashSet::new(), + total_nodes: 0, + deleted_count: 0, + } + } + + // Builds the vector index from provided vectors, there should atleast be single vector for dim calculation + pub fn build(mut vectors: Vec) -> Result { + if vectors.is_empty() { + Err(DbError::IndexInitError) + } else { + let dim = vectors[0].vector.len(); + + let mut point_ids = HashSet::with_capacity(vectors.len()); + for indexed_vector in vectors.iter() { + point_ids.insert(indexed_vector.id); + } + + let root_node = Self::build_recursive(&mut vectors, 0, dim); + Ok(KDTree { + dim, + root: Some(root_node), + point_ids, + total_nodes: vectors.len(), + deleted_count: 0, + }) + } + } + + // Builds the tree recursively with given vectors and returns the pointer of the root node + pub fn build_recursive( + vectors: &mut [IndexedVector], + depth: usize, + dim: usize, + ) -> Box { + if vectors.is_empty() { + panic!("Cannot build from an empty slice recursively"); + } + + let axis = depth % dim; + let mid_idx = vectors.len() / 2; + + vectors.select_nth_unstable_by(mid_idx, |a, b| { + let a_at_axis = a.vector[axis]; + let b_at_axis = b.vector[axis]; + a_at_axis.partial_cmp(&b_at_axis).unwrap_or(Ordering::Equal) + }); + + // Using swap so that we don't need to clone the whole vector + let mut median_vec = IndexedVector { + id: Uuid::new_v4(), + vector: vec![], + }; // dummy + std::mem::swap(&mut vectors[mid_idx], &mut median_vec); + + let (left_points, right_points_with_median) = vectors.split_at_mut(mid_idx); + let right_points = &mut right_points_with_median[1..]; // Exclude the swapped-out median + + let left = if left_points.is_empty() { + None + } else { + Some(Self::build_recursive(left_points, depth + 1, dim)) + }; + + let right = if right_points.is_empty() { + None + } else { + Some(Self::build_recursive(right_points, depth + 1, dim)) + }; + + let left_size = left.as_ref().map_or(0, |n| n.subtree_size); + let right_size = right.as_ref().map_or(0, |n| n.subtree_size); + + Box::new(KDTreeNode { + indexed_vector: median_vec, + left, + right, + is_deleted: false, + subtree_size: left_size + right_size + 1, + }) + } + + pub fn insert_point(&mut self, new_vector: IndexedVector) { + // Add to point_ids + self.point_ids.insert(new_vector.id); + self.total_nodes += 1; + + // use a traverse function to get the final leaf where this belongs + if self.root.is_none() { + self.root = Some(Box::new(KDTreeNode { + indexed_vector: new_vector, + left: None, + right: None, + is_deleted: false, + subtree_size: 1, + })); + return; + } + + let mut path: Vec<(usize, bool)> = Vec::new(); + let dim = self.dim; + + let mut current_link = &mut self.root; + let mut depth = 0; + + while let Some(node_box) = current_link { + let axis = depth % dim; + let current_node = node_box.as_mut(); + + current_node.subtree_size += 1; + + let va = new_vector.vector[axis]; + let vb = current_node.indexed_vector.vector[axis]; + + let go_left = va <= vb; + path.push((depth, go_left)); + + if go_left { + current_link = &mut current_node.left; + } else { + current_link = &mut current_node.right; + } + depth += 1; + } + + // Assign the new node to current link which is &mut Option> + let new_node = Box::new(KDTreeNode { + indexed_vector: new_vector, + left: None, + right: None, + is_deleted: false, + subtree_size: 1, + }); + + *current_link = Some(new_node); + + self.check_and_rebalance(&path); + } + + // Rebuild helper methods + fn is_unbalanced(node: &KDTreeNode) -> bool { + let left_size = node.left.as_ref().map_or(0, |n| n.subtree_size); + let right_size = node.right.as_ref().map_or(0, |n| n.subtree_size); + let max_child = left_size.max(right_size); + + max_child as f32 > Self::BALANCE_THRESHOLD * node.subtree_size as f32 + } + + fn collect_recursive(node: KDTreeNode, result: &mut Vec) { + if !node.is_deleted { + result.push(node.indexed_vector); + } + if let Some(left) = node.left { + Self::collect_recursive(*left, result); + } + if let Some(right) = node.right { + Self::collect_recursive(*right, result); + } + } + + fn collect_active_vectors(node: KDTreeNode) -> Vec { + let mut result = Vec::with_capacity(node.subtree_size); + Self::collect_recursive(node, &mut result); + result + } + + fn rebuild_at_depth(&mut self, path: &[(usize, bool)], target_depth: usize) { + let dim = self.dim; + + // Navigate to parent of target node + if target_depth == 0 { + // Rebuild root + if let Some(root) = self.root.take() { + let old_size = root.subtree_size; + let mut vectors = Self::collect_active_vectors(*root); + let new_size = vectors.len(); + if !vectors.is_empty() { + self.root = Some(Self::build_recursive(&mut vectors, 0, dim)); + } + // Update global counts as deleted nodes were purged + self.total_nodes -= old_size - new_size; + self.deleted_count = 0; + } + } else { + // Navigate to target node + let mut current_link = &mut self.root; + for (_depth, go_left) in path.iter().take(target_depth) { + let node = current_link.as_mut().unwrap(); + current_link = if *go_left { + &mut node.left + } else { + &mut node.right + }; + } + + // Rebuild tree at current link + if let Some(subtree_root) = current_link.take() { + let old_size = subtree_root.subtree_size; + let mut vectors = Self::collect_active_vectors(*subtree_root); + let new_size = vectors.len(); + + if !vectors.is_empty() { + *current_link = Some(Self::build_recursive(&mut vectors, target_depth, dim)); + } + + // Only update ancestors if size changed (deleted nodes were purged) + if old_size != new_size { + let size_diff = old_size - new_size; + self.subtract_size_from_ancestors(path, target_depth, size_diff); + + self.total_nodes -= size_diff; + self.deleted_count = self.deleted_count.saturating_sub(size_diff); + } + } + } + } + + fn subtract_size_from_ancestors( + &mut self, + path: &[(usize, bool)], + up_to_depth: usize, + diff: usize, + ) { + let mut current = &mut self.root; + for (_, go_left) in path.iter().take(up_to_depth) { + if let Some(node) = current { + node.subtree_size -= diff; + current = if *go_left { + &mut node.left + } else { + &mut node.right + }; + } + } + } + + fn check_and_rebalance(&mut self, path: &[(usize, bool)]) { + // Find the shallowest (closest to root) depth where imbalance occurs + // so that rebuilding fixes the largest unbalanced subtree + let mut unbalanced_depth: Option = None; + + let mut current = self.root.as_ref(); + + // Check root first (depth 0) + if let Some(node) = current + && Self::is_unbalanced(node) + { + unbalanced_depth = Some(0); + } + + // Then traverse the path and check each node + // Once we find the shallowest unbalanced node, break immediately + for (idx, (_depth, go_left)) in path.iter().enumerate() { + if unbalanced_depth.is_some() { + break; + } + + if let Some(node) = current { + current = if *go_left { + node.left.as_ref() + } else { + node.right.as_ref() + }; + + // Check the child node we just moved to (at depth idx + 1) + if let Some(child) = current + && Self::is_unbalanced(child) + { + unbalanced_depth = Some(idx + 1); + break; + } + } + } + + if let Some(target_depth) = unbalanced_depth { + self.rebuild_at_depth(path, target_depth); + } + } + + fn should_rebuild_global(&self) -> bool { + self.total_nodes > 0 + && (self.deleted_count as f32 / self.total_nodes as f32) > Self::DELETE_REBUILD_RATIO + } + + // Returns true if point found and deleted, else false + pub fn delete_point(&mut self, point_id: &PointId) -> bool { + if self.point_ids.contains(point_id) { + let deleted = Self::find_and_mark_deleted(&mut self.root, *point_id); + if deleted { + self.deleted_count += 1; + self.point_ids.remove(point_id); + } + + if Self::should_rebuild_global(self) + && let Some(root) = self.root.take() + { + let mut vectors = Self::collect_active_vectors(*root); + if !vectors.is_empty() { + self.root = Some(Self::build_recursive(&mut vectors, 0, self.dim)); + } + + self.total_nodes = vectors.len(); + self.deleted_count = 0; + } + + return deleted; + } + false + } + + fn find_and_mark_deleted(node_opt: &mut Option>, target_id: PointId) -> bool { + if let Some(node) = node_opt { + if node.indexed_vector.id == target_id { + node.is_deleted = true; + return true; + } + + // Search left first then right + Self::find_and_mark_deleted(&mut node.left, target_id) + || Self::find_and_mark_deleted(&mut node.right, target_id) + } else { + false + } + } + + pub fn search_top_k( + &self, + query_vector: DenseVector, + k: usize, + dist_type: Similarity, + ) -> Vec<(PointId, f32)> { + //Searches for top k closest vectors according to specified metric + + if self.root.is_none() || k == 0 { + return Vec::new(); + } + + let mut best_neighbours = BinaryHeap::with_capacity(k); + + self.search_recursive( + &self.root, + &query_vector, + k, + &mut best_neighbours, + 0, + dist_type, + ); + + best_neighbours + .into_sorted_vec() + .iter() + .map(|neighbor| (neighbor.id, neighbor.distance)) + .collect() + } + + fn search_recursive( + &self, + node_opt: &Option>, + query_vector: &DenseVector, + k: usize, + heap: &mut BinaryHeap, + depth: usize, + dist_type: Similarity, + ) { + // Base case is that we hit a leaf node don't do anything + if let Some(node) = node_opt { + let axis = depth % self.dim; + + let (near_side, far_side) = if query_vector[axis] <= node.indexed_vector.vector[axis] { + (&node.left, &node.right) + } else { + (&node.right, &node.left) + }; + + // Recurse on near side first + self.search_recursive(near_side, query_vector, k, heap, depth + 1, dist_type); + + if !node.is_deleted { + // TODO: Possible overhead, here heap stores sqrt euclidean distance, we can eliminate that by storing squared distances in case of euclidean + let distance = distance(query_vector, &node.indexed_vector.vector, dist_type); + if heap.len() < k { + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } else if distance < heap.peek().unwrap().distance { + heap.pop(); + heap.push(Neighbor { + id: node.indexed_vector.id, + distance, + }); + } + } + + // Pruning on the farther side to check if there are better candidates + // Use <= to handle ties: when axis_diff == current worst distance, there could be + // a point on the far side with the same distance that should be included + let axis_diff = (query_vector[axis] - node.indexed_vector.vector[axis]).abs(); + let should_search_far = match dist_type { + Similarity::Euclidean | Similarity::Manhattan => { + heap.len() < k || axis_diff <= heap.peek().unwrap().distance + } + _ => true, // Cosine/Hamming - no effective pruning, always search + }; + + if should_search_far { + self.search_recursive(far_side, query_vector, k, heap, depth + 1, dist_type); + } + } + } +} + +impl VectorIndex for KDTree { + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { + self.insert_point(vector); + Ok(()) + } + + fn delete(&mut self, point_id: PointId) -> Result { + Ok(self.delete_point(&point_id)) + } + + fn search( + &self, + query_vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result, DbError> { + if matches!(similarity, Similarity::Cosine | Similarity::Hamming) { + return Err(DbError::UnsupportedSimilarity); + } + + let results = self.search_top_k(query_vector, k, similarity); + Ok(results.into_iter().map(|(id, _)| id).collect()) + } +} diff --git a/crates/index/src/kd_tree/mod.rs b/crates/index/src/kd_tree/mod.rs new file mode 100644 index 0000000..6ff5fb0 --- /dev/null +++ b/crates/index/src/kd_tree/mod.rs @@ -0,0 +1,10 @@ +use defs::Magic; + +pub mod index; +mod serialize; +pub mod types; + +#[cfg(test)] +mod tests; + +pub const KD_TREE_MAGIC_BYTES: Magic = [0x00, 0x00, 0x00, 0x00]; diff --git a/crates/index/src/kd_tree/serialize.rs b/crates/index/src/kd_tree/serialize.rs new file mode 100644 index 0000000..99ccc20 --- /dev/null +++ b/crates/index/src/kd_tree/serialize.rs @@ -0,0 +1,200 @@ +use std::collections::HashSet; +use std::io::{Cursor, Read, Write}; + +use super::KD_TREE_MAGIC_BYTES; +use super::index::KDTree; +use super::types::KDTreeNode; +use crate::{IndexSnapshot, IndexType, SerializableIndex}; +use bincode; +use defs::{DbError, IndexedVector, PointId}; +use serde::{Deserialize, Serialize}; +use storage::StorageEngine; +use uuid::Uuid; + +#[derive(Serialize, Deserialize)] +pub struct KDTreeMetadata { + pub dim: usize, + pub total_nodes: usize, + pub deleted_count: usize, +} + +impl SerializableIndex for KDTree { + fn serialize_topology(&self) -> Result, DbError> { + let mut buffer = Vec::new(); + let mut cursor = Cursor::new(&mut buffer); + serialize_topology_recursive(&self.root, &mut cursor)?; + Ok(buffer) + } + + fn serialize_metadata(&self) -> Result, DbError> { + let mut buffer = Vec::new(); + let km = KDTreeMetadata { + dim: self.dim, + total_nodes: self.total_nodes, + deleted_count: self.deleted_count, + }; + let metadata_bytes = bincode::serialize(&km).map_err(|e| { + DbError::SerializationError(format!("Failed to serailize KD Tree Metadata: {}", e)) + })?; + buffer.extend_from_slice(metadata_bytes.as_slice()); + Ok(buffer) + } + + fn snapshot(&self) -> Result { + let topology_bytes = self.serialize_topology()?; + let metadata_bytes = self.serialize_metadata()?; + Ok(IndexSnapshot { + index_type: crate::IndexType::KDTree, + magic: KD_TREE_MAGIC_BYTES, + topology_b: topology_bytes, + metadata_b: metadata_bytes, + }) + } + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { + populate_vectors_recursive(&mut self.root, storage)?; + Ok(()) + } +} + +const NODE_MARKER_BYTE: u8 = 1u8; +const SKIP_MARKER_BYTE: u8 = 0u8; + +const DELETED_MASK: u8 = 2u8; + +impl KDTree { + pub fn deserialize( + IndexSnapshot { + index_type, + magic, + topology_b, + metadata_b, + }: &IndexSnapshot, + ) -> Result { + if index_type != &IndexType::KDTree { + return Err(DbError::SerializationError( + "Invalid index type".to_string(), + )); + } + + if magic != &KD_TREE_MAGIC_BYTES { + return Err(DbError::SerializationError( + "Invalid magic bytes".to_string(), + )); + } + + let metadata: KDTreeMetadata = + bincode::deserialize(metadata_b.as_slice()).map_err(|e| { + DbError::SerializationError(format!( + "Failed to deserailize KD Tree Metadata: {}", + e + )) + })?; + + let mut buf = Cursor::new(topology_b); + let mut non_deleted = HashSet::new(); + let root = deserialize_topology_recursive(&mut buf, &mut non_deleted)?; + + Ok(KDTree { + dim: metadata.dim, + root, + point_ids: non_deleted, + total_nodes: metadata.total_nodes, + deleted_count: metadata.deleted_count, + }) + } +} + +// helper functions + +fn serialize_topology_recursive( + current_opt: &Option>, + buffer: &mut Cursor<&mut Vec>, +) -> Result<(), DbError> { + if let Some(current) = current_opt { + let mut marker = NODE_MARKER_BYTE; + if current.is_deleted { + marker |= DELETED_MASK; + } + buffer + .write_all(&[marker]) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + + let uuid_bytes = current.indexed_vector.id.to_bytes_le(); + buffer + .write_all(&uuid_bytes) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + + // serialize left subtree topology + serialize_topology_recursive(¤t.left, buffer)?; + // serialize right subtree topology + serialize_topology_recursive(¤t.right, buffer)?; + } else { + buffer + .write_all(&[SKIP_MARKER_BYTE]) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + } + Ok(()) +} + +fn populate_vectors_recursive( + node: &mut Option>, + storage: &dyn StorageEngine, +) -> Result<(), DbError> { + if let Some(node) = node { + let vector = storage + .get_vector(node.indexed_vector.id)? + .ok_or(DbError::VectorNotFound(node.indexed_vector.id))?; + node.indexed_vector.vector = vector; + + populate_vectors_recursive(&mut node.left, storage)?; + populate_vectors_recursive(&mut node.right, storage)?; + } + Ok(()) +} + +fn deserialize_topology_recursive( + buffer: &mut Cursor<&Vec>, + non_deleted: &mut HashSet, +) -> Result>, DbError> { + let mut current_marker: [u8; 1] = [0u8; 1]; + buffer.read_exact(&mut current_marker).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize KD Topology: {}", e)) + })?; + + if current_marker[0] == SKIP_MARKER_BYTE { + return Ok(None); + } + + let mut uuid_bytes = [0u8; 16]; + buffer.read_exact(&mut uuid_bytes).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize KD Topology: {}", e)) + })?; + let uuid = Uuid::from_bytes_le(uuid_bytes); + let indexed_vector = IndexedVector { + id: uuid, + vector: Vec::new(), + }; + + let is_deleted = current_marker[0] & DELETED_MASK == DELETED_MASK; + if !is_deleted { + non_deleted.insert(uuid); + } + + // pre order deserialization + let left_node = deserialize_topology_recursive(buffer, non_deleted)?; + let right_node = deserialize_topology_recursive(buffer, non_deleted)?; + + let left_size = left_node.as_ref().map_or(0, |n| n.subtree_size); + let right_size = right_node.as_ref().map_or(0, |n| n.subtree_size); + + let current_node = KDTreeNode { + indexed_vector, + left: left_node, + right: right_node, + is_deleted, + subtree_size: left_size + right_size + 1, + }; + + Ok(Some(Box::new(current_node))) +} diff --git a/crates/index/src/kd_tree/tests.rs b/crates/index/src/kd_tree/tests.rs new file mode 100644 index 0000000..5b30952 --- /dev/null +++ b/crates/index/src/kd_tree/tests.rs @@ -0,0 +1,734 @@ +use super::index::KDTree; +use crate::SerializableIndex; +use crate::VectorIndex; +use crate::distance; +use crate::flat::index::FlatIndex; +use defs::{DbError, IndexedVector, Similarity}; +use std::collections::HashSet; +use uuid::Uuid; + +fn make_vector(vector: Vec) -> IndexedVector { + IndexedVector { + id: Uuid::new_v4(), + vector, + } +} + +fn make_vector_with_id(id: Uuid, vector: Vec) -> IndexedVector { + IndexedVector { id, vector } +} + +// Build Tests + +#[test] +fn test_build_empty() { + let tree = KDTree::build_empty(3); + assert!(tree.root.is_none()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 0); + assert!(tree.point_ids.is_empty()); +} + +#[test] +fn test_build_with_empty_vectors_returns_error() { + let result = KDTree::build(vec![]); + assert!(result.is_err()); +} + +#[test] +fn test_build_single_vector() { + let id = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id, vec![1.0, 2.0, 3.0])]; + let tree = KDTree::build(vectors).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 1); + assert!(tree.point_ids.contains(&id)); +} + +#[test] +fn test_build_multiple_vectors() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 2.0]), + make_vector_with_id(id2, vec![3.0, 4.0]), + make_vector_with_id(id3, vec![5.0, 6.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 2); + assert_eq!(tree.total_nodes, 3); + assert!(tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); +} + +// Insert Tests + +#[test] +fn test_insert_into_empty_tree() { + let mut tree = KDTree::build_empty(2); + let id = Uuid::new_v4(); + let vector = make_vector_with_id(id, vec![1.0, 2.0]); + + let result = tree.insert(vector); + assert!(result.is_ok()); + assert_eq!(tree.total_nodes, 1); + assert!(tree.point_ids.contains(&id)); + assert!(tree.root.is_some()); +} + +#[test] +fn test_insert_multiple_vectors() { + let mut tree = KDTree::build_empty(2); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + + tree.insert(make_vector_with_id(id1, vec![1.0, 2.0])) + .unwrap(); + tree.insert(make_vector_with_id(id2, vec![3.0, 4.0])) + .unwrap(); + tree.insert(make_vector_with_id(id3, vec![5.0, 6.0])) + .unwrap(); + + assert_eq!(tree.total_nodes, 3); + assert!(tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); +} + +// Delete Tests + +#[test] +fn test_delete_existing_point() { + let mut ids = Vec::new(); + let mut vectors = Vec::new(); + + // Create enough vectors so deleting one doesn't trigger global rebuild + for i in 0..10 { + let id = Uuid::new_v4(); + ids.push(id); + vectors.push(make_vector_with_id(id, vec![i as f32, i as f32])); + } + + let mut tree = KDTree::build(vectors).unwrap(); + + let result = tree.delete(ids[0]).unwrap(); + assert!(result); + assert!(!tree.point_ids.contains(&ids[0])); + assert_eq!(tree.deleted_count, 1); +} + +#[test] +fn test_delete_non_existing_point() { + let id1 = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id1, vec![1.0, 2.0])]; + let mut tree = KDTree::build(vectors).unwrap(); + + let non_existing_id = Uuid::new_v4(); + let result = tree.delete(non_existing_id).unwrap(); + assert!(!result); + assert_eq!(tree.deleted_count, 0); +} + +#[test] +fn test_delete_from_empty_tree() { + let mut tree = KDTree::build_empty(2); + let result = tree.delete(Uuid::new_v4()).unwrap(); + assert!(!result); +} + +#[test] +fn test_deleted_point_not_in_search_results() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![0.0, 0.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), + make_vector_with_id(id3, vec![10.0, 10.0]), + ]; + let mut tree = KDTree::build(vectors).unwrap(); + + // Delete the closest point + tree.delete(id1).unwrap(); + + // Search should not return the deleted point + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert!(!results.contains(&id1)); + assert!(results.contains(&id2)); +} + +// Search Tests (VectorIndex trait) + +#[test] +fn test_search_empty_tree() { + let tree = KDTree::build_empty(2); + let results = tree + .search(vec![1.0, 2.0], Similarity::Euclidean, 5) + .unwrap(); + assert!(results.is_empty()); +} + +#[test] +fn test_search_euclidean() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + make_vector_with_id(id3, vec![10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0], id1); // Closest + assert_eq!(results[1], id2); // Second closest +} + +#[test] +fn test_search_manhattan() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + make_vector_with_id(id3, vec![5.0, 5.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Manhattan, 2) + .unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0], id1); + assert_eq!(results[1], id2); +} + +#[test] +fn test_search_unsupported_similarity_cosine() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let result = tree.search(vec![1.0, 2.0], Similarity::Cosine, 1); + assert!(matches!(result, Err(DbError::UnsupportedSimilarity))); +} + +#[test] +fn test_search_unsupported_similarity_hamming() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let result = tree.search(vec![1.0, 2.0], Similarity::Hamming, 1); + assert!(matches!(result, Err(DbError::UnsupportedSimilarity))); +} + +#[test] +fn test_search_k_larger_than_tree_size() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![2.0, 2.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 10) + .unwrap(); + assert_eq!(results.len(), 2); // Should return all available points +} + +#[test] +fn test_search_exact_match() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![5.0, 5.0]), + make_vector_with_id(id2, vec![10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![5.0, 5.0], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], id1); +} + +// Search Correctness Tests + +#[test] +fn test_search_correctness_3d() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![0.0, 0.0, 0.0]), + make_vector_with_id(id2, vec![1.0, 1.0, 1.0]), + make_vector_with_id(id3, vec![2.0, 2.0, 2.0]), + make_vector_with_id(id4, vec![10.0, 10.0, 10.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.5, 0.5, 0.5], Similarity::Euclidean, 2) + .unwrap(); + // id1 at distance sqrt(0.75) ≈ 0.866 + // id2 at distance sqrt(0.75) ≈ 0.866 + // Both are equidistant, should return both + assert_eq!(results.len(), 2); + assert!(results.contains(&id1) || results.contains(&id2)); +} + +#[test] +fn test_search_after_insert() { + let mut tree = KDTree::build_empty(2); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + + tree.insert(make_vector_with_id(id1, vec![10.0, 10.0])) + .unwrap(); + tree.insert(make_vector_with_id(id2, vec![1.0, 1.0])) + .unwrap(); + tree.insert(make_vector_with_id(id3, vec![5.0, 5.0])) + .unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results[0], id2); // Closest to origin + assert_eq!(results[1], id3); // Second closest +} + +#[test] +fn test_search_high_dimensional() { + let dim = 10; + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + + let vectors = vec![ + make_vector_with_id(id1, vec![0.0; dim]), + make_vector_with_id(id2, vec![1.0; dim]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let query = vec![0.1; dim]; + let results = tree.search(query, Similarity::Euclidean, 1).unwrap(); + assert_eq!(results[0], id1); // Closer to all-zeros +} + +// Rebalancing Tests + +#[test] +fn test_many_inserts_maintains_searchability() { + let mut tree = KDTree::build_empty(2); + let mut ids = Vec::new(); + + // Insert many points that would cause imbalance + for i in 0..20 { + let id = Uuid::new_v4(); + ids.push(id); + tree.insert(make_vector_with_id(id, vec![i as f32, i as f32])) + .unwrap(); + } + + // Search should still work correctly + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 5) + .unwrap(); + assert_eq!(results.len(), 5); + // First result should be the point at (0, 0) + assert_eq!(results[0], ids[0]); +} + +#[test] +fn test_delete_triggers_rebuild() { + let mut ids = Vec::new(); + let mut vectors = Vec::new(); + + for i in 0..10 { + let id = Uuid::new_v4(); + ids.push(id); + vectors.push(make_vector_with_id(id, vec![i as f32, i as f32])); + } + + let mut tree = KDTree::build(vectors).unwrap(); + + // Delete enough points to trigger rebuild (> 25%) + for id in ids.iter().take(3) { + tree.delete(*id).unwrap(); + } + + // Tree should still function correctly + let results = tree + .search(vec![5.0, 5.0], Similarity::Euclidean, 3) + .unwrap(); + assert_eq!(results.len(), 3); + // Deleted points should not appear + for id in ids.iter().take(3) { + assert!(!results.contains(id)); + } +} + +// Edge Cases + +#[test] +fn test_single_point_search() { + let id = Uuid::new_v4(); + let vectors = vec![make_vector_with_id(id, vec![5.0, 5.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![0.0, 0.0], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], id); +} + +#[test] +fn test_duplicate_coordinates() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 1.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), // Same coordinates + make_vector_with_id(id3, vec![2.0, 2.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![1.0, 1.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results.len(), 2); + // Both id1 and id2 should be in results (both at distance 0) + assert!(results.contains(&id1) || results.contains(&id2)); +} + +#[test] +fn test_negative_coordinates() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let vectors = vec![ + make_vector_with_id(id1, vec![-1.0, -1.0]), + make_vector_with_id(id2, vec![1.0, 1.0]), + ]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![-0.5, -0.5], Similarity::Euclidean, 1) + .unwrap(); + assert_eq!(results[0], id1); +} + +#[test] +fn test_search_with_zero_k() { + let vectors = vec![make_vector(vec![1.0, 2.0])]; + let tree = KDTree::build(vectors).unwrap(); + + let results = tree + .search(vec![1.0, 2.0], Similarity::Euclidean, 0) + .unwrap(); + assert!(results.is_empty()); +} + +// Comparison Tests: KDTree vs FlatIndex + +/// Helper to create a fixed set of 10 vectors with known UUIDs for comparison tests +fn create_test_vectors_2d() -> Vec { + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + vec![ + make_vector_with_id(ids[0], vec![0.5, 0.5]), + make_vector_with_id(ids[1], vec![2.3, 1.7]), + make_vector_with_id(ids[2], vec![-1.0, 3.0]), + make_vector_with_id(ids[3], vec![4.5, -2.0]), + make_vector_with_id(ids[4], vec![7.0, 7.0]), + make_vector_with_id(ids[5], vec![-3.5, -1.5]), + make_vector_with_id(ids[6], vec![1.0, 5.0]), + make_vector_with_id(ids[7], vec![6.0, 2.0]), + make_vector_with_id(ids[8], vec![-2.0, -4.0]), + make_vector_with_id(ids[9], vec![3.0, 3.0]), + ] +} + +fn create_test_vectors_3d() -> Vec { + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + vec![ + make_vector_with_id(ids[0], vec![1.0, 2.0, 3.0]), + make_vector_with_id(ids[1], vec![-1.5, 0.5, 2.0]), + make_vector_with_id(ids[2], vec![4.0, 4.0, 4.0]), + make_vector_with_id(ids[3], vec![0.0, 0.0, 0.0]), + make_vector_with_id(ids[4], vec![2.5, -1.0, 3.5]), + make_vector_with_id(ids[5], vec![-2.0, 3.0, -1.0]), + make_vector_with_id(ids[6], vec![5.0, 1.0, 2.0]), + make_vector_with_id(ids[7], vec![3.0, 3.0, 3.0]), + make_vector_with_id(ids[8], vec![-0.5, -0.5, 1.0]), + make_vector_with_id(ids[9], vec![1.5, 2.5, 0.5]), + ] +} + +/// Helper to verify that two result sets are valid k-nearest neighbor results +/// Both should return the k closest points (by distance), but may differ on tie-breaking +fn verify_same_results( + tree_results: &[Uuid], + flat_results: &[Uuid], + vectors: &[IndexedVector], + query: &[f32], + similarity: Similarity, + k: usize, +) { + // Same length + assert_eq!( + tree_results.len(), + flat_results.len(), + "Result lengths differ" + ); + + // Both should return at most k results + assert!(tree_results.len() <= k); + + // Get distances for all results + let query_vec = query.to_vec(); + let get_distance = |id: &Uuid| -> f32 { + let vec = vectors.iter().find(|v| v.id == *id).unwrap(); + distance(&vec.vector, &query_vec, similarity) + }; + + // Verify tree results are sorted by distance + for i in 1..tree_results.len() { + let d1 = get_distance(&tree_results[i - 1]); + let d2 = get_distance(&tree_results[i]); + assert!( + d1 <= d2 + 1e-6, + "KDTree results not sorted: {} > {}", + d1, + d2 + ); + } + + // Verify flat results are sorted by distance + for i in 1..flat_results.len() { + let d1 = get_distance(&flat_results[i - 1]); + let d2 = get_distance(&flat_results[i]); + assert!(d1 <= d2 + 1e-6, "Flat results not sorted: {} > {}", d1, d2); + } + + // The maximum distance in both result sets should be the same (k-th nearest distance) + if !tree_results.is_empty() { + let tree_max_dist = get_distance(tree_results.last().unwrap()); + let flat_max_dist = get_distance(flat_results.last().unwrap()); + assert!( + (tree_max_dist - flat_max_dist).abs() < 1e-6, + "Max distances differ: tree={}, flat={}", + tree_max_dist, + flat_max_dist + ); + } + + // Verify that for each result in tree_results, either: + // 1. It's also in flat_results, OR + // 2. It has the same distance as the last element (tie-breaking difference) + let flat_set: HashSet<_> = flat_results.iter().collect(); + let flat_max_dist = if flat_results.is_empty() { + 0.0 + } else { + get_distance(flat_results.last().unwrap()) + }; + + for id in tree_results { + if !flat_set.contains(id) { + // This ID is not in flat results, verify it's a tie + let dist = get_distance(id); + assert!( + (dist - flat_max_dist).abs() < 1e-6, + "KDTree returned {:?} with distance {} but it's not in flat results and not a tie (flat max: {})", + id, + dist, + flat_max_dist + ); + } + } + + // Similarly verify flat_results + let tree_set: HashSet<_> = tree_results.iter().collect(); + let tree_max_dist = if tree_results.is_empty() { + 0.0 + } else { + get_distance(tree_results.last().unwrap()) + }; + + for id in flat_results { + if !tree_set.contains(id) { + let dist = get_distance(id); + assert!( + (dist - tree_max_dist).abs() < 1e-6, + "Flat returned {:?} with distance {} but it's not in tree results and not a tie (tree max: {})", + id, + dist, + tree_max_dist + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_2d() { + let vectors = create_test_vectors_2d(); + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + // Test multiple query points and different k values + let queries = vec![ + vec![0.0, 0.0], + vec![3.0, 3.0], + vec![-1.0, 2.0], + vec![5.0, 5.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_3d() { + let vectors = create_test_vectors_3d(); + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + let queries = vec![ + vec![0.0, 0.0, 0.0], + vec![2.0, 2.0, 2.0], + vec![-1.0, 1.0, 1.0], + vec![4.0, 3.0, 3.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} + +#[test] +fn test_kdtree_vs_flat_euclidean_5d() { + // Test with higher dimensionality + let ids: Vec = (0..10).map(|_| Uuid::new_v4()).collect(); + let vectors = vec![ + make_vector_with_id(ids[0], vec![1.0, 2.0, 3.0, 4.0, 5.0]), + make_vector_with_id(ids[1], vec![-1.0, 0.0, 1.0, 2.0, 3.0]), + make_vector_with_id(ids[2], vec![5.0, 4.0, 3.0, 2.0, 1.0]), + make_vector_with_id(ids[3], vec![0.0, 0.0, 0.0, 0.0, 0.0]), + make_vector_with_id(ids[4], vec![2.5, 2.5, 2.5, 2.5, 2.5]), + make_vector_with_id(ids[5], vec![-2.0, -1.0, 0.0, 1.0, 2.0]), + make_vector_with_id(ids[6], vec![3.0, 3.0, 3.0, 3.0, 3.0]), + make_vector_with_id(ids[7], vec![1.0, 1.0, 1.0, 1.0, 1.0]), + make_vector_with_id(ids[8], vec![4.0, 0.0, -1.0, 2.0, 5.0]), + make_vector_with_id(ids[9], vec![-0.5, 1.5, 2.5, 3.5, 4.5]), + ]; + + let tree = KDTree::build(vectors.clone()).unwrap(); + let flat = FlatIndex::build(vectors.clone()); + + let queries = vec![ + vec![0.0, 0.0, 0.0, 0.0, 0.0], + vec![2.0, 2.0, 2.0, 2.0, 2.0], + vec![1.0, 2.0, 3.0, 4.0, 5.0], + vec![-1.0, -1.0, 0.0, 1.0, 1.0], + ]; + + for query in queries { + for k in [1, 3, 5, 10] { + let tree_results = tree + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + let flat_results = flat + .search(query.clone(), Similarity::Euclidean, k) + .unwrap(); + + verify_same_results( + &tree_results, + &flat_results, + &vectors, + &query, + Similarity::Euclidean, + k, + ); + } + } +} + +#[test] +fn test_serialize_and_deserialize_topo() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 2.0, 3.0]), + make_vector_with_id(id2, vec![4.0, 5.0, 6.0]), + make_vector_with_id(id3, vec![7.0, 8.0, 9.0]), + ]; + let mut tree_before = KDTree::build(vectors).unwrap(); + tree_before + .insert(make_vector_with_id(id4, vec![10.0, 11.0, 12.0])) + .unwrap(); + tree_before.delete(id1).unwrap(); + + let snapshot = tree_before.snapshot().unwrap(); + let tree = KDTree::deserialize(&snapshot).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 4); + assert!(!tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); + assert!(tree.point_ids.contains(&id3)); +} diff --git a/crates/index/src/kd_tree/types.rs b/crates/index/src/kd_tree/types.rs new file mode 100644 index 0000000..9999cb3 --- /dev/null +++ b/crates/index/src/kd_tree/types.rs @@ -0,0 +1,37 @@ +use std::cmp::Ordering; + +use defs::{IndexedVector, PointId}; + +// the node which will be the part of the KD Tree +pub struct KDTreeNode { + pub indexed_vector: IndexedVector, + pub left: Option>, + pub right: Option>, + pub is_deleted: bool, + + pub subtree_size: usize, +} + +// The struct definition which is present in max heap while search +#[derive(Debug, Clone, PartialEq)] +pub struct Neighbor { + pub id: PointId, + pub distance: f32, +} + +impl Eq for Neighbor {} + +// Custom Ord implementation for the max-heap +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + self.distance + .partial_cmp(&other.distance) + .unwrap_or(Ordering::Equal) + } +} + +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index bd802b2..e494591 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -1,8 +1,11 @@ -use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; +use defs::{DbError, DenseVector, IndexedVector, Magic, PointId, Similarity}; +use serde::{Deserialize, Serialize}; +use storage::StorageEngine; pub mod flat; +pub mod kd_tree; -pub trait VectorIndex: Send + Sync { +pub trait VectorIndex: Send + Sync + SerializableIndex { fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>; // Returns true if point id existed and is deleted, else returns false @@ -19,7 +22,7 @@ pub trait VectorIndex: Send + Sync { } /// Distance function to get the distance between two vectors (taken from old version) -pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { +pub fn distance(a: &DenseVector, b: &DenseVector, dist_type: Similarity) -> f32 { assert_eq!(a.len(), b.len()); match dist_type { Similarity::Euclidean => { @@ -58,9 +61,25 @@ pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum IndexType { Flat, KDTree, HNSW, } + +pub struct IndexSnapshot { + pub index_type: IndexType, + pub magic: Magic, + pub topology_b: Vec, + pub metadata_b: Vec, +} + +pub trait SerializableIndex { + fn serialize_topology(&self) -> Result, DbError>; + fn serialize_metadata(&self) -> Result, DbError>; + + fn snapshot(&self) -> Result; + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError>; +} diff --git a/crates/snapshot/Cargo.toml b/crates/snapshot/Cargo.toml new file mode 100644 index 0000000..10328f1 --- /dev/null +++ b/crates/snapshot/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "snapshot" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +chrono.workspace = true +data-encoding = "2.9.0" +defs.workspace = true +flate2 = "1.1.5" +fs2 = "0.4.3" +index.workspace = true +semver = "1.0.27" +serde.workspace = true +serde_json.workspace = true +sha2 = "0.10.9" +storage.workspace = true +tar = "0.4.44" +tempfile.workspace = true +uuid.workspace = true diff --git a/crates/snapshot/README.md b/crates/snapshot/README.md new file mode 100644 index 0000000..e69de29 diff --git a/crates/snapshot/src/constants.rs b/crates/snapshot/src/constants.rs new file mode 100644 index 0000000..3dd46d3 --- /dev/null +++ b/crates/snapshot/src/constants.rs @@ -0,0 +1,5 @@ +use semver::Version; + +pub const SNAPSHOT_PARSER_VER: Version = Version::new(0, 1, 0); +pub const SMALL_ID_LEN: usize = 8; +pub const MANIFEST_FILE: &str = "manifest.json"; diff --git a/crates/snapshot/src/engine/mod.rs b/crates/snapshot/src/engine/mod.rs new file mode 100644 index 0000000..fe138f8 --- /dev/null +++ b/crates/snapshot/src/engine/mod.rs @@ -0,0 +1,172 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Condvar, Mutex}, + time::Duration, +}; + +use defs::{DbError, SnapshottableDb}; + +use crate::{metadata::Metadata, registry::SnapshotRegistry}; + +pub struct SnapshotEngine { + last_k: usize, // only retain the last k snapshots on disk. old/stale snapshots are marked as dead on the registry + snapshot_queue: Arc>>, + db: Arc>, + registry: Arc>, + worker_cv: Arc, + worker_running: Arc>, +} +impl SnapshotEngine { + pub fn new( + last_k: usize, + db: Arc>, + registry: Arc>, + ) -> Self { + Self { + last_k, + snapshot_queue: Arc::new(Mutex::new(VecDeque::new())), + db, + registry, + worker_cv: Arc::new(Condvar::new()), + worker_running: Arc::new(Mutex::new(false)), + } + } + + pub fn stop_worker(&mut self) -> Result<(), DbError> { + // acquire lock for worker_running + let mut worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if !*worker_running { + return Err(DbError::StorageEngineError( + "Worker thread not running".to_string(), + )); + } + *worker_running = false; + self.worker_cv.notify_one(); + Ok(()) + } + + // notify the worker thread to take a snapshot now + pub fn worker_snapshot(&mut self) -> Result<(), DbError> { + // acquire lock for worker_running + let worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if !*worker_running { + return Err(DbError::StorageEngineError( + "Worker thread not running".to_string(), + )); + } + self.worker_cv.notify_one(); + Ok(()) + } + + // take a snapshot on the callers thread + pub fn snapshot(&mut self) -> Result<(), DbError> { + Self::take_snapshot( + &mut self.db, + &mut self.registry, + &mut self.snapshot_queue, + self.last_k, + ) + } + + pub fn list_alive_snapshots(&mut self) -> Result, DbError> { + Ok(self + .snapshot_queue + .lock() + .map_err(|_| DbError::LockError)? + .iter() + .cloned() + .collect()) + } + + pub fn start_worker(&mut self, interval: i64) -> Result<(), DbError> { + // acquire lock for worker_running + let mut worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if *worker_running { + return Err(DbError::StorageEngineError( + "Worker thread already running".to_string(), + )); + } + *worker_running = true; + + let worker_running_clone = Arc::clone(&self.worker_running); + let db_clone = Arc::clone(&self.db); + let registry_clone = Arc::clone(&self.registry); + let worker_cv_clone = Arc::clone(&self.worker_cv); + let snapshot_queue_clone = Arc::clone(&self.snapshot_queue); + let last_k_clone = self.last_k; + + let dur_interval = Duration::from_secs(interval as u64); + let _ = std::thread::spawn(move || { + Self::worker( + dur_interval, + last_k_clone, + worker_running_clone, + db_clone, + registry_clone, + worker_cv_clone, + snapshot_queue_clone, + ); + }); + Ok(()) + } + + // helper function to take snapshot + fn take_snapshot( + db: &mut Arc>, + registry: &mut Arc>, + snapshot_queue: &mut Arc>>, + last_k: usize, + ) -> Result<(), DbError> { + let snapshot_path = db + .lock() + .unwrap() + .create_snapshot(registry.lock().unwrap().dir().as_path()) + .unwrap(); + let snapshot_metadata = Metadata::parse(&snapshot_path).unwrap(); + + // add the snapshot to registry + registry + .lock() + .unwrap() + .add_snapshot(&snapshot_path) + .unwrap(); + + { + let mut queue = snapshot_queue.lock().unwrap(); + queue.push_back(snapshot_metadata); + + while queue.len() > last_k { + let old = queue.pop_front().unwrap(); + registry.lock().unwrap().mark_dead(old.small_id).unwrap(); + } + // drop queue lock + } + Ok(()) + } + + // TODO: fix sync issues if any (i dont think there are any) + fn worker( + interval: Duration, + last_k: usize, + worker_running: Arc>, + mut db: Arc>, + mut registry: Arc>, + worker_cv: Arc, + mut snapshot_queue: Arc>>, + ) { + loop { + // acquire the lock and exit if its false + let worker_running = worker_running + .lock() + .map_err(|_| DbError::LockError) + .unwrap(); + if !*worker_running { + break; + } + + Self::take_snapshot(&mut db, &mut registry, &mut snapshot_queue, last_k).unwrap(); + + let _ = worker_cv.wait_timeout(worker_running, interval).unwrap(); + } + } +} diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs new file mode 100644 index 0000000..32311d2 --- /dev/null +++ b/crates/snapshot/src/lib.rs @@ -0,0 +1,294 @@ +pub mod constants; +pub mod engine; +pub mod manifest; +pub mod metadata; +pub mod registry; +mod util; + +use crate::{ + constants::{MANIFEST_FILE, SNAPSHOT_PARSER_VER}, + manifest::Manifest, + util::{compress_archive, save_index_metadata, save_topology}, +}; + +use chrono::{DateTime, Local}; +use defs::DbError; +use flate2::read::GzDecoder; +use index::{ + IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, kd_tree::index::KDTree, +}; +use semver::Version; +use std::{ + fs::File, + path::{Path, PathBuf}, + sync::{Arc, RwLock}, + time::SystemTime, +}; +use storage::{ + StorageEngine, StorageType, checkpoint::StorageCheckpoint, rocks_db::RocksDbStorage, +}; +use tar::Archive; +use tempfile::tempdir; +use uuid::Uuid; + +type VectorDbRestore = (Arc, Arc>, usize); + +pub struct Snapshot { + pub id: Uuid, + pub date: SystemTime, + pub sem_ver: Version, + pub index_snapshot: IndexSnapshot, + pub storage_snapshot: StorageCheckpoint, + pub dimensions: usize, +} + +impl Snapshot { + pub fn new( + index_snapshot: IndexSnapshot, + storage_snapshot: StorageCheckpoint, + dimensions: usize, + ) -> Result { + let id = Uuid::new_v4(); + let date = SystemTime::now(); + + Ok(Snapshot { + id, + date, + sem_ver: SNAPSHOT_PARSER_VER, + index_snapshot, + storage_snapshot, + dimensions, + }) + } + + pub fn save(&self, dir_path: &Path) -> Result { + if !dir_path.is_dir() { + return Err(DbError::SnapshotError(format!( + "Invalid path: {}", + dir_path.display() + ))); + } + + let temp_dir = tempdir().map_err(|e| DbError::SnapshotError(e.to_string()))?; + + // save index snapshots + let index_metadata_path = save_index_metadata( + temp_dir.path(), + self.id, + &self.index_snapshot.metadata_b, + &self.index_snapshot.magic, + )?; + + let topology_path = save_topology( + temp_dir.path(), + self.id, + &self.index_snapshot.topology_b, + &self.index_snapshot.magic, + )?; + + // take checksums + let index_metadata_checksum = util::sha256_digest(&index_metadata_path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + let index_topo_checksum = util::sha256_digest(&topology_path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + let storage_checkpoint_checksum = util::sha256_digest(&self.storage_snapshot.path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + + let dt_now_local: DateTime = self.date.into(); + + // need this for manifest + let storage_checkpoint_filename = self + .storage_snapshot + .path + .file_name() + .ok_or(DbError::SnapshotError( + "Storage checkpoint was not properly made".to_string(), + ))? + .to_str() + .unwrap() + .to_string(); + + // create manifest file + let manifest = Manifest { + id: self.id, + date: dt_now_local.timestamp(), + sem_ver: constants::SNAPSHOT_PARSER_VER.to_string(), + index_metadata_checksum, + index_topo_checksum, + storage_checkpoint_checksum, + storage_type: self.storage_snapshot.storage_type, + index_type: self.index_snapshot.index_type, + dimensions: self.dimensions, + storage_checkpoint_filename, + }; + + let manifest_path = manifest + .save(temp_dir.path()) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + + let tar_filename = format!( + "{}.tar.gz", + metadata::Metadata::new( + self.id, + self.date, + index_metadata_path.clone(), + constants::SNAPSHOT_PARSER_VER + ) + ); + let tar_gz_path = dir_path.join(tar_filename); + + compress_archive( + &tar_gz_path, + &[ + &index_metadata_path, + &topology_path, + &self.storage_snapshot.path, + &manifest_path, + ], + ) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + Ok(tar_gz_path.to_path_buf()) + } + + pub fn load(path: &Path, storage_data_path: &Path) -> Result { + let tar_gz = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open snapshot: {}", e)))?; + + let tar = GzDecoder::new(tar_gz); + let mut archive = Archive::new(tar); + + let snapshot_filename = path.file_name().ok_or(DbError::SnapshotError( + "Invalid snapshot filename".to_string(), + ))?; + let temp_dir = std::env::temp_dir().join(snapshot_filename); + + // remove any existing data + if temp_dir.exists() && !temp_dir.is_dir() { + std::fs::remove_file(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't remove existing file: {}", e)) + })?; + } else if temp_dir.is_dir() { + std::fs::remove_dir_all(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't remove existing directory: {}", e)) + })?; + } + + std::fs::create_dir(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't create temporary directory: {}", e)) + })?; + + archive + .unpack(temp_dir.clone()) + .map_err(|e| DbError::SnapshotError(format!("Couldn't unpack archive: {}", e)))?; + + // read manifest and validate + let manifest_path = temp_dir.join(MANIFEST_FILE); + if !manifest_path.is_file() { + return Err(DbError::SnapshotError( + "Manifest file not found".to_string(), + )); + } + + let manifest = Manifest::load(&manifest_path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't load manifest: {}", e)))?; + + if manifest.sem_ver != SNAPSHOT_PARSER_VER.to_string() { + return Err(DbError::SnapshotError( + "Incompatible snapshot version".to_string(), + )); + } + + // only rocksdb is supported for snapshots as of now + let mut storage_engine: Box = match manifest.storage_type { + StorageType::RocksDb => Box::new(RocksDbStorage::new(storage_data_path)?), + _ => { + return Err(DbError::SnapshotError( + "Unsupported storage type".to_string(), + )); + } + }; + + let id = manifest.id; + let index_metadata_path = temp_dir.join(util::metadata_filename(&id)); + let topology_path = temp_dir.join(util::topology_filename(&id)); + let storage_checkpoint_path = temp_dir.join(manifest.storage_checkpoint_filename); + + if !index_metadata_path.exists() + || !topology_path.exists() + || !storage_checkpoint_path.exists() + { + return Err(DbError::SnapshotError(format!( + "Missing snapshot files {} , {}, {}", + index_metadata_path.display(), + topology_path.display(), + storage_checkpoint_path.display() + ))); + } + + // match checksums + if util::sha256_digest(&index_metadata_path).map_err(|_| { + DbError::SnapshotError("Could not calculate index metadata hash".to_string()) + })? != manifest.index_metadata_checksum + { + return Err(DbError::SnapshotError( + "Index metadata hash mismatch".to_string(), + )); + } + if util::sha256_digest(&topology_path) + .map_err(|_| DbError::SnapshotError("Could not calculate topology hash".to_string()))? + != manifest.index_topo_checksum + { + return Err(DbError::SnapshotError("Topology hash mismatch".to_string())); + } + if util::sha256_digest(&storage_checkpoint_path).map_err(|_| { + DbError::SnapshotError("Could not calculate storage checkpoint hash".to_string()) + })? != manifest.storage_checkpoint_checksum + { + return Err(DbError::SnapshotError( + "Storage checkpoint hash mismatch".to_string(), + )); + } + + let (mgmeta, meta_bytes) = util::read_index_metadata(&index_metadata_path) + .map_err(|_| DbError::SnapshotError("Could not read metadata".to_string()))?; + let (mgtopo, topo_bytes) = util::read_index_topology(&topology_path) + .map_err(|_| DbError::SnapshotError("Could not read topology".to_string()))?; + + if mgtopo != mgmeta { + return Err(DbError::InvalidMagicBytes( + "Magic bytes don't match".to_string(), + )); + } + + // validates if manifest storage type matches that in the filename of storage checkpoint + let storage_checkpoint = StorageCheckpoint::open(storage_checkpoint_path.as_path())?; + if storage_checkpoint.storage_type != manifest.storage_type { + return Err(DbError::SnapshotError( + "Storage type mismatch from manifest and checkpoint".to_string(), + )); + } + + storage_engine.restore_checkpoint(&storage_checkpoint)?; + + let index_snapshot = IndexSnapshot { + index_type: manifest.index_type, + magic: mgmeta, + metadata_b: meta_bytes, + topology_b: topo_bytes, + }; + + // dynamic dispatch based on index type + let vector_index: Arc> = match manifest.index_type { + IndexType::Flat => Arc::new(RwLock::new(FlatIndex::deserialize(&index_snapshot)?)), + IndexType::KDTree => Arc::new(RwLock::new(KDTree::deserialize(&index_snapshot)?)), + _ => return Err(DbError::SnapshotError("Unsupported index type".to_string())), + }; + + vector_index + .write() + .map_err(|_| DbError::LockError)? + .populate_vectors(&*storage_engine)?; + + Ok((storage_engine.into(), vector_index, manifest.dimensions)) + } +} diff --git a/crates/snapshot/src/manifest.rs b/crates/snapshot/src/manifest.rs new file mode 100644 index 0000000..0993435 --- /dev/null +++ b/crates/snapshot/src/manifest.rs @@ -0,0 +1,47 @@ +use index::IndexType; +use serde::{Deserialize, Serialize}; +use std::path::Path; +use std::{ + io::{BufReader, BufWriter, Error, Write}, + path::PathBuf, +}; +use storage::StorageType; +use uuid::Uuid; + +use crate::constants::MANIFEST_FILE; + +type UnixTimestamp = i64; + +#[derive(Serialize, Deserialize)] +pub struct Manifest { + pub id: Uuid, + pub date: UnixTimestamp, + pub sem_ver: String, + pub index_metadata_checksum: String, + pub index_topo_checksum: String, + pub storage_checkpoint_checksum: String, + pub index_type: IndexType, + pub storage_type: StorageType, + pub dimensions: usize, + pub storage_checkpoint_filename: String, +} + +impl Manifest { + pub fn save(&self, path: &Path) -> Result { + let manifest_path = path.join(MANIFEST_FILE); + + let file = std::fs::File::create(manifest_path.clone())?; + let mut writer = BufWriter::new(file); + serde_json::to_writer(&mut writer, self)?; + writer.flush()?; + + Ok(manifest_path) + } + + pub fn load(path: &Path) -> Result { + let file = std::fs::File::open(path)?; + let mut reader = BufReader::new(file); + let manifest: Manifest = serde_json::from_reader(&mut reader)?; + Ok(manifest) + } +} diff --git a/crates/snapshot/src/metadata.rs b/crates/snapshot/src/metadata.rs new file mode 100644 index 0000000..cd73185 --- /dev/null +++ b/crates/snapshot/src/metadata.rs @@ -0,0 +1,114 @@ +use crate::constants::SMALL_ID_LEN; +use chrono::DateTime; +use chrono::Local; +use defs::DbError; +use semver::Version; +use std::{fmt::Display, path::PathBuf, time::SystemTime}; +use std::{fs, path::Path}; +use uuid::Uuid; + +pub type SmallID = String; + +// Metadata is the data that can be parsed from the snapshot filename +#[derive(Debug, Clone)] +pub struct Metadata { + pub small_id: SmallID, + pub date: SystemTime, + pub path: PathBuf, + pub sem_ver: Version, +} + +const FILENAME_METADATA_SEPARATOR: &str = "-x"; + +impl Metadata { + pub fn new(id: Uuid, date: SystemTime, path: PathBuf, sem_ver: Version) -> Self { + Metadata { + small_id: id.to_string()[..SMALL_ID_LEN].to_string(), + date, + path, + sem_ver, + } + } + + pub fn parse(path: &Path) -> Result { + if !path.is_file() { + return Err(DbError::SnapshotError("File not found".to_string())); + } + let filename = path + .file_name() + .ok_or(DbError::SnapshotError("No filename".to_string()))? + .to_str() + .ok_or(DbError::SnapshotError( + "Invalid UTF-8 in filename".to_string(), + ))? + .strip_suffix(".tar.gz") + .ok_or(DbError::SnapshotError( + "Snapshot filename doesnt end with .tar.gz".to_string(), + ))?; + + let parts = filename + .split(FILENAME_METADATA_SEPARATOR) + .collect::>(); + + if parts.len() != 3 { + return Err(DbError::SnapshotError("Invalid filename".to_string())); + } + + let id = parts[1]; + if id.len() != SMALL_ID_LEN { + return Err(DbError::SnapshotError("Invalid UUID".to_string())); + } + + let date = chrono::DateTime::parse_from_rfc3339(parts[0]) + .map_err(|_| DbError::SnapshotError("Invalid date".to_string()))?; + let version = Version::parse(parts[2]) + .map_err(|_| DbError::SnapshotError("Invalid version".to_string()))?; + + Ok(Metadata { + small_id: id.to_string(), + date: date.into(), + path: path.to_path_buf(), + sem_ver: version, + }) + } + + pub fn snapshot_dir_metadata(path: &Path) -> Result, DbError> { + if !path.is_dir() { + return Err(DbError::SnapshotError( + "Path is not a directory".to_string(), + )); + } + + let mut metadata_vec = Vec::new(); + + for item in fs::read_dir(path).map_err(|_| { + DbError::SnapshotError(format!("Cannot read directory: {}", path.display())) + })? { + let entry = item.map_err(|_| { + DbError::SnapshotError(format!("Invalid entry: {}", path.display())) + })?; + let path = entry.path(); + if path.is_file() + && let Ok(metadata) = Self::parse(&path) + { + metadata_vec.push(metadata); + } + } + Ok(metadata_vec) + } +} + +impl Display for Metadata { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let dt_now_local: DateTime = self.date.into(); + write!( + f, + "{}{}{}{}{}", + dt_now_local.to_rfc3339_opts(chrono::SecondsFormat::Secs, true), + FILENAME_METADATA_SEPARATOR, + self.small_id, + FILENAME_METADATA_SEPARATOR, + self.sem_ver + ) + } +} diff --git a/crates/snapshot/src/registry/constants.rs b/crates/snapshot/src/registry/constants.rs new file mode 100644 index 0000000..9138454 --- /dev/null +++ b/crates/snapshot/src/registry/constants.rs @@ -0,0 +1 @@ +pub const LOCAL_REGISTRY_LOCKFILE: &str = "LOCKFILE"; diff --git a/crates/snapshot/src/registry/local.rs b/crates/snapshot/src/registry/local.rs new file mode 100644 index 0000000..8f79a73 --- /dev/null +++ b/crates/snapshot/src/registry/local.rs @@ -0,0 +1,271 @@ +use std::{ + collections::HashMap, + fs, + path::{Path, PathBuf}, +}; + +use crate::registry::{INFINITY_LIMIT, NO_OFFSET, SnapshotRegistry}; +use crate::registry::{SnapshotMetaPage, constants::LOCAL_REGISTRY_LOCKFILE}; +use crate::{ + Snapshot, VectorDbRestore, + metadata::{Metadata, SmallID}, +}; +use defs::DbError; +use fs2::FileExt; + +pub struct LocalRegistry { + pub dir: PathBuf, + filename_cache: HashMap, +} + +impl LocalRegistry { + pub fn new(dir: &Path) -> Result { + fs::create_dir_all(dir).map_err(|e| DbError::SnapshotRegistryError(e.to_string()))?; + let lock_file_path = dir.join(LOCAL_REGISTRY_LOCKFILE); + let lock_file = if !lock_file_path.exists() { + fs::File::create(&lock_file_path).map_err(|e| { + DbError::SnapshotRegistryError(format!("Couldn't create LOCKFILE : {}", e)) + })? + } else { + fs::OpenOptions::new() + .read(true) + .write(true) + .open(&lock_file_path) + .map_err(|e| { + DbError::SnapshotRegistryError(format!("Couldn't open LOCKFILE : {}", e)) + })? + }; + + // try to acquire lockfile + lock_file + .try_lock_exclusive() + .map_err(|_| DbError::SnapshotRegistryError("Couldn't acquire LOCKFILE".to_string()))?; + + Ok(LocalRegistry { + dir: dir.to_path_buf(), + filename_cache: HashMap::new(), + }) + } +} + +impl SnapshotRegistry for LocalRegistry { + fn add_snapshot(&mut self, snapshot_path: &Path) -> Result { + // move the snapshot file to the directory and cache its metadata + + let filename = snapshot_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Invalid snapshot path".to_string(), + ))?; + let final_snapshot_path = self.dir.join(filename); + + // if the snapshot is already in the managed directory then do nothing + if snapshot_path != final_snapshot_path.as_path() { + fs::rename(snapshot_path, final_snapshot_path.clone()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to move snapshot: {}", e)) + })?; + } + + let metadata = Metadata::parse(final_snapshot_path.as_path())?; + self.filename_cache.insert( + metadata.small_id.clone(), + filename.to_string_lossy().to_string(), + ); + Ok(metadata) + } + + fn list_snapshots(&mut self, limit: usize, offset: usize) -> Result { + let mut res = Vec::new(); + let filtered_files = fs::read_dir(self.dir.as_path()) + .map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? + .skip(offset) + .take(limit); + + for file in filtered_files { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + + if let Ok(metadata) = Metadata::parse(file_path.as_path()) { + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + + res.push(metadata); + } + } + Ok(res) + } + + fn get_latest_snapshot(&mut self) -> Result { + let mut latest_record: Option = None; + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + + if let Ok(metadata) = Metadata::parse(file_path.as_path()) { + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + + latest_record = match latest_record { + None => Some(metadata), + Some(existing) => { + if metadata.date > existing.date { + Some(metadata) + } else { + Some(existing) + } + } + }; + } + } + match latest_record { + Some(metadata) => Ok(metadata), + None => Err(DbError::SnapshotRegistryError( + "No snapshots found".to_string(), + )), + } + } + + fn list_alive_snapshots(&mut self) -> Result { + self.list_snapshots(INFINITY_LIMIT, NO_OFFSET) + } + + fn remove_snapshot(&mut self, small_id: SmallID) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + + let metadata = Metadata::parse(snapshot_filepath.as_path())?; + fs::remove_file(snapshot_filepath.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to remove snapshot: {}", e)) + })?; + self.filename_cache.remove_entry(&small_id); + Ok(metadata) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + if let Ok(metadata) = Metadata::parse(file_path.as_path()) + && metadata.small_id == small_id + { + fs::remove_file(metadata.path.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to remove snapshot: {}", e)) + })?; + return Ok(metadata); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn get_metadata(&mut self, small_id: SmallID) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + let metadata = Metadata::parse(snapshot_filepath.as_path())?; + Ok(metadata) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + if let Ok(metadata) = Metadata::parse(file_path.as_path()) + && metadata.small_id == small_id + { + return Ok(metadata); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn mark_dead(&mut self, small_id: String) -> Result { + self.remove_snapshot(small_id) + } + + fn load( + &mut self, + small_id: String, + storage_data_path: &Path, + ) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + Snapshot::load(snapshot_filepath.as_path(), storage_data_path) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + let metadata = Metadata::parse(file_path.as_path())?; + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + if metadata.small_id == small_id { + return Snapshot::load(file_path.as_path(), storage_data_path); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn dir(&self) -> PathBuf { + self.dir.clone() + } +} + +impl Drop for LocalRegistry { + fn drop(&mut self) { + // remove exclusive lock on lockfile + let lock_file_path = self.dir.join(LOCAL_REGISTRY_LOCKFILE); + if let Ok(lock_file) = fs::OpenOptions::new() + .read(true) + .write(true) + .open(&lock_file_path) + { + let _ = lock_file.unlock(); + } + } +} diff --git a/crates/snapshot/src/registry/mod.rs b/crates/snapshot/src/registry/mod.rs new file mode 100644 index 0000000..6513d97 --- /dev/null +++ b/crates/snapshot/src/registry/mod.rs @@ -0,0 +1,32 @@ +use std::path::{Path, PathBuf}; + +use defs::DbError; +pub mod constants; +pub mod local; +use crate::{VectorDbRestore, metadata::Metadata}; + +pub type SnapshotMetaPage = Vec; + +pub const INFINITY_LIMIT: usize = 100000; +pub const NO_OFFSET: usize = 0; + +pub trait SnapshotRegistry: Send + Sync { + fn add_snapshot(&mut self, snapshot_path: &Path) -> Result; + + fn list_snapshots(&mut self, limit: usize, offset: usize) -> Result; + fn get_latest_snapshot(&mut self) -> Result; + + fn get_metadata(&mut self, small_id: String) -> Result; + fn remove_snapshot(&mut self, small_id: String) -> Result; + + fn load( + &mut self, + small_id: String, + storage_data_path: &Path, + ) -> Result; + fn dir(&self) -> PathBuf; + + // in the future this could be used to maybe move an old/stale snapshot to cold storage or to a remote registry + fn mark_dead(&mut self, small_id: String) -> Result; // current behaviour is to call remove_snapshot; + fn list_alive_snapshots(&mut self) -> Result; // current behaviour is to call list_snapshots; +} diff --git a/crates/snapshot/src/util.rs b/crates/snapshot/src/util.rs new file mode 100644 index 0000000..7d0510e --- /dev/null +++ b/crates/snapshot/src/util.rs @@ -0,0 +1,148 @@ +use data_encoding::HEXLOWER; +use sha2::{Digest, Sha256}; +use std::fs::File; +use std::io::{BufReader, Error, Read}; +use std::path::PathBuf; + +use defs::{DbError, Magic}; +use flate2::{Compression, write::GzEncoder}; +use std::{io::Write, path::Path}; +use tar::Builder; +use uuid::Uuid; + +type BinFileContent = (Magic, Vec); + +#[inline] +pub fn metadata_filename(id: &Uuid) -> String { + format!("{}-index-meta.bin", id) +} + +#[inline] +pub fn topology_filename(id: &Uuid) -> String { + format!("{}-index-topo.bin", id) +} + +// source: https://stackoverflow.com/questions/69787906/how-to-hash-a-binary-file-in-rust +pub fn sha256_digest(path: &PathBuf) -> Result { + let input = File::open(path)?; + let mut reader = BufReader::new(input); + + let digest = { + let mut hasher = Sha256::new(); + let mut buffer = [0; 1024]; + loop { + let count = reader.read(&mut buffer)?; + if count == 0 { + break; + } + hasher.update(&buffer[..count]); + } + hasher.finalize() + }; + Ok(HEXLOWER.encode(digest.as_ref())) +} + +pub fn save_index_metadata( + path: &Path, + uuid: Uuid, + bytes: &[u8], + magic: &Magic, +) -> Result { + let file_name = metadata_filename(&uuid); + let metadata_file_path = path.join(file_name); + + let mut file = std::fs::File::create(metadata_file_path.clone()) + .map_err(|e| DbError::SnapshotError(format!("Could not create metadata file: {}", e)))?; + + file.write_all(magic) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + file.write_all(&bytes.len().to_le_bytes()) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + file.write_all(bytes) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + + Ok(metadata_file_path) +} + +pub fn save_topology( + path: &Path, + uuid: Uuid, + bytes: &[u8], + magic: &Magic, +) -> Result { + let file_name = topology_filename(&uuid); + let topology_file_path = path.join(file_name); + + let mut file = std::fs::File::create(topology_file_path.clone()) + .map_err(|e| DbError::SnapshotError(format!("Could not create topology file: {}", e)))?; + + file.write_all(magic) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + file.write_all(&bytes.len().to_le_bytes()) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + file.write_all(bytes) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + + Ok(topology_file_path) +} + +pub fn compress_archive(path: &Path, files: &[&Path]) -> Result<(), Error> { + let tar_gz = File::create(path)?; + let enc = GzEncoder::new(tar_gz, Compression::default()); + let mut tar = Builder::new(enc); + + for file in files { + let rel_path = file.file_name().unwrap(); + let mut f = File::open(file)?; + tar.append_file(rel_path, &mut f)?; + } + + tar.into_inner()?; + Ok(()) +} + +pub fn read_index_topology(path: &Path) -> Result { + let mut file = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open topology file: {}", e)))?; + + let mut magic = Magic::default(); + file.read_exact(&mut magic).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read magic from topology file: {}", e)) + })?; + + let mut len_bytes = [0u8; size_of::()]; + file.read_exact(&mut len_bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read length from topology file: {}", e)) + })?; + let len = usize::from_le_bytes(len_bytes); + + let mut bytes = vec![0u8; len]; + file.read_exact(&mut bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read bytes from topology file: {}", e)) + })?; + + Ok((magic, bytes)) +} + +pub fn read_index_metadata(path: &Path) -> Result { + let mut file = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open metadata file: {}", e)))?; + + let mut magic = Magic::default(); + file.read_exact(&mut magic).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read magic from metadata file: {}", e)) + })?; + + let mut len_bytes = [0u8; size_of::()]; + file.read_exact(&mut len_bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read length from metadata file: {}", e)) + })?; + + let len = usize::from_le_bytes(len_bytes); + let mut bytes = vec![0u8; len]; + file.read_exact(&mut bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read bytes from metadata file: {}", e)) + })?; + + Ok((magic, bytes)) +} diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index c786373..4d2afb3 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -9,6 +9,9 @@ license.workspace = true [dependencies] bincode.workspace = true defs.workspace = true +flate2 = "1.1.5" rocksdb.workspace = true +serde.workspace = true +tar = "0.4.44" tempfile.workspace = true uuid.workspace = true diff --git a/crates/storage/src/checkpoint.rs b/crates/storage/src/checkpoint.rs new file mode 100644 index 0000000..09827fd --- /dev/null +++ b/crates/storage/src/checkpoint.rs @@ -0,0 +1,51 @@ +use crate::StorageType; +use crate::in_memory::INMEMORY_CHECKPOINT_FILENAME_MARKER; +use crate::rocks_db::ROCKSDB_CHECKPOINT_FILENAME_MARKER; +use defs::DbError; +use std::path::{Path, PathBuf}; + +impl StorageType { + #[inline] + pub fn checkpoint_filename_marker(&self) -> &str { + match self { + StorageType::InMemory => INMEMORY_CHECKPOINT_FILENAME_MARKER, + StorageType::RocksDb => ROCKSDB_CHECKPOINT_FILENAME_MARKER, + } + } +} + +pub struct StorageCheckpoint { + pub path: PathBuf, + pub storage_type: StorageType, +} + +impl StorageCheckpoint { + pub fn open(path: &Path) -> Result { + let filename = path + .file_name() + .ok_or_else(|| DbError::StorageCheckpointError("Invalid filename".to_string()))? + .to_str() + .ok_or_else(|| { + DbError::StorageCheckpointError("Invalid UTF-8 in filename".to_string()) + })? + .to_owned(); + let marker = filename + .split_once("-") + .ok_or_else(|| DbError::StorageCheckpointError("Invalid filename".to_string()))? + .0; + + let storage_type = match marker { + ROCKSDB_CHECKPOINT_FILENAME_MARKER => StorageType::RocksDb, + _ => { + return Err(DbError::StorageCheckpointError( + "Invalid storage type".to_string(), + )); + } + }; + + Ok(StorageCheckpoint { + path: path.to_path_buf(), + storage_type, + }) + } +} diff --git a/crates/storage/src/in_memory.rs b/crates/storage/src/in_memory.rs index 5190082..647627d 100644 --- a/crates/storage/src/in_memory.rs +++ b/crates/storage/src/in_memory.rs @@ -1,5 +1,9 @@ -use crate::{StorageEngine, VectorPage}; +use crate::StorageType; +use crate::{StorageEngine, VectorPage, checkpoint::StorageCheckpoint}; use defs::{DbError, DenseVector, Payload, PointId}; +use std::path::{Path, PathBuf}; + +pub const INMEMORY_CHECKPOINT_FILENAME_MARKER: &str = "inmemory"; pub struct MemoryStorage { // define here how MemoryStorage will be defined @@ -41,4 +45,15 @@ impl StorageEngine for MemoryStorage { fn list_vectors(&self, _offset: PointId, _limit: usize) -> Result, DbError> { Ok(None) } + + fn checkpoint_at(&self, _path: &Path) -> Result { + Ok(StorageCheckpoint { + path: PathBuf::default(), + storage_type: StorageType::InMemory, + }) + } + + fn restore_checkpoint(&mut self, _checkpoint: &StorageCheckpoint) -> Result<(), DbError> { + Ok(()) + } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f7c067e..8228f72 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,8 +1,9 @@ +use crate::rocks_db::RocksDbStorage; use defs::{DbError, DenseVector, Payload, PointId}; -use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use std::sync::Arc; - -use crate::rocks_db::RocksDbStorage; +pub mod checkpoint; pub type VectorPage = (Vec<(PointId, DenseVector)>, PointId); @@ -18,12 +19,18 @@ pub trait StorageEngine: Send + Sync { fn delete_point(&self, id: PointId) -> Result<(), DbError>; fn contains_point(&self, id: PointId) -> Result; fn list_vectors(&self, offset: PointId, limit: usize) -> Result, DbError>; + + fn checkpoint_at(&self, path: &Path) -> Result; + fn restore_checkpoint( + &mut self, + checkpoint: &checkpoint::StorageCheckpoint, + ) -> Result<(), DbError>; } pub mod in_memory; pub mod rocks_db; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] pub enum StorageType { InMemory, RocksDb, diff --git a/crates/storage/src/rocks_db.rs b/crates/storage/src/rocks_db.rs index f9c80ab..7d184d4 100644 --- a/crates/storage/src/rocks_db.rs +++ b/crates/storage/src/rocks_db.rs @@ -1,26 +1,44 @@ // Rewrite needed -use crate::{StorageEngine, VectorPage}; +use crate::StorageType; +use crate::{StorageEngine, VectorPage, checkpoint::StorageCheckpoint}; use bincode::{deserialize, serialize}; use defs::{DbError, DenseVector, Payload, Point, PointId}; +use flate2::{Compression, read::GzDecoder, write::GzEncoder}; use rocksdb::{DB, Error, Options}; -use std::path::PathBuf; +use std::{ + fs::File, + path::{Path, PathBuf}, +}; +use tar::{Archive, Builder}; +use tempfile::tempdir; //TODO: Implement RocksDbStorage with necessary fields and implementations //TODO: Optimize the basic design pub struct RocksDbStorage { pub path: PathBuf, - pub db: DB, + pub db: Option, } pub enum RocksDBStorageError { RocksDBError(Error), - SerializationError, } +pub const ROCKSDB_CHECKPOINT_FILENAME_MARKER: &str = "rocksdb"; + impl RocksDbStorage { // Creates new db or switches to existing db pub fn new(path: impl Into) -> Result { + let converted_path = path.into(); + let db = Self::initialize_db(&converted_path)?; + + Ok(RocksDbStorage { + path: converted_path, + db: Some(db), + }) + } + + fn initialize_db(path: &Path) -> Result { // Initialize a db at the given location let mut options = Options::default(); @@ -30,15 +48,8 @@ impl RocksDbStorage { options.create_if_missing(true); - let converted_path = path.into(); - - let db = DB::open(&options, converted_path.clone()) - .map_err(|e| DbError::StorageError(e.into_string()))?; - - Ok(RocksDbStorage { - path: converted_path, - db, - }) + let db = DB::open(&options, path).map_err(|e| DbError::StorageError(e.into_string()))?; + Ok(db) } pub fn get_current_path(&self) -> PathBuf { @@ -60,7 +71,12 @@ impl StorageEngine for RocksDbStorage { payload, }; let value = serialize(&point).map_err(|e| DbError::SerializationError(e.to_string()))?; - match self.db.put(key.as_bytes(), value.as_slice()) { + match self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .put(key.as_bytes(), value.as_slice()) + { Ok(_) => Ok(()), Err(e) => Err(DbError::StorageError(e.into_string())), } @@ -69,9 +85,16 @@ impl StorageEngine for RocksDbStorage { fn contains_point(&self, id: PointId) -> Result { // Efficient lookup inspired from https://github.com/facebook/rocksdb/issues/11586#issuecomment-1890429488 let key = id.to_string(); - if self.db.key_may_exist(key.clone()) { + if self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .key_may_exist(key.clone()) + { let key_exist = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? .is_some(); @@ -84,6 +107,8 @@ impl StorageEngine for RocksDbStorage { fn delete_point(&self, id: PointId) -> Result<(), DbError> { let key = id.to_string(); self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .delete(key) .map_err(|e| DbError::StorageError(e.into_string()))?; @@ -94,6 +119,8 @@ impl StorageEngine for RocksDbStorage { let key = id.to_string(); let Some(value_serialized) = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? else { @@ -110,6 +137,8 @@ impl StorageEngine for RocksDbStorage { let key = id.to_string(); let Some(value_serialized) = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? else { @@ -128,10 +157,14 @@ impl StorageEngine for RocksDbStorage { } let mut result = Vec::with_capacity(limit); - let iter = self.db.iterator(rocksdb::IteratorMode::From( - offset.to_string().as_bytes(), - rocksdb::Direction::Forward, - )); + let iter = self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .iterator(rocksdb::IteratorMode::From( + offset.to_string().as_bytes(), + rocksdb::Direction::Forward, + )); let mut last_id = offset; for item in iter { @@ -152,6 +185,124 @@ impl StorageEngine for RocksDbStorage { } Ok(Some((result, last_id))) } + + fn checkpoint_at(&self, path: &Path) -> Result { + // flush db first for durability + self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .flush() + .map_err(|e| { + DbError::StorageCheckpointError(format!( + "Failed to flush database: {}", + e.into_string() + )) + })?; + + // filename is rocksdb-{uuid}.tar.gz + let checkpoint_filename = format!( + "{}-{}.tar.gz", + ROCKSDB_CHECKPOINT_FILENAME_MARKER, + uuid::Uuid::new_v4() + ); + let checkpoint_path = path.join(checkpoint_filename); + + let temp_dir_parent = tempdir().unwrap(); + let temp_dir = temp_dir_parent.path().join("checkpoint"); + + let db_ref = self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)?; + let checkpoint = rocksdb::checkpoint::Checkpoint::new(db_ref) + .map_err(|e| DbError::StorageCheckpointError(e.into_string()))?; + checkpoint + .create_checkpoint(temp_dir.clone()) + .map_err(|e| DbError::StorageCheckpointError(e.into_string()))?; + + // compress the checkpoint into an archive + let tar_gz = File::create(checkpoint_path.clone()).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't create tar archive file: {}", e)) + })?; + let enc = GzEncoder::new(tar_gz, Compression::default()); + let mut archive = Builder::new(enc); + + archive.append_dir_all("", temp_dir).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't append directory to archive: {}", e)) + })?; + + let enc = archive.into_inner().map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't compress tar archive: {}", e)) + })?; + + enc.finish().map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't compress tar archive: {}", e)) + })?; + + Ok(StorageCheckpoint { + path: checkpoint_path, + storage_type: crate::StorageType::RocksDb, + }) + } + + fn restore_checkpoint(&mut self, checkpoint: &StorageCheckpoint) -> Result<(), DbError> { + // enforce storage type + if checkpoint.storage_type != StorageType::RocksDb { + return Err(DbError::StorageCheckpointError( + "Invalid storage type".to_string(), + )); + } + // enforce filename marker - should have been enforced during StoraegCheckpoint::open anyway + let checkpoint_filename = checkpoint + .path + .file_name() + .ok_or(DbError::StorageCheckpointError( + "Could not read checkpoint filename".to_string(), + ))? + .to_str() + .ok_or(DbError::StorageCheckpointError( + "Could not read checkpoint filename".to_string(), + ))?; + if !checkpoint_filename.ends_with(".tar.gz") + || !checkpoint_filename.starts_with(ROCKSDB_CHECKPOINT_FILENAME_MARKER) + { + return Err(DbError::StorageCheckpointError( + "Invalid filename".to_string(), + )); + } + + let tar_gz = File::open(&checkpoint.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't open rocksdb checkpoint: {}", e)) + })?; + let tar = GzDecoder::new(tar_gz); + let mut archive = Archive::new(tar); + + // remove existing stuff in data path + self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .cancel_all_background_work(true); + // drop db early + self.db = None; + + std::fs::remove_dir_all(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't remove existing data: {}", e)) + })?; + + // create new data path + std::fs::create_dir_all(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't create data path: {}", e)) + })?; + + archive.unpack(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't unpack tar.gz archive: {}", e)) + })?; + + // reinitialize db + self.db = Some(Self::initialize_db(&self.path)?); + + Ok(()) + } } #[cfg(test)] @@ -160,26 +311,24 @@ mod tests { use defs::ContentType; use uuid::Uuid; - use tempfile::tempdir; + use tempfile::{TempDir, tempdir}; - fn create_test_db() -> (RocksDbStorage, String) { + fn create_test_db() -> (RocksDbStorage, TempDir) { let temp_dir = tempdir().unwrap(); - let temp_dir_path = temp_dir.path().to_str().unwrap().to_string(); - let db = RocksDbStorage::new(temp_dir_path.clone()).expect("Failed to create RocksDB"); - (db, temp_dir_path) + let db = RocksDbStorage::new(temp_dir.path()).expect("Failed to create RocksDB"); + (db, temp_dir) } #[test] fn test_new_rocksdb_storage() { - let (db, path) = create_test_db(); - assert_eq!(db.get_current_path(), PathBuf::from(path.clone())); - std::fs::remove_dir_all(path).unwrap_or_default(); + let (db, temp_dir) = create_test_db(); + assert_eq!(db.get_current_path(), temp_dir.path()); } #[test] fn test_insert_and_get_vector() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); let vector = Some(vec![0.1, 0.2, 0.3]); let payload = Some(Payload { @@ -190,13 +339,11 @@ mod tests { assert!(db.insert_point(id, vector.clone(), payload).is_ok()); let result = db.get_vector(id).unwrap(); assert_eq!(result, vector); - - std::fs::remove_dir_all(path).unwrap_or_default(); } #[test] fn test_insert_and_get_payload() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); let payload = Some(Payload { content_type: ContentType::Text, @@ -212,13 +359,11 @@ mod tests { content: "Test".to_string(), }); assert_eq!(result, expected); - - std::fs::remove_dir_all(path).unwrap_or_default(); } #[test] fn test_contains_point() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); let payload = Some(Payload { content_type: ContentType::Text, @@ -231,13 +376,11 @@ mod tests { db.insert_point(id, vector, payload).unwrap(); assert!(db.contains_point(id).unwrap()); - - std::fs::remove_dir_all(path).unwrap_or_default(); } #[test] fn test_delete_point() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); let payload = Some(Payload { content_type: ContentType::Text, @@ -255,27 +398,54 @@ mod tests { assert!(!db.contains_point(id).unwrap()); assert_eq!(db.get_vector(id).unwrap(), None); assert_eq!(db.get_payload(id).unwrap(), None); - - std::fs::remove_dir_all(path).unwrap_or_default(); } #[test] fn test_get_nonexistent_vector() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); assert_eq!(db.get_vector(id).unwrap(), None); - - std::fs::remove_dir_all(path).unwrap_or_default(); } #[test] fn test_get_nonexistent_payload() { - let (db, path) = create_test_db(); + let (db, _temp_dir) = create_test_db(); let id = Uuid::new_v4(); assert_eq!(db.get_payload(id).unwrap(), None); + } + + #[test] + fn test_create_and_load_checkpoint() { + let (mut db, temp_dir) = create_test_db(); + + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + + let vector = Some(vec![0.1, 0.2, 0.3]); + let payload = Some(Payload { + content_type: ContentType::Text, + content: "Test".to_string(), + }); + + assert!( + db.insert_point(id1, vector.clone(), payload.clone()) + .is_ok() + ); + + let checkpoint = db + .checkpoint_at(temp_dir.path()) + .expect("Failed to create checkpoint"); + + assert!( + db.insert_point(id2, vector.clone(), payload.clone()) + .is_ok() + ); + + db.restore_checkpoint(&checkpoint).unwrap(); - std::fs::remove_dir_all(path).unwrap_or_default(); + assert!(db.contains_point(id1).unwrap()); + assert!(!db.contains_point(id2).unwrap()); } }