Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ anyhow = "1.0.89"
serde = { version = "1.0.210", features = ["derive"], optional = true }
serde_json = { version = "1.0.128", optional = true }
reqwest = { version = "0.12.8", optional = true, default-features = false, features = ["stream", "rustls-tls", "http2"] }
futures = { version = "0.3.31" }
tokio = { version = "1.40.0", features = ["rt-multi-thread"] }
futures-util = { version = "0.3.31", optional = true }
derive_builder = { version = "0.20.2" }
thiserror = "1.0.64"
semver = "1.0.24"

[dev-dependencies]
tonic-build = { version = "0.12.3", features = ["prost"] }
tokio = { version = "1.40.0", features = ["rt-multi-thread"] }

[features]
default = ["download_snapshots", "serde", "generate-snippets"]
Expand Down
10 changes: 10 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{Qdrant, QdrantError};
/// .compression(Some(CompressionEncoding::Gzip))
/// .build();
/// ```
#[derive(Clone)]
pub struct QdrantConfig {
/// Qdrant server URI to connect to
pub uri: String,
Expand All @@ -34,6 +35,9 @@ pub struct QdrantConfig {

/// Optional compression schema to use for API requests
pub compression: Option<CompressionEncoding>,

/// Whether to check compatibility between the client and server versions
pub check_compatibility: bool,
}

impl QdrantConfig {
Expand Down Expand Up @@ -169,6 +173,11 @@ impl QdrantConfig {
pub fn build(self) -> Result<Qdrant, QdrantError> {
Qdrant::new(self)
}

pub fn skip_compatibility_check(mut self) -> Self {
self.check_compatibility = false;
self
}
}

/// Default Qdrant client configuration.
Expand All @@ -183,6 +192,7 @@ impl Default for QdrantConfig {
keep_alive_while_idle: true,
api_key: None,
compression: None,
check_compatibility: true,
}
}
}
Expand Down
49 changes: 49 additions & 0 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ mod query;
mod search;
mod sharding_keys;
mod snapshot;
mod version_check;

use std::future::Future;
use std::thread;

use tonic::codegen::InterceptedService;
use tonic::transport::{Channel, Uri};
Expand All @@ -21,6 +23,7 @@ use crate::auth::TokenInterceptor;
use crate::channel_pool::ChannelPool;
use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest};
use crate::qdrant_client::config::QdrantConfig;
use crate::qdrant_client::version_check::is_compatible;
use crate::QdrantError;

/// [`Qdrant`] client result
Expand Down Expand Up @@ -95,6 +98,52 @@ impl Qdrant {
///
/// Constructs the client and connects based on the given [`QdrantConfig`](config::QdrantConfig).
pub fn new(config: QdrantConfig) -> QdrantResult<Self> {
if config.check_compatibility {
// create a temporary client to check compatibility
let channel = ChannelPool::new(
config.uri.parse::<Uri>()?,
config.timeout,
config.connect_timeout,
config.keep_alive_while_idle,
);
let client = Self {
channel,
config: config.clone(),
};

// We're in sync context, spawn temporary runtime in thread to do async health check
let server_version = thread::scope(|s| {
s.spawn(|| {
tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(QdrantError::Io)?
.block_on(client.health_check())
})
.join()
.expect("Failed to join health check thread")
})
.ok()
.map(|info| info.version);

let client_version = env!("CARGO_PKG_VERSION").to_string();
if let Some(server_version) = server_version {
let is_compatible = is_compatible(Some(&client_version), Some(&server_version));
if !is_compatible {
println!("Client version {client_version} is not compatible with server version {server_version}. \
Major versions should match and minor version difference must not exceed 1. \
Set check_compatibility=false to skip version check.");
}
} else {
println!(
"Failed to obtain server version. \
Unable to check client-server compatibility. \
Set check_compatibility=false to skip version check."
);
}
}

let channel = ChannelPool::new(
config.uri.parse::<Uri>()?,
config.timeout,
Expand Down
128 changes: 128 additions & 0 deletions src/qdrant_client/version_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::error::Error;
use std::fmt;

use semver::Version;

pub fn parse(version: &str) -> Result<Version, VersionParseError> {
if version.is_empty() {
return Err(VersionParseError::EmptyVersion);
}
match Version::parse(version) {
Ok(v) => Ok(v),
Err(_) => Err(VersionParseError::InvalidFormat(version.to_string())),
}
}

#[derive(Debug)]
pub enum VersionParseError {
EmptyVersion,
InvalidFormat(String),
}

impl fmt::Display for VersionParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
VersionParseError::EmptyVersion => write!(f, "Version is empty"),
VersionParseError::InvalidFormat(version) => {
write!(
f,
"Unable to parse version, expected format: x.y[.z], found: {}",
version
)
}
}
}
}

impl Error for VersionParseError {}

pub fn is_compatible(client_version: Option<&str>, server_version: Option<&str>) -> bool {
if client_version.is_none() || server_version.is_none() {
println!(
"Unable to compare versions, client_version: {:?}, server_version: {:?}",
client_version, server_version
);
return false;
}

let client_version = client_version.unwrap();
let server_version = server_version.unwrap();

if client_version == server_version {
return true;
}

match (parse(client_version), parse(server_version)) {
(Ok(client), Ok(server)) => {
let major_dif = (client.major as i32 - server.major as i32).abs();
if major_dif >= 1 {
return false;
}
(client.minor as i32 - server.minor as i32).abs() <= 1
}
(Err(e), _) | (_, Err(e)) => {
println!("Unable to compare versions: {}", e);
false
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_is_compatible() {
let test_cases = vec![
(Some("1.9.3.dev0"), Some("2.8.1-dev12"), false),
(Some("1.9"), Some("2.8"), false),
(Some("1"), Some("2"), false),
(Some("1.9.0"), Some("2.9.0"), false),
(Some("1.1.0"), Some("1.2.9"), true),
(Some("1.2.7"), Some("1.1.8-dev0"), true),
(Some("1.2.1"), Some("1.2.29"), true),
(Some("1.2.0"), Some("1.2.0"), true),
(Some("1.2.0"), Some("1.4.0"), false),
(Some("1.4.0"), Some("1.2.0"), false),
(Some("1.9.0"), Some("3.7.0"), false),
(Some("3.0.0"), Some("1.0.0"), false),
(None, Some("1.0.0"), false),
(Some("1.0.0"), None, false),
(None, None, false),
];

for (client_version, server_version, expected_result) in test_cases {
let result = is_compatible(client_version, server_version);
assert_eq!(
result, expected_result,
"Failed for client: {:?}, server: {:?}",
client_version, server_version
);
}
}

#[test]
fn test_version_parse_errors() {
let test_cases = vec![
("1", VersionParseError::InvalidFormat("1".to_string())),
("1.", VersionParseError::InvalidFormat("1.".to_string())),
(".1", VersionParseError::InvalidFormat(".1".to_string())),
(".1.", VersionParseError::InvalidFormat(".1.".to_string())),
(
"1.a.1",
VersionParseError::InvalidFormat("1.a.1".to_string()),
),
(
"a.1.1",
VersionParseError::InvalidFormat("a.1.1".to_string()),
),
("", VersionParseError::EmptyVersion),
];

for (input, expected_error) in test_cases {
let result = parse(input);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), expected_error.to_string());
}
}
}
Loading