Skip to content

Commit ce25cac

Browse files
authored
Merge pull request #1603 from tursodatabase/vector-search-more-tests
add more tests for vector feature
2 parents b233a11 + da55b5f commit ce25cac

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libsql-sqlite3/test/libsql_vector_index.test

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,26 @@ do_execsql_test vector-vacuum {
230230
SELECT COUNT(*) FROM t_vacuum_idx_shadow;
231231
} {2 2}
232232

233+
do_execsql_test vector-many-columns {
234+
CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) );
235+
CREATE INDEX t_many_1_idx ON t_many(libsql_vector_idx(e1));
236+
CREATE INDEX t_many_2_idx ON t_many(libsql_vector_idx(e2));
237+
INSERT INTO t_many VALUES (1, vector('[1,1]'), vector('[-1,-1]')), (2, vector('[-1,-1]'), vector('[1,1]'));
238+
SELECT * FROM vector_top_k('t_many_1_idx', vector('[1,1]'), 2);
239+
SELECT * FROM vector_top_k('t_many_2_idx', vector('[1,1]'), 2);
240+
} {1 2 2 1}
241+
242+
do_execsql_test vector-transaction {
243+
CREATE TABLE t_transaction ( i INTEGER PRIMARY KEY, e FLOAT32(2) );
244+
CREATE INDEX t_transaction_idx ON t_transaction(libsql_vector_idx(e));
245+
INSERT INTO t_transaction VALUES (1, vector('[1,2]')), (2, vector('[3,4]'));
246+
BEGIN;
247+
INSERT INTO t_transaction VALUES (3, vector('[4,5]')), (4, vector('[5,6]'));
248+
SELECT * FROM vector_top_k('t_transaction_idx', vector('[4,5]'), 2);
249+
ROLLBACK;
250+
SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2);
251+
} {3 4 1 2}
252+
233253
proc error_messages {sql} {
234254
set ret ""
235255
catch {

libsql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ tokio = { version = "1.29.1", features = ["full"] }
5050
tokio-test = "0.4"
5151
tracing-subscriber = "0.3"
5252
tempfile = { version = "3.7.0" }
53+
rand = "0.8.5"
5354

5455
[features]
5556
default = ["core", "replication", "remote"]

libsql/tests/integration_tests.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ use libsql::{
66
params::{IntoParams, IntoValue},
77
Connection, Database, Value,
88
};
9+
use rand::distributions::Uniform;
10+
use rand::prelude::*;
11+
use std::collections::HashSet;
912

1013
async fn setup() -> Connection {
1114
let db = Database::open(":memory:").unwrap();
@@ -650,3 +653,102 @@ async fn deserialize_row() {
650653
assert_eq!(data.status, Status::Draft);
651654
assert_eq!(data.wrapper, Wrapper(Status::Published));
652655
}
656+
657+
#[tokio::test]
658+
#[ignore]
659+
// fuzz test can be run explicitly with following command:
660+
// cargo test vector_fuzz_test -- --nocapture --include-ignored
661+
async fn vector_fuzz_test() {
662+
let mut global_rng = rand::thread_rng();
663+
for attempt in 0..10000 {
664+
let seed = global_rng.next_u64();
665+
666+
let mut rng =
667+
rand::rngs::StdRng::from_seed(unsafe { std::mem::transmute([seed, seed, seed, seed]) });
668+
let db = Database::open(":memory:").unwrap();
669+
let conn = db.connect().unwrap();
670+
let dim = rng.gen_range(1..=1536);
671+
let operations = rng.gen_range(1..128);
672+
println!(
673+
"============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================",
674+
attempt, seed, dim, operations
675+
);
676+
677+
let _ = conn
678+
.execute(
679+
&format!(
680+
"CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )",
681+
dim
682+
),
683+
(),
684+
)
685+
.await;
686+
// println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
687+
let _ = conn
688+
.execute(
689+
"CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );",
690+
(),
691+
)
692+
.await;
693+
// println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");
694+
695+
let mut next_id = 1;
696+
let mut alive = HashSet::new();
697+
let uniform = Uniform::new(-1.0, 1.0);
698+
for _ in 0..operations {
699+
let operation = rng.gen_range(0..4);
700+
let vector: Vec<f32> = (0..dim).map(|_| rng.sample(uniform)).collect();
701+
let vector_str = format!(
702+
"[{}]",
703+
vector
704+
.iter()
705+
.map(|x| format!("{}", x))
706+
.collect::<Vec<String>>()
707+
.join(",")
708+
);
709+
if operation == 0 {
710+
// println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
711+
conn.execute(
712+
"INSERT INTO users VALUES (?, vector(?) )",
713+
libsql::params![next_id, vector_str],
714+
)
715+
.await
716+
.unwrap();
717+
alive.insert(next_id);
718+
next_id += 1;
719+
} else if operation == 1 {
720+
let id = rng.gen_range(0..next_id);
721+
// println!("DELETE FROM users WHERE id = {};", id);
722+
conn.execute("DELETE FROM users WHERE id = ?", libsql::params![id])
723+
.await
724+
.unwrap();
725+
alive.remove(&id);
726+
} else if operation == 2 && !alive.is_empty() {
727+
let id = alive.iter().collect::<Vec<_>>()[rng.gen_range(0..alive.len())];
728+
// println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
729+
conn.execute(
730+
"UPDATE users SET v = vector(?) WHERE id = ?",
731+
libsql::params![vector_str, id],
732+
)
733+
.await
734+
.unwrap();
735+
} else if operation == 3 {
736+
let k = rng.gen_range(1..200);
737+
// println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
738+
let result = conn
739+
.query(
740+
"SELECT * FROM vector_top_k('users_idx', ?, ?)",
741+
libsql::params![vector_str, k],
742+
)
743+
.await
744+
.unwrap();
745+
let count = result.into_stream().count().await;
746+
assert!(count <= alive.len());
747+
if alive.len() > 0 {
748+
assert!(count > 0);
749+
}
750+
}
751+
}
752+
let _ = conn.execute("REINDEX users;", ()).await.unwrap();
753+
}
754+
}

0 commit comments

Comments
 (0)