diff --git a/.env.example b/.env.example index 0240126..8f8dce1 100644 --- a/.env.example +++ b/.env.example @@ -1,10 +1,26 @@ -NUMBER_OF_DATABASE=3 -DATABASE_PATH1=Enter-path-1-here -DATABASE_NAME1=Enter-name-1-here -DATABASE_PATH2=Enter-path-2-here -DATABASE_NAME2=Enter-name-2-here -DATABASE_PATH3=Enter-path-3-here -DATABASE_NAME3=Enter-name-3-here +# HTTP Server +HTTP_HOST=127.0.0.1 +HTTP_PORT=3000 +# gRPC Server +GRPC_HOST=127.0.0.1 +GRPC_PORT=50051 +# required +GRPC_ROOT_PASSWORD=your-secure-password + +# Database Configuration +# Storage: inmemory, rocksdb +# Index: flat, kdtree, hnsw +STORAGE_TYPE=rocksdb +INDEX_TYPE=flat +DIMENSION=512 + +DATA_PATH=./data + +# Server Options +LOGGING=true +DISABLE_HTTP=false + +# Embedding Services (TUI) TEXT_EMBEDDING_URL=http://localhost:8080/vectors IMAGE_EMBEDDING_URL=http://localhost:8080/vectors_img diff --git a/.gitignore b/.gitignore index ca03682..7eb2e3b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ /data/ .TODO /databases +.env diff --git a/Cargo.lock b/Cargo.lock index e5bc60d..00dd523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,7 +91,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http", + "http 1.4.0", "http-body", "http-body-util", "hyper", @@ -122,7 +122,7 @@ checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" dependencies = [ "bytes", "futures-core", - "http", + "http 1.4.0", "http-body", "http-body-util", "mime", @@ -145,7 +145,7 @@ dependencies = [ "bytesize", "cookie", "expect-json", - "http", + "http 1.4.0", "http-body-util", "hyper", "hyper-util", @@ -706,7 +706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] -name = "grpc_server" +name = "grpc" version = "0.1.0" dependencies = [ "api", @@ -739,7 +739,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.4.0", "indexmap", "slab", "tokio", @@ -770,6 +770,25 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "http" +version = "0.1.0" +dependencies = [ + "api", + "axum", + "axum-test", + "defs", + "dotenvy", + "index", + "serde", + "serde_json", + "storage", + "tempfile", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "http" version = "1.4.0" @@ -787,7 +806,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -798,30 +817,11 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", + "http 1.4.0", "http-body", "pin-project-lite", ] -[[package]] -name = "http_server" -version = "0.1.0" -dependencies = [ - "api", - "axum", - "axum-test", - "defs", - "dotenvy", - "index", - "serde", - "serde_json", - "storage", - "tempfile", - "tokio", - "tracing", - "tracing-subscriber", -] - [[package]] name = "httparse" version = "1.10.1" @@ -845,7 +845,7 @@ dependencies = [ "futures-channel", "futures-core", "h2", - "http", + "http 1.4.0", "http-body", "httparse", "httpdate", @@ -863,7 +863,7 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", + "http 1.4.0", "hyper", "hyper-util", "rustls", @@ -913,7 +913,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", + "http 1.4.0", "http-body", "hyper", "ipnet", @@ -1840,7 +1840,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 1.4.0", "http-body", "http-body-util", "hyper", @@ -1912,7 +1912,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", + "http 1.4.0", "mime", "rand", "thiserror", @@ -2101,6 +2101,26 @@ dependencies = [ [[package]] name = "server" version = "0.1.0" +dependencies = [ + "api", + "axum", + "defs", + "dotenv", + "grpc", + "http 0.1.0", + "index", + "prost", + "serde", + "serde_json", + "storage", + "tempfile", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-subscriber", + "uuid", +] [[package]] name = "sharded-slab" @@ -2449,7 +2469,7 @@ dependencies = [ "base64", "bytes", "h2", - "http", + "http 1.4.0", "http-body", "http-body-util", "hyper", @@ -2534,7 +2554,7 @@ dependencies = [ "bitflags 2.10.0", "bytes", "futures-util", - "http", + "http 1.4.0", "http-body", "iri-string", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 5825d64..397facf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,9 +5,9 @@ members = [ "crates/storage", "crates/index", "crates/server", - "crates/http_server", + "crates/http", "crates/tui", - "crates/grpc_server", + "crates/grpc", ] # You can define shared dependencies for all crates here diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index f190682..72739ac 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -131,6 +131,7 @@ impl VectorDb { } } +#[derive(Debug)] pub struct DbConfig { pub storage_type: StorageType, pub index_type: IndexType, diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 3cbdac9..8f929c3 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -29,3 +29,6 @@ impl std::fmt::Display for DbError { } impl std::error::Error for DbError {} + +// Error type for server +pub type BoxError = Box; diff --git a/crates/grpc_server/Cargo.toml b/crates/grpc/Cargo.toml similarity index 96% rename from crates/grpc_server/Cargo.toml rename to crates/grpc/Cargo.toml index 8b2030c..ddb7c79 100644 --- a/crates/grpc_server/Cargo.toml +++ b/crates/grpc/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "grpc_server" +name = "grpc" version = "0.1.0" edition = "2024" diff --git a/crates/grpc_server/README.md b/crates/grpc/README.md similarity index 100% rename from crates/grpc_server/README.md rename to crates/grpc/README.md diff --git a/crates/grpc_server/build.rs b/crates/grpc/build.rs similarity index 100% rename from crates/grpc_server/build.rs rename to crates/grpc/build.rs diff --git a/crates/grpc_server/proto/vector-db.proto b/crates/grpc/proto/vector-db.proto similarity index 100% rename from crates/grpc_server/proto/vector-db.proto rename to crates/grpc/proto/vector-db.proto diff --git a/crates/grpc_server/src/constants.rs b/crates/grpc/src/constants.rs similarity index 100% rename from crates/grpc_server/src/constants.rs rename to crates/grpc/src/constants.rs diff --git a/crates/grpc_server/src/errors.rs b/crates/grpc/src/errors.rs similarity index 100% rename from crates/grpc_server/src/errors.rs rename to crates/grpc/src/errors.rs diff --git a/crates/grpc_server/src/interceptors.rs b/crates/grpc/src/interceptors.rs similarity index 100% rename from crates/grpc_server/src/interceptors.rs rename to crates/grpc/src/interceptors.rs diff --git a/crates/grpc/src/lib.rs b/crates/grpc/src/lib.rs new file mode 100644 index 0000000..c1a1930 --- /dev/null +++ b/crates/grpc/src/lib.rs @@ -0,0 +1,32 @@ +pub mod constants; +pub mod errors; +pub mod interceptors; +pub mod service; +pub mod utils; + +use api::VectorDb; +use defs::BoxError; +use service::{VectorDBService, run_server}; +use std::net::SocketAddr; +use std::sync::Arc; +use utils::ServerEndpoint; + +/// Runs the gRPC server on the specified address. +pub async fn run_grpc_server( + db: Arc, + addr: SocketAddr, + root_password: String, + logging: bool, +) -> Result<(), BoxError> { + let vector_db_service = VectorDBService::new(db, logging); + run_server( + vector_db_service, + ServerEndpoint::Address(addr), + root_password, + ) + .await + .map_err(|e| -> BoxError { e.to_string().into() }) +} + +#[cfg(test)] +mod tests; diff --git a/crates/grpc_server/src/service.rs b/crates/grpc/src/service.rs similarity index 96% rename from crates/grpc_server/src/service.rs rename to crates/grpc/src/service.rs index 9eeea96..605242d 100644 --- a/crates/grpc_server/src/service.rs +++ b/crates/grpc/src/service.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::sync::Arc; use crate::interceptors; use crate::service::vectordb::{ContentType, Uuid}; @@ -18,10 +19,16 @@ pub mod vectordb { } pub struct VectorDBService { - pub vector_db: api::VectorDb, + pub vector_db: Arc, pub logging: bool, } +impl VectorDBService { + pub fn new(vector_db: Arc, logging: bool) -> Self { + Self { vector_db, logging } + } +} + #[tonic::async_trait] impl VectorDb for VectorDBService { async fn insert_vector( diff --git a/crates/grpc_server/src/tests.rs b/crates/grpc/src/tests.rs similarity index 93% rename from crates/grpc_server/src/tests.rs rename to crates/grpc/src/tests.rs index 820a4ea..6f77329 100644 --- a/crates/grpc_server/src/tests.rs +++ b/crates/grpc/src/tests.rs @@ -1,16 +1,14 @@ -use crate::config::GRPCServerConfig; use crate::constants::AUTHORIZATION_HEADER_KEY; use crate::service::vectordb::vector_db_client::VectorDbClient; use crate::service::vectordb::{DenseVector, InsertVectorRequest, Payload, PointId, SearchRequest}; use crate::service::{VectorDBService, run_server}; use crate::utils::ServerEndpoint; -use api; use api::DbConfig; use index::IndexType; use std::net::SocketAddr; +use std::sync::Arc; use storage::StorageType; use tempfile::tempdir; -use tokio; use tonic::transport::Channel; // Inspired from https://github.com/hyperium/tonic/discussions/924#discussioncomment-9854088 @@ -35,28 +33,18 @@ async fn start_test_server() -> Result> { dimension: 3, }; - let config = GRPCServerConfig { - addr: "127.0.0.1:0".parse()?, - root_password: TEST_AUTH_BEARER_TOKEN.to_string(), - logging: false, - db_config, - }; - - let vector_db_api = api::init_api(config.db_config)?; + let vector_db_api = api::init_api(db_config)?; - let vector_db_service = VectorDBService { - vector_db: vector_db_api, - logging: config.logging, - }; + let vector_db_service = VectorDBService::new(Arc::new(vector_db_api), false); - let listener = tokio::net::TcpListener::bind(config.addr).await?; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; let listener_addr = listener.local_addr()?; tokio::spawn(async move { let _ = run_server( vector_db_service, ServerEndpoint::Listener(listener), - config.root_password, + TEST_AUTH_BEARER_TOKEN.to_string(), ) .await .inspect_err(|err| panic!("Could not start test server : {:?}", err)); diff --git a/crates/grpc_server/src/utils.rs b/crates/grpc/src/utils.rs similarity index 100% rename from crates/grpc_server/src/utils.rs rename to crates/grpc/src/utils.rs diff --git a/crates/grpc_server/.sample.env b/crates/grpc_server/.sample.env deleted file mode 100644 index 85c9030..0000000 --- a/crates/grpc_server/.sample.env +++ /dev/null @@ -1,9 +0,0 @@ -GRPC_SERVER_ROOT_PASSWORD=123 # required -GRPC_SERVER_DIMENSION=3 # required - -GRPC_SERVER_HOST=127.0.0.1 # defaults to 127.0.0.1 aka localhost -GRPC_SERVER_PORT=8080 # defaults to 8080 -GRPC_SERVER_STORAGE_TYPE=inmemory # (inmemory/rocksdb) defaults to 'inmemory' -GRPC_SERVER_INDEX_TYPE=flat # defaults to flat -GRPC_SERVER_DATA_PATH=data # defaults to a temporary directory -GRPC_SERVER_LOGGING=true # defaults to true diff --git a/crates/grpc_server/src/config.rs b/crates/grpc_server/src/config.rs deleted file mode 100644 index 6d9137c..0000000 --- a/crates/grpc_server/src/config.rs +++ /dev/null @@ -1,125 +0,0 @@ -use crate::constants::{ - self, DEFAULT_PORT, ENV_DATA_PATH, ENV_DIMENSION, ENV_INDEX_TYPE, ENV_LOGGING, ENV_PORT, - ENV_ROOT_PASSWORD, ENV_STORAGE_TYPE, -}; -use crate::errors; -use api; -use dotenv::dotenv; -use index::IndexType; -use std::net::SocketAddr; -use std::path::PathBuf; -use std::{env, fs}; -use storage; -use tempfile::tempdir; -use tracing::{Level, event}; - -pub struct GRPCServerConfig { - pub addr: SocketAddr, - pub root_password: String, - pub db_config: api::DbConfig, - pub logging: bool, -} - -impl GRPCServerConfig { - pub fn load_config() -> Result> { - dotenv().ok(); - - // fetch server host; default to localhost if not defined - let host = env::var(constants::ENV_HOST) - .inspect_err(|_| { - event!(Level::WARN, "Host not defined, defaulting to 'localhost'"); - }) - .unwrap_or("127.0.0.1".to_string()); - - // fetch server port; default to 8080 if not defined - let port: u32 = env::var(ENV_PORT) - .inspect_err(|_| { - event!( - Level::WARN, - "Port not defined, defaulting to {}", - DEFAULT_PORT - ); - }) - .unwrap_or(DEFAULT_PORT.to_string()) - .parse() - .unwrap_or(DEFAULT_PORT.parse::().unwrap()); - - // fetch server root password; return err if not defined - let root_password = env::var(ENV_ROOT_PASSWORD).map_err(|_| { - errors::ConfigError::MissingRequiredEnvVar(ENV_ROOT_PASSWORD.to_string()) - })?; - - // fetch server storage type - let storage_type_str = env::var(ENV_STORAGE_TYPE) - .inspect_err(|_| { - event!( - Level::WARN, - "Storage Type not defined, defaulting to InMemory" - ) - }) - .unwrap_or_default(); - let storage_type = match storage_type_str.as_str() { - "inmemory" => storage::StorageType::InMemory, - "rocksdb" => storage::StorageType::RocksDb, - _ => storage::StorageType::InMemory, // default to InMemory if not specified - }; - - // fetch server index type - let index_type_str = env::var(ENV_INDEX_TYPE) - .inspect_err(|_| event!(Level::WARN, "Index Type not defined, defaulting to flat")) - .unwrap_or("flat".to_string()) - .to_lowercase(); - let index_type = match index_type_str.as_str() { - "flat" => IndexType::Flat, - "kdtree" => IndexType::KDTree, - "hnsw" => IndexType::HNSW, - _ => IndexType::Flat, // default to Flat if not specified - }; - - // fetch dimension size - let dimension: usize = env::var(ENV_DIMENSION) - .map_err(|_| errors::ConfigError::MissingRequiredEnvVar(ENV_DIMENSION.to_string()))? - .parse() - .map_err(|_| errors::ConfigError::InvalidDimension)?; - - // fetch data path; create tempdir if not specified - let data_path: PathBuf; - if let Ok(data_path_str) = env::var(ENV_DATA_PATH) { - data_path = PathBuf::from(data_path_str); - fs::create_dir_all(&data_path).map_err(|_| errors::ConfigError::InvalidDataPath)?; - } else { - let tempbuf = tempdir().unwrap().path().to_path_buf().join("vectordb"); - fs::create_dir_all(&tempbuf)?; - event!( - Level::WARN, - "Data Path not specified, using temporary directory: {:?}", - tempbuf.clone() - ); - data_path = tempbuf; - } - - // create db config for api - let db_config = api::DbConfig { - storage_type, - index_type, - data_path, - dimension, - }; - - // create socket address for grpc server - let addr: SocketAddr = format!("{}:{}", host, port).parse()?; - - // check if logging is enabled - let mut logging: bool = true; // default to logging enabled - if let Ok(logging_str) = env::var(ENV_LOGGING) { - logging = logging_str.parse().unwrap_or(true); - } - - Ok(GRPCServerConfig { - addr, - root_password, - db_config, - logging, - }) - } -} diff --git a/crates/grpc_server/src/lib.rs b/crates/grpc_server/src/lib.rs deleted file mode 100644 index a14557d..0000000 --- a/crates/grpc_server/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod config; -pub mod constants; -pub mod errors; -pub mod interceptors; -pub mod service; -pub mod utils; - -#[cfg(test)] -pub mod tests; diff --git a/crates/grpc_server/src/main.rs b/crates/grpc_server/src/main.rs deleted file mode 100644 index 83db6dd..0000000 --- a/crates/grpc_server/src/main.rs +++ /dev/null @@ -1,30 +0,0 @@ -use grpc_server::config::GRPCServerConfig; -use grpc_server::service::{VectorDBService, run_server}; -use grpc_server::utils::ServerEndpoint; -use std::panic; - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing_subscriber::fmt::init(); - - // load config from environment from environment variables - let config = GRPCServerConfig::load_config() - .inspect_err(|err| panic!("Failed to load config: {}", err)) - .unwrap(); - - let vector_db_api = api::init_api(config.db_config) - .inspect_err(|err| panic!("Failed to Init API: {:?}", err)) - .unwrap(); - - let vector_db_service = VectorDBService { - vector_db: vector_db_api, - logging: config.logging, - }; - run_server( - vector_db_service, - ServerEndpoint::Address(config.addr), - config.root_password, - ) - .await?; - Ok(()) -} diff --git a/crates/http_server/Cargo.toml b/crates/http/Cargo.toml similarity index 95% rename from crates/http_server/Cargo.toml rename to crates/http/Cargo.toml index a98a2cb..ce94b9f 100644 --- a/crates/http_server/Cargo.toml +++ b/crates/http/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "http_server" +name = "http" version = "0.1.0" edition = "2024" diff --git a/crates/http_server/src/handler.rs b/crates/http/src/handler.rs similarity index 98% rename from crates/http_server/src/handler.rs rename to crates/http/src/handler.rs index 8b94008..3b2cab0 100644 --- a/crates/http_server/src/handler.rs +++ b/crates/http/src/handler.rs @@ -24,6 +24,10 @@ pub async fn root_handler() -> &'static str { "Vector Database server is running!" } +pub async fn health_handler() -> &'static str { + "OK" +} + pub async fn insert_point_handler( State(app_state): State, Json(request): Json, diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs new file mode 100644 index 0000000..8554157 --- /dev/null +++ b/crates/http/src/lib.rs @@ -0,0 +1,46 @@ +pub mod handler; + +use api::VectorDb; +use axum::{ + Router, + routing::{get, post}, +}; +use defs::BoxError; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpListener; +use tracing::info; + +use handler::{ + delete_point_handler, get_point_handler, health_handler, insert_point_handler, root_handler, + search_points_handler, +}; + +#[derive(Clone)] +pub struct AppState { + pub db: Arc, +} + +/// Creates the HTTP router with all VectorDB routes. +pub fn create_router(db: Arc) -> Router { + let app_state = AppState { db }; + Router::new() + .route("/", get(root_handler)) + .route("/health", get(health_handler)) + .route("/points", post(insert_point_handler)) + .route( + "/points/{id}", + get(get_point_handler).delete(delete_point_handler), + ) + .route("/points/search", post(search_points_handler)) + .with_state(app_state) +} + +/// Runs the HTTP server on the specified address. +pub async fn run_http_server(db: Arc, addr: SocketAddr) -> Result<(), BoxError> { + let app = create_router(db); + let listener = TcpListener::bind(addr).await?; + info!("HTTP server listening on http://{}", addr); + axum::serve(listener, app.into_make_service()).await?; + Ok(()) +} diff --git a/crates/http_server/src/config.rs b/crates/http_server/src/config.rs deleted file mode 100644 index c4c9101..0000000 --- a/crates/http_server/src/config.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::env; -use std::net::SocketAddr; -use std::path::PathBuf; - -pub struct Config { - pub listen_addr: SocketAddr, - pub db_path: PathBuf, - pub vector_dimension: usize, -} - -impl Config { - pub fn from_env() -> Self { - // Load listen address - let listen_addr_str = - env::var("LISTEN_ADDR").unwrap_or_else(|_| "127.0.0.1:3000".to_string()); - let listen_addr = listen_addr_str - .parse() - .expect("Failed to parse LISTEN_ADDR"); - - // Load database path - let db_path_str = env::var("DB_PATH").unwrap_or_else(|_| "./data/vectordb".to_string()); - let db_path = PathBuf::from(db_path_str); - - // Load vector dimension - let vector_dimension_str = env::var("VECTOR_DIMENSION").unwrap_or_else(|_| "3".to_string()); - let vector_dimension = vector_dimension_str - .parse() - .expect("Failed to parse VECTOR_DIMENSION"); - - Self { - listen_addr, - db_path, - vector_dimension, - } - } -} diff --git a/crates/http_server/src/main.rs b/crates/http_server/src/main.rs deleted file mode 100644 index 62d39e2..0000000 --- a/crates/http_server/src/main.rs +++ /dev/null @@ -1,153 +0,0 @@ -mod config; -mod handler; - -use api::{DbConfig, VectorDb, init_api}; -use axum::{ - Router, - routing::{get, post}, -}; -use config::Config; -use defs::{AppError, ServerError}; -use index::IndexType; -use storage::StorageType; -use tokio::net::TcpListener; -use tracing::info; - -use handler::{ - delete_point_handler, get_point_handler, insert_point_handler, root_handler, - search_points_handler, -}; -use std::sync::Arc; - -#[derive(Clone)] -struct AppState { - db: Arc, -} - -pub fn app(db: Arc) -> Router { - let app_state = AppState { db }; - Router::new() - .route("/", get(root_handler)) - .route("/points", post(insert_point_handler)) - .route( - "/points/{id}", - get(get_point_handler).delete(delete_point_handler), - ) - .route("/points/search", post(search_points_handler)) - .with_state(app_state) -} -#[tokio::main] -async fn main() -> Result<(), AppError> { - tracing_subscriber::fmt::init(); - - let config = Config::from_env(); - info!("Loaded configuration: {:?}", config.db_path); - info!("Vector dimension set to: {}", config.vector_dimension); - - if let Some(parent) = config.db_path.parent() { - std::fs::create_dir_all(parent).expect("Failed to create database directory"); - } - - // db init - let db_config = DbConfig { - storage_type: StorageType::RocksDb, - index_type: IndexType::Flat, - data_path: config.db_path, - dimension: config.vector_dimension, - }; - - let db = init_api(db_config).map_err(AppError::DbError)?; - - // axum init - info!(" Server listening on http://{}", config.listen_addr); - - let app = app(Arc::new(db)); - - let listener = TcpListener::bind(config.listen_addr) - .await - .map_err(|err| AppError::ServerError(ServerError::Bind(err)))?; - axum::serve(listener, app.into_make_service()) - .await - .map_err(|err| AppError::ServerError(ServerError::Serve(err)))?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::handler::SearchResponse; - use axum::http::StatusCode; - use axum_test::TestServer; - use defs::{DenseVector, Point}; - use serde_json::json; - use tempfile::tempdir; - - fn create_test_db() -> VectorDb { - let temp_dir = tempdir().unwrap(); - let config = DbConfig { - storage_type: StorageType::RocksDb, - index_type: IndexType::Flat, - data_path: temp_dir.path().to_path_buf(), - dimension: 2, - }; - init_api(config).unwrap() - } - - fn setup_test_server() -> TestServer { - let db = Arc::new(create_test_db()); - let test_app = app(db); - TestServer::new(test_app).unwrap() - } - - #[tokio::test] - async fn test_all_routes() { - let server = setup_test_server(); - // 1 Insert a point - let insert_response = server - .post("/points") - .json( - &json!({"vector": [0.1, 0.2], "payload": {"content_type": "Image", "content": "tester"}}), - ) - .await; - assert_eq!(insert_response.status_code(), StatusCode::CREATED); - println!("Insert Test passed"); - - let insert_result: handler::InsertResponse = insert_response.json(); - let point_id = insert_result.point_id; - - // 2 Get the point back - let get_response = server.get(&format!("/points/{}", point_id)).await; - get_response.assert_status_ok(); - let point: Point = get_response.json(); - - let expected_vec: DenseVector = vec![0.1, 0.2]; - assert_eq!(point.vector.unwrap(), expected_vec); - println!("Retrieval Test passed"); - - println!("Deletion Test passed"); - - // 3 Search for the point - let search_response = server - .post("/points/search") - .json(&json!({ - "vector": [0.11, 0.22], - "similarity": "Cosine", - "limit": 1 - })) - .await; - search_response.assert_status_ok(); - println!("{:?}", search_response); - let search_results: SearchResponse = search_response.json(); - assert_eq!(search_results.results.len(), 1); - assert_eq!(search_results.results[0].to_string(), point_id.to_string()); - println!("Search Test passed"); - - // 4 Delete the point - let delete_response = server.delete(&format!("/points/{}", point_id)).await; - assert_eq!(delete_response.status_code(), StatusCode::NO_CONTENT); - - let get_after_delete_response = server.get(&format!("/points/{}", point_id)).await; - get_after_delete_response.assert_status_not_found(); - } -} diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index cd363b0..bd802b2 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -58,6 +58,7 @@ pub fn distance(a: DenseVector, b: DenseVector, dist_type: Similarity) -> f32 { } } +#[derive(Debug, Clone, Copy)] pub enum IndexType { Flat, KDTree, diff --git a/crates/server/.env.example b/crates/server/.env.example deleted file mode 100644 index 2f97e60..0000000 --- a/crates/server/.env.example +++ /dev/null @@ -1,10 +0,0 @@ -# ./.env - -# The IP address and port for the server to listen on -LISTEN_ADDR="127.0.0.1:3000" - -# The file system path for the database storage -DB_PATH="./data/vectordb" - -# The dimension of the vectors database will store -VECTOR_DIMENSION="3" \ No newline at end of file diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index a35055c..4a91211 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -1,6 +1,42 @@ [package] name = "server" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] +# Async runtime +tokio = { version = "1", features = ["full"] } + +# HTTP server (Axum) +axum = "0.8" + +# gRPC server (Tonic) +tonic = "0.14.2" +prost = "0.14.1" + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Logging +tracing = "0.1" +tracing-subscriber = "0.3" + +# Config +dotenv = "0.15.0" + +# Local crates +api = { path = "../api" } +defs = { path = "../defs" } +index = { path = "../index" } +storage = { path = "../storage" } +grpc = { path = "../grpc" } +http = { path = "../http" } + +# Other +tempfile = "3.23.0" +uuid.workspace = true +tokio-stream = "0.1.17" + +[dev-dependencies] +tonic = "0.14.2" diff --git a/crates/server/src/config.rs b/crates/server/src/config.rs new file mode 100644 index 0000000..481d761 --- /dev/null +++ b/crates/server/src/config.rs @@ -0,0 +1,191 @@ +use api::DbConfig; +use dotenv::dotenv; +use index::IndexType; +use std::env; +use std::fs; +use std::net::SocketAddr; +use std::path::PathBuf; +use storage::StorageType; +use tempfile::tempdir; +use tracing::{Level, event}; + +const DEFAULT_HTTP_PORT: &str = "3000"; +const DEFAULT_GRPC_PORT: &str = "50051"; + +#[derive(Debug)] +pub struct ServerConfig { + pub http_addr: SocketAddr, + pub grpc_addr: SocketAddr, + pub grpc_root_password: String, + pub db_config: DbConfig, + pub logging: bool, + pub disable_http: bool, +} + +#[derive(Debug)] +pub enum ConfigError { + MissingRequiredEnvVar(String), + InvalidDimension, + InvalidDataPath, + InvalidAddress(String), + IoError(String), +} + +impl std::fmt::Display for ConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConfigError::MissingRequiredEnvVar(var) => { + write!(f, "Missing required environment variable: {}", var) + } + ConfigError::InvalidDimension => write!(f, "Invalid dimension value"), + ConfigError::InvalidDataPath => write!(f, "Invalid data path"), + ConfigError::InvalidAddress(addr) => write!(f, "Invalid address: {}", addr), + ConfigError::IoError(err) => write!(f, "IO error: {}", err), + } + } +} + +impl std::error::Error for ConfigError {} + +impl ServerConfig { + pub fn load_config() -> Result { + dotenv().ok(); + + // HTTP server configuration + let http_host = env::var("HTTP_HOST") + .inspect_err(|_| { + event!( + Level::WARN, + "HTTP_HOST not defined, defaulting to '127.0.0.1'" + ); + }) + .unwrap_or_else(|_| "127.0.0.1".to_string()); + + let http_port = env::var("HTTP_PORT") + .inspect_err(|_| { + event!( + Level::WARN, + "HTTP_PORT not defined, defaulting to {}", + DEFAULT_HTTP_PORT + ); + }) + .unwrap_or_else(|_| DEFAULT_HTTP_PORT.to_string()); + + let http_addr: SocketAddr = format!("{}:{}", http_host, http_port) + .parse() + .map_err(|_| ConfigError::InvalidAddress(format!("{}:{}", http_host, http_port)))?; + + // gRPC server configuration + let grpc_host = env::var("GRPC_HOST") + .inspect_err(|_| { + event!( + Level::WARN, + "GRPC_HOST not defined, defaulting to '127.0.0.1'" + ); + }) + .unwrap_or_else(|_| "127.0.0.1".to_string()); + + let grpc_port = env::var("GRPC_PORT") + .inspect_err(|_| { + event!( + Level::WARN, + "GRPC_PORT not defined, defaulting to {}", + DEFAULT_GRPC_PORT + ); + }) + .unwrap_or_else(|_| DEFAULT_GRPC_PORT.to_string()); + + let grpc_addr: SocketAddr = format!("{}:{}", grpc_host, grpc_port) + .parse() + .map_err(|_| ConfigError::InvalidAddress(format!("{}:{}", grpc_host, grpc_port)))?; + + // gRPC root password (required) + let grpc_root_password = env::var("GRPC_ROOT_PASSWORD") + .map_err(|_| ConfigError::MissingRequiredEnvVar("GRPC_ROOT_PASSWORD".to_string()))?; + + // Storage type + let storage_type_str = env::var("STORAGE_TYPE") + .inspect_err(|_| { + event!( + Level::WARN, + "STORAGE_TYPE not defined, defaulting to InMemory" + ); + }) + .unwrap_or_default(); + + let storage_type = match storage_type_str.to_lowercase().as_str() { + "inmemory" => StorageType::InMemory, + "rocksdb" => StorageType::RocksDb, + _ => StorageType::InMemory, + }; + + // Index type + let index_type_str = env::var("INDEX_TYPE") + .inspect_err(|_| { + event!(Level::WARN, "INDEX_TYPE not defined, defaulting to flat"); + }) + .unwrap_or_else(|_| "flat".to_string()) + .to_lowercase(); + + let index_type = match index_type_str.as_str() { + "flat" => IndexType::Flat, + "kdtree" => IndexType::KDTree, + "hnsw" => IndexType::HNSW, + _ => IndexType::Flat, + }; + + // Dimension (required) + let dimension: usize = env::var("DIMENSION") + .map_err(|_| ConfigError::MissingRequiredEnvVar("DIMENSION".to_string()))? + .parse() + .map_err(|_| ConfigError::InvalidDimension)?; + + // Data path + let data_path: PathBuf = if let Ok(data_path_str) = env::var("DATA_PATH") { + let path = PathBuf::from(data_path_str); + fs::create_dir_all(&path).map_err(|_| ConfigError::InvalidDataPath)?; + path + } else { + let tempbuf = tempdir() + .map_err(|e| ConfigError::IoError(e.to_string()))? + .path() + .to_path_buf() + .join("vectordb"); + fs::create_dir_all(&tempbuf).map_err(|e| ConfigError::IoError(e.to_string()))?; + event!( + Level::WARN, + "DATA_PATH not specified, using temporary directory: {:?}", + tempbuf + ); + tempbuf + }; + + // Logging + let logging = env::var("LOGGING") + .unwrap_or_else(|_| "true".to_string()) + .parse() + .unwrap_or(true); + + // HTTP server disable flag (default to false, set to true to run only gRPC) + let disable_http = env::var("DISABLE_HTTP") + .unwrap_or_else(|_| "false".to_string()) + .parse() + .unwrap_or(false); + + let db_config = DbConfig { + storage_type, + index_type, + data_path, + dimension, + }; + + Ok(ServerConfig { + http_addr, + grpc_addr, + grpc_root_password, + db_config, + logging, + disable_http, + }) + } +} diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 10f5b73..e10ba39 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,15 +1,113 @@ -use std::fmt::Error; -// Import from other crates -// use defs; // Import the entire crate -// use api::init_api_server; -// use index::some_module; // Import specific module -// use storage::{Type1, Type2}; // Import specific types -// use api::prelude::*; // Import everything from prelude - -fn main() -> Result<(), Error> { - // Start tracing - // Load configs for DB - // Start API and/or gRPC server - // let _ = init_api_server(); +mod config; + +use std::sync::Arc; + +use config::ServerConfig; +use defs::BoxError; +use grpc::run_grpc_server; +use http::run_http_server; +use tokio::signal; +use tracing::{error, info}; + +#[tokio::main] +async fn main() -> Result<(), BoxError> { + tracing_subscriber::fmt::init(); + + info!("Starting VortexDB unified server..."); + + let config = + ServerConfig::load_config().inspect_err(|err| error!("Failed to load config: {}", err))?; + + info!("Configuration loaded successfully"); + if !config.disable_http { + info!("HTTP server will listen on: {}", config.http_addr); + } + info!("gRPC server will listen on: {}", config.grpc_addr); + + let vector_db = api::init_api(config.db_config) + .inspect_err(|err| error!("Failed to init API: {:?}", err))?; + + let shared_db = Arc::new(vector_db); + info!("VectorDb initialized successfully"); + + // Spawn HTTP server task if enabled + let http_handle = if !config.disable_http { + let db = Arc::clone(&shared_db); + let addr = config.http_addr; + Some(tokio::spawn(async move { run_http_server(db, addr).await })) + } else { + info!("HTTP server is disabled"); + None + }; + + // Spawn gRPC server task + let grpc_handle = { + let db = Arc::clone(&shared_db); + let addr = config.grpc_addr; + let password = config.grpc_root_password; + let logging = config.logging; + tokio::spawn(async move { run_grpc_server(db, addr, password, logging).await }) + }; + + if let Some(http) = http_handle { + tokio::select! { + _ = shutdown_signal() => { + info!("Stopping servers"); + } + result = http => { + match result { + Err(e) => error!("HTTP server task error: {}", e), + Ok(Err(e)) => error!("HTTP server error: {}", e), + Ok(Ok(())) => {} + } + } + result = grpc_handle => { + match result { + Err(e) => error!("gRPC server task error: {}", e), + Ok(Err(e)) => error!("gRPC server error: {}", e), + Ok(Ok(())) => {} + } + } + } + } else { + tokio::select! { + _ = shutdown_signal() => { + info!("Stopping servers"); + } + result = grpc_handle => { + match result { + Err(e) => error!("gRPC server task error: {}", e), + Ok(Err(e)) => error!("gRPC server error: {}", e), + Ok(Ok(())) => {} + } + } + } + } + + info!("Shutdown complete"); Ok(()) } + +async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("Failed to install SIGTERM handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {} + _ = terminate => {} + } +} diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 47ed2c7..f7c067e 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -23,6 +23,7 @@ pub trait StorageEngine: Send + Sync { pub mod in_memory; pub mod rocks_db; +#[derive(Debug, Clone, Copy)] pub enum StorageType { InMemory, RocksDb,