Skip to content

Commit 302aaae

Browse files
committed
validator/tests: add vector_similarity() function test
Add tests to validate the results of vector_similarity() function. Refs: scylladb/scylladb#25993
1 parent a473e5e commit 302aaae

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

crates/validator/src/tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod crud;
88
mod full_scan;
99
mod reconnect;
1010
mod serde;
11+
mod vector_similarity;
1112

1213
use crate::ServicesSubnet;
1314
use crate::dns::Dns;
@@ -221,6 +222,7 @@ pub(crate) async fn register() -> Vec<(String, TestCase)> {
221222
("full_scan", full_scan::new().await),
222223
("reconnect", reconnect::new().await),
223224
("serde", serde::new().await),
225+
("vector_similarity", vector_similarity::new().await),
224226
]
225227
.into_iter()
226228
.map(|(name, test_case)| (name.to_string(), test_case))
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/*
2+
* Copyright 2025-present ScyllaDB
3+
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
4+
*/
5+
6+
use crate::common::*;
7+
use crate::tests::*;
8+
use scylla::client::session::Session;
9+
use std::time::Duration;
10+
use tracing::info;
11+
12+
pub(crate) async fn new() -> TestCase {
13+
let timeout = Duration::from_secs(30);
14+
TestCase::empty()
15+
.with_init(timeout, init)
16+
.with_cleanup(timeout, cleanup)
17+
.with_test(
18+
"vector_similarity_function_returns_expected_results",
19+
timeout,
20+
vector_similarity_function_returns_expected_results,
21+
)
22+
.with_test(
23+
"vector_similarity_function_with_clustering_key",
24+
timeout,
25+
vector_similarity_function_with_clustering_key,
26+
)
27+
.with_test(
28+
"vector_similarity_function_with_multi_column_partition_key",
29+
timeout,
30+
vector_similarity_function_with_multi_column_partition_key,
31+
)
32+
}
33+
34+
pub(crate) static EMBEDDINGS: [[f32; 3]; 3] = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
35+
36+
async fn check_similarity_function_results(session: &Session, table: &str, key_column: &str) {
37+
let results = get_query_results(
38+
format!(
39+
"SELECT {key_column}, vector_similarity() FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 5"
40+
),
41+
&session,
42+
)
43+
.await;
44+
let rows = results.rows::<(i32, f32)>().expect("failed to get rows");
45+
assert_eq!(rows.rows_remaining(), 3);
46+
47+
// Expected results are calculated using Euclidean distance formula
48+
let expected_distances = [(0, 14.0), (1, 77.0), (2, 194.0)];
49+
for (i, row) in rows.enumerate() {
50+
let row = row.expect("failed to get row");
51+
let (key, distance) = row;
52+
assert_eq!(
53+
(key, distance),
54+
expected_distances[i],
55+
"Row {i} does not match expected result"
56+
);
57+
}
58+
}
59+
60+
async fn vector_similarity_function_returns_expected_results(actors: TestActors) {
61+
info!("started");
62+
63+
let (session, client) = prepare_connection(&actors).await;
64+
65+
let keyspace = create_keyspace(&session).await;
66+
let table = create_table(&session, "pk INT PRIMARY KEY, v VECTOR<FLOAT, 3>", None).await;
67+
68+
// Insert test data
69+
for (i, embedding) in EMBEDDINGS.into_iter().enumerate() {
70+
session
71+
.query_unpaged(
72+
format!("INSERT INTO {table} (pk, v) VALUES (?, ?)"),
73+
(i as i32, embedding.as_slice()),
74+
)
75+
.await
76+
.expect("failed to insert data");
77+
}
78+
79+
// Create index with EUCLIDEAN similarity function
80+
let index = create_index(
81+
&session,
82+
&client,
83+
&table,
84+
"v",
85+
Some("{'similarity_function' : 'EUCLIDEAN'}"),
86+
)
87+
.await;
88+
89+
wait_for(
90+
|| async { client.count(&index.keyspace, &index.index).await == Some(3) },
91+
"Waiting for 3 vectors to be indexed",
92+
Duration::from_secs(5),
93+
)
94+
.await;
95+
96+
// Check if the query returns the expected distances
97+
check_similarity_function_results(&session, &table, "pk").await;
98+
99+
// Drop keyspace
100+
session
101+
.query_unpaged(format!("DROP KEYSPACE {keyspace}"), ())
102+
.await
103+
.expect("failed to drop a keyspace");
104+
105+
info!("finished");
106+
}
107+
108+
async fn vector_similarity_function_with_clustering_key(actors: TestActors) {
109+
info!("started");
110+
111+
let (session, client) = prepare_connection(&actors).await;
112+
113+
let keyspace = create_keyspace(&session).await;
114+
let table = create_table(
115+
&session,
116+
"pk INT, ck INT, v VECTOR<FLOAT, 3>, PRIMARY KEY (pk, ck)",
117+
None,
118+
)
119+
.await;
120+
121+
// Insert test data
122+
for (i, embedding) in EMBEDDINGS.into_iter().enumerate() {
123+
session
124+
.query_unpaged(
125+
format!("INSERT INTO {table} (pk, ck, v) VALUES (?, ?, ?)"),
126+
(123, i as i32, &embedding.as_slice()),
127+
)
128+
.await
129+
.expect("failed to insert data");
130+
}
131+
132+
// Create index with EUCLIDEAN similarity function
133+
let index = create_index(
134+
&session,
135+
&client,
136+
&table,
137+
"v",
138+
Some("{'similarity_function' : 'EUCLIDEAN'}"),
139+
)
140+
.await;
141+
142+
wait_for(
143+
|| async { client.count(&index.keyspace, &index.index).await == Some(3) },
144+
"Waiting for 3 vectors to be indexed",
145+
Duration::from_secs(5),
146+
)
147+
.await;
148+
149+
// Check if the query returns the expected distances
150+
check_similarity_function_results(&session, &table, "ck").await;
151+
152+
// Drop keyspace
153+
session
154+
.query_unpaged(format!("DROP KEYSPACE {keyspace}"), ())
155+
.await
156+
.expect("failed to drop a keyspace");
157+
158+
info!("finished");
159+
}
160+
161+
async fn vector_similarity_function_with_multi_column_partition_key(actors: TestActors) {
162+
info!("started");
163+
164+
let (session, client) = prepare_connection(&actors).await;
165+
166+
let keyspace = create_keyspace(&session).await;
167+
let table = create_table(
168+
&session,
169+
"pk1 INT, pk2 INT, v VECTOR<FLOAT, 3>, PRIMARY KEY ((pk1, pk2))",
170+
None,
171+
)
172+
.await;
173+
174+
// Insert test data
175+
for (i, embedding) in EMBEDDINGS.into_iter().enumerate() {
176+
session
177+
.query_unpaged(
178+
format!("INSERT INTO {table} (pk1, pk2, v) VALUES (?, ?, ?)"),
179+
(123, i as i32, &embedding.as_slice()),
180+
)
181+
.await
182+
.expect("failed to insert data");
183+
}
184+
185+
// Create index with EUCLIDEAN similarity function
186+
let index = create_index(
187+
&session,
188+
&client,
189+
&table,
190+
"v",
191+
Some("{'similarity_function' : 'EUCLIDEAN'}"),
192+
)
193+
.await;
194+
195+
wait_for(
196+
|| async { client.count(&index.keyspace, &index.index).await == Some(3) },
197+
"Waiting for 3 vectors to be indexed",
198+
Duration::from_secs(5),
199+
)
200+
.await;
201+
202+
// Check if the query returns the expected distances
203+
check_similarity_function_results(&session, &table, "pk2").await;
204+
205+
// Drop keyspace
206+
session
207+
.query_unpaged(format!("DROP KEYSPACE {keyspace}"), ())
208+
.await
209+
.expect("failed to drop a keyspace");
210+
211+
info!("finished");
212+
}

0 commit comments

Comments
 (0)