diff --git a/Cargo.toml b/Cargo.toml index 0cde4c8d..b7fa8c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,13 +20,15 @@ anyhow = "1.0.89" serde = { version = "1.0.210", features = ["derive"], optional = true } serde_json = { version = "1.0.128", optional = true } reqwest = { version = "0.12.8", optional = true, default-features = false, features = ["stream", "rustls-tls", "http2"] } +futures = { version = "0.3.31" } +tokio = { version = "1.40.0", features = ["rt-multi-thread"] } futures-util = { version = "0.3.31", optional = true } derive_builder = { version = "0.20.2" } thiserror = "1.0.64" +semver = "1.0.24" [dev-dependencies] tonic-build = { version = "0.12.3", features = ["prost"] } -tokio = { version = "1.40.0", features = ["rt-multi-thread"] } [features] default = ["download_snapshots", "serde", "generate-snippets"] diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index d1d214a4..9ed54810 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -16,6 +16,7 @@ use crate::{Qdrant, QdrantError}; /// .compression(Some(CompressionEncoding::Gzip)) /// .build(); /// ``` +#[derive(Clone)] pub struct QdrantConfig { /// Qdrant server URI to connect to pub uri: String, @@ -34,6 +35,9 @@ pub struct QdrantConfig { /// Optional compression schema to use for API requests pub compression: Option, + + /// Whether to check compatibility between the client and server versions + pub check_compatibility: bool, } impl QdrantConfig { @@ -169,6 +173,11 @@ impl QdrantConfig { pub fn build(self) -> Result { Qdrant::new(self) } + + pub fn skip_compatibility_check(mut self) -> Self { + self.check_compatibility = false; + self + } } /// Default Qdrant client configuration. @@ -183,6 +192,7 @@ impl Default for QdrantConfig { keep_alive_while_idle: true, api_key: None, compression: None, + check_compatibility: true, } } } diff --git a/src/qdrant_client/mod.rs b/src/qdrant_client/mod.rs index adfe9924..550fb005 100644 --- a/src/qdrant_client/mod.rs +++ b/src/qdrant_client/mod.rs @@ -10,8 +10,10 @@ mod query; mod search; mod sharding_keys; mod snapshot; +mod version_check; use std::future::Future; +use std::thread; use tonic::codegen::InterceptedService; use tonic::transport::{Channel, Uri}; @@ -21,6 +23,7 @@ use crate::auth::TokenInterceptor; use crate::channel_pool::ChannelPool; use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest}; use crate::qdrant_client::config::QdrantConfig; +use crate::qdrant_client::version_check::is_compatible; use crate::QdrantError; /// [`Qdrant`] client result @@ -95,6 +98,52 @@ impl Qdrant { /// /// Constructs the client and connects based on the given [`QdrantConfig`](config::QdrantConfig). pub fn new(config: QdrantConfig) -> QdrantResult { + if config.check_compatibility { + // create a temporary client to check compatibility + let channel = ChannelPool::new( + config.uri.parse::()?, + config.timeout, + config.connect_timeout, + config.keep_alive_while_idle, + ); + let client = Self { + channel, + config: config.clone(), + }; + + // We're in sync context, spawn temporary runtime in thread to do async health check + let server_version = thread::scope(|s| { + s.spawn(|| { + tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(QdrantError::Io)? + .block_on(client.health_check()) + }) + .join() + .expect("Failed to join health check thread") + }) + .ok() + .map(|info| info.version); + + let client_version = env!("CARGO_PKG_VERSION").to_string(); + if let Some(server_version) = server_version { + let is_compatible = is_compatible(Some(&client_version), Some(&server_version)); + if !is_compatible { + println!("Client version {client_version} is not compatible with server version {server_version}. \ + Major versions should match and minor version difference must not exceed 1. \ + Set check_compatibility=false to skip version check."); + } + } else { + println!( + "Failed to obtain server version. \ + Unable to check client-server compatibility. \ + Set check_compatibility=false to skip version check." + ); + } + } + let channel = ChannelPool::new( config.uri.parse::()?, config.timeout, diff --git a/src/qdrant_client/version_check.rs b/src/qdrant_client/version_check.rs new file mode 100644 index 00000000..8c7b8861 --- /dev/null +++ b/src/qdrant_client/version_check.rs @@ -0,0 +1,128 @@ +use std::error::Error; +use std::fmt; + +use semver::Version; + +pub fn parse(version: &str) -> Result { + if version.is_empty() { + return Err(VersionParseError::EmptyVersion); + } + match Version::parse(version) { + Ok(v) => Ok(v), + Err(_) => Err(VersionParseError::InvalidFormat(version.to_string())), + } +} + +#[derive(Debug)] +pub enum VersionParseError { + EmptyVersion, + InvalidFormat(String), +} + +impl fmt::Display for VersionParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VersionParseError::EmptyVersion => write!(f, "Version is empty"), + VersionParseError::InvalidFormat(version) => { + write!( + f, + "Unable to parse version, expected format: x.y[.z], found: {}", + version + ) + } + } + } +} + +impl Error for VersionParseError {} + +pub fn is_compatible(client_version: Option<&str>, server_version: Option<&str>) -> bool { + if client_version.is_none() || server_version.is_none() { + println!( + "Unable to compare versions, client_version: {:?}, server_version: {:?}", + client_version, server_version + ); + return false; + } + + let client_version = client_version.unwrap(); + let server_version = server_version.unwrap(); + + if client_version == server_version { + return true; + } + + match (parse(client_version), parse(server_version)) { + (Ok(client), Ok(server)) => { + let major_dif = (client.major as i32 - server.major as i32).abs(); + if major_dif >= 1 { + return false; + } + (client.minor as i32 - server.minor as i32).abs() <= 1 + } + (Err(e), _) | (_, Err(e)) => { + println!("Unable to compare versions: {}", e); + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_compatible() { + let test_cases = vec![ + (Some("1.9.3.dev0"), Some("2.8.1-dev12"), false), + (Some("1.9"), Some("2.8"), false), + (Some("1"), Some("2"), false), + (Some("1.9.0"), Some("2.9.0"), false), + (Some("1.1.0"), Some("1.2.9"), true), + (Some("1.2.7"), Some("1.1.8-dev0"), true), + (Some("1.2.1"), Some("1.2.29"), true), + (Some("1.2.0"), Some("1.2.0"), true), + (Some("1.2.0"), Some("1.4.0"), false), + (Some("1.4.0"), Some("1.2.0"), false), + (Some("1.9.0"), Some("3.7.0"), false), + (Some("3.0.0"), Some("1.0.0"), false), + (None, Some("1.0.0"), false), + (Some("1.0.0"), None, false), + (None, None, false), + ]; + + for (client_version, server_version, expected_result) in test_cases { + let result = is_compatible(client_version, server_version); + assert_eq!( + result, expected_result, + "Failed for client: {:?}, server: {:?}", + client_version, server_version + ); + } + } + + #[test] + fn test_version_parse_errors() { + let test_cases = vec![ + ("1", VersionParseError::InvalidFormat("1".to_string())), + ("1.", VersionParseError::InvalidFormat("1.".to_string())), + (".1", VersionParseError::InvalidFormat(".1".to_string())), + (".1.", VersionParseError::InvalidFormat(".1.".to_string())), + ( + "1.a.1", + VersionParseError::InvalidFormat("1.a.1".to_string()), + ), + ( + "a.1.1", + VersionParseError::InvalidFormat("a.1.1".to_string()), + ), + ("", VersionParseError::EmptyVersion), + ]; + + for (input, expected_error) in test_cases { + let result = parse(input); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), expected_error.to_string()); + } + } +}