Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ anyhow = "1.0.89"
serde = { version = "1.0.210", features = ["derive"], optional = true }
serde_json = { version = "1.0.128", optional = true }
reqwest = { version = "0.12.8", optional = true, default-features = false, features = ["stream", "rustls-tls", "http2"] }
futures = { version = "0.3.31" }
tokio = { version = "1.40.0", features = ["rt-multi-thread"] }
futures-util = { version = "0.3.31", optional = true }
derive_builder = { version = "0.20.2" }
thiserror = "1.0.64"
semver = "1.0.24"

[dev-dependencies]
tonic-build = { version = "0.12.3", features = ["prost"] }
tokio = { version = "1.40.0", features = ["rt-multi-thread"] }

[features]
default = ["download_snapshots", "serde", "generate-snippets"]
Expand Down
10 changes: 10 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{Qdrant, QdrantError};
/// .compression(Some(CompressionEncoding::Gzip))
/// .build();
/// ```
#[derive(Clone)]
pub struct QdrantConfig {
/// Qdrant server URI to connect to
pub uri: String,
Expand All @@ -34,6 +35,9 @@ pub struct QdrantConfig {

/// Optional compression schema to use for API requests
pub compression: Option<CompressionEncoding>,

/// Whether to check compatibility between the client and server versions
pub check_compatibility: bool,
}

impl QdrantConfig {
Expand Down Expand Up @@ -169,6 +173,11 @@ impl QdrantConfig {
pub fn build(self) -> Result<Qdrant, QdrantError> {
Qdrant::new(self)
}

pub fn skip_compatibility_check(mut self) -> Self {
self.check_compatibility = false;
self
}
}

/// Default Qdrant client configuration.
Expand All @@ -183,6 +192,7 @@ impl Default for QdrantConfig {
keep_alive_while_idle: true,
api_key: None,
compression: None,
check_compatibility: true,
}
}
}
Expand Down
49 changes: 49 additions & 0 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ mod query;
mod search;
mod sharding_keys;
mod snapshot;
mod version_check;

use std::future::Future;
use std::thread;

use tonic::codegen::InterceptedService;
use tonic::transport::{Channel, Uri};
Expand All @@ -21,6 +23,7 @@ use crate::auth::TokenInterceptor;
use crate::channel_pool::ChannelPool;
use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest};
use crate::qdrant_client::config::QdrantConfig;
use crate::qdrant_client::version_check::is_compatible;
use crate::QdrantError;

/// [`Qdrant`] client result
Expand Down Expand Up @@ -95,6 +98,52 @@ impl Qdrant {
///
/// Constructs the client and connects based on the given [`QdrantConfig`](config::QdrantConfig).
pub fn new(config: QdrantConfig) -> QdrantResult<Self> {
if config.check_compatibility {
// create a temporary client to check compatibility
let channel = ChannelPool::new(
config.uri.parse::<Uri>()?,
config.timeout,
config.connect_timeout,
config.keep_alive_while_idle,
);
let client = Self {
channel,
config: config.clone(),
};

// We're in sync context, spawn temporary runtime in thread to do async health check
let server_version = thread::scope(|s| {
s.spawn(|| {
tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(QdrantError::Io)?
.block_on(client.health_check())
})
.join()
.expect("Failed to join health check thread")
})
.ok()
.map(|info| info.version);

let client_version = env!("CARGO_PKG_VERSION").to_string();
if let Some(server_version) = server_version {
let is_compatible = is_compatible(Some(&client_version), Some(&server_version));
if !is_compatible {
println!("Client version {client_version} is not compatible with server version {server_version}. \
Major versions should match and minor version difference must not exceed 1. \
Set check_compatibility=false to skip version check.");
}
} else {
println!(
"Failed to obtain server version. \
Unable to check client-server compatibility. \
Set check_compatibility=false to skip version check."
);
}
}

let channel = ChannelPool::new(
config.uri.parse::<Uri>()?,
config.timeout,
Expand Down
128 changes: 128 additions & 0 deletions src/qdrant_client/version_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::error::Error;
use std::fmt;

use semver::Version;

pub fn parse(version: &str) -> Result<Version, VersionParseError> {
if version.is_empty() {
return Err(VersionParseError::EmptyVersion);
}
match Version::parse(version) {
Ok(v) => Ok(v),
Err(_) => Err(VersionParseError::InvalidFormat(version.to_string())),
}
}

#[derive(Debug)]
pub enum VersionParseError {
EmptyVersion,
InvalidFormat(String),
}

impl fmt::Display for VersionParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
VersionParseError::EmptyVersion => write!(f, "Version is empty"),
VersionParseError::InvalidFormat(version) => {
write!(
f,
"Unable to parse version, expected format: x.y[.z], found: {}",
version
)
}
}
}
}

impl Error for VersionParseError {}

pub fn is_compatible(client_version: Option<&str>, server_version: Option<&str>) -> bool {
if client_version.is_none() || server_version.is_none() {
println!(
"Unable to compare versions, client_version: {:?}, server_version: {:?}",
client_version, server_version
);
return false;
}

let client_version = client_version.unwrap();
let server_version = server_version.unwrap();

if client_version == server_version {
return true;
}

match (parse(client_version), parse(server_version)) {
(Ok(client), Ok(server)) => {
let major_dif = (client.major as i32 - server.major as i32).abs();
if major_dif >= 1 {
return false;
}
(client.minor as i32 - server.minor as i32).abs() <= 1
}
(Err(e), _) | (_, Err(e)) => {
println!("Unable to compare versions: {}", e);
false
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_is_compatible() {
let test_cases = vec![
(Some("1.9.3.dev0"), Some("2.8.1-dev12"), false),
(Some("1.9"), Some("2.8"), false),
(Some("1"), Some("2"), false),
(Some("1.9.0"), Some("2.9.0"), false),
(Some("1.1.0"), Some("1.2.9"), true),
(Some("1.2.7"), Some("1.1.8-dev0"), true),
(Some("1.2.1"), Some("1.2.29"), true),
(Some("1.2.0"), Some("1.2.0"), true),
(Some("1.2.0"), Some("1.4.0"), false),
(Some("1.4.0"), Some("1.2.0"), false),
(Some("1.9.0"), Some("3.7.0"), false),
(Some("3.0.0"), Some("1.0.0"), false),
(None, Some("1.0.0"), false),
(Some("1.0.0"), None, false),
(None, None, false),
];

for (client_version, server_version, expected_result) in test_cases {
let result = is_compatible(client_version, server_version);
assert_eq!(
result, expected_result,
"Failed for client: {:?}, server: {:?}",
client_version, server_version
);
}
}

#[test]
fn test_version_parse_errors() {
let test_cases = vec![
("1", VersionParseError::InvalidFormat("1".to_string())),
("1.", VersionParseError::InvalidFormat("1.".to_string())),
(".1", VersionParseError::InvalidFormat(".1".to_string())),
(".1.", VersionParseError::InvalidFormat(".1.".to_string())),
(
"1.a.1",
VersionParseError::InvalidFormat("1.a.1".to_string()),
),
(
"a.1.1",
VersionParseError::InvalidFormat("a.1.1".to_string()),
),
("", VersionParseError::EmptyVersion),
];

for (input, expected_error) in test_cases {
let result = parse(input);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), expected_error.to_string());
}
}
}
Loading