Skip to content

Commit 9057bcb

Browse files
committed
add more tests for vector feature
1 parent 1eaa283 commit 9057bcb

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libsql-sqlite3/test/libsql_vector_index.test

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,26 @@ do_execsql_test vector-vacuum {
230230
SELECT COUNT(*) FROM t_vacuum_idx_shadow;
231231
} {2 2}
232232

233+
do_execsql_test vector-many-columns {
234+
CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) );
235+
CREATE INDEX t_many_1_idx ON t_many(libsql_vector_idx(e1));
236+
CREATE INDEX t_many_2_idx ON t_many(libsql_vector_idx(e2));
237+
INSERT INTO t_many VALUES (1, vector('[1,1]'), vector('[-1,-1]')), (2, vector('[-1,-1]'), vector('[1,1]'));
238+
SELECT * FROM vector_top_k('t_many_1_idx', vector('[1,1]'), 2);
239+
SELECT * FROM vector_top_k('t_many_2_idx', vector('[1,1]'), 2);
240+
} {1 2 2 1}
241+
242+
do_execsql_test vector-transaction {
243+
CREATE TABLE t_transaction ( i INTEGER PRIMARY KEY, e FLOAT32(2) );
244+
CREATE INDEX t_transaction_idx ON t_transaction(libsql_vector_idx(e));
245+
INSERT INTO t_transaction VALUES (1, vector('[1,2]')), (2, vector('[3,4]'));
246+
BEGIN;
247+
INSERT INTO t_transaction VALUES (3, vector('[4,5]')), (4, vector('[5,6]'));
248+
SELECT * FROM vector_top_k('t_transaction_idx', vector('[4,5]'), 2);
249+
ROLLBACK;
250+
SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2);
251+
} {3 4 1 2}
252+
233253
proc error_messages {sql} {
234254
set ret ""
235255
catch {

libsql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ tokio = { version = "1.29.1", features = ["full"] }
5050
tokio-test = "0.4"
5151
tracing-subscriber = "0.3"
5252
tempfile = { version = "3.7.0" }
53+
rand = "0.8.5"
5354

5455
[features]
5556
default = ["core", "replication", "remote"]

libsql/tests/integration_tests.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#![allow(deprecated)]
22

3+
use rand::prelude::*;
4+
use rand::distributions::Uniform;
5+
use std::collections::HashSet;
36
use futures::{StreamExt, TryStreamExt};
47
use libsql::{
58
named_params, params,
@@ -650,3 +653,60 @@ async fn deserialize_row() {
650653
assert_eq!(data.status, Status::Draft);
651654
assert_eq!(data.wrapper, Wrapper(Status::Published));
652655
}
656+
657+
#[tokio::test]
658+
#[ignore]
659+
// fuzz test can be run explicitly with following command:
660+
// cargo test vector_fuzz_test -- --nocapture --include-ignored
661+
async fn vector_fuzz_test() {
662+
let mut global_rng = rand::thread_rng();
663+
for attempt in 0..10000 {
664+
let seed = global_rng.next_u64();
665+
666+
let mut rng = rand::rngs::StdRng::from_seed(unsafe { std::mem::transmute([seed, seed, seed, seed]) });
667+
let db = Database::open(":memory:").unwrap();
668+
let conn = db.connect().unwrap();
669+
let dim = rng.gen_range(1..=1536);
670+
let operations = rng.gen_range(1..128);
671+
println!("============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================", attempt, seed, dim, operations);
672+
673+
let _ = conn.execute(&format!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )", dim), ()).await;
674+
// println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
675+
let _ = conn.execute("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );", ()).await;
676+
// println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");
677+
678+
let mut next_id = 1;
679+
let mut alive = HashSet::new();
680+
let uniform = Uniform::new(-1.0, 1.0);
681+
for _ in 0..operations {
682+
let operation = rng.gen_range(0..4);
683+
let vector : Vec<f32> = (0..dim).map(|_| rng.sample(uniform)).collect();
684+
let vector_str = format!("[{}]", vector.iter().map(|x| format!("{}", x)).collect::<Vec<String>>().join(","));
685+
if operation == 0 {
686+
// println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
687+
conn.execute("INSERT INTO users VALUES (?, vector(?) )", libsql::params![next_id, vector_str]).await.unwrap();
688+
alive.insert(next_id);
689+
next_id += 1;
690+
} else if operation == 1 {
691+
let id = rng.gen_range(0..next_id);
692+
// println!("DELETE FROM users WHERE id = {};", id);
693+
conn.execute("DELETE FROM users WHERE id = ?", libsql::params![id]).await.unwrap();
694+
alive.remove(&id);
695+
} else if operation == 2 && !alive.is_empty() {
696+
let id = alive.iter().collect::<Vec<_>>()[rng.gen_range(0..alive.len())];
697+
// println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
698+
conn.execute("UPDATE users SET v = vector(?) WHERE id = ?", libsql::params![vector_str, id]).await.unwrap();
699+
} else if operation == 3 {
700+
let k = rng.gen_range(1..200);
701+
// println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
702+
let result = conn.query("SELECT * FROM vector_top_k('users_idx', ?, ?)", libsql::params![vector_str, k]).await.unwrap();
703+
let count = result.into_stream().count().await;
704+
assert!(count <= alive.len());
705+
if alive.len() > 0 {
706+
assert!(count > 0);
707+
}
708+
}
709+
}
710+
let _ = conn.execute("REINDEX users;", ()).await.unwrap();
711+
}
712+
}

0 commit comments

Comments
 (0)