|
1 | 1 | #![allow(deprecated)]
|
2 | 2 |
|
| 3 | +use rand::prelude::*; |
| 4 | +use rand::distributions::Uniform; |
| 5 | +use std::collections::HashSet; |
3 | 6 | use futures::{StreamExt, TryStreamExt};
|
4 | 7 | use libsql::{
|
5 | 8 | named_params, params,
|
@@ -650,3 +653,60 @@ async fn deserialize_row() {
|
650 | 653 | assert_eq!(data.status, Status::Draft);
|
651 | 654 | assert_eq!(data.wrapper, Wrapper(Status::Published));
|
652 | 655 | }
|
| 656 | + |
| 657 | +#[tokio::test] |
| 658 | +#[ignore] |
| 659 | +// fuzz test can be run explicitly with following command: |
| 660 | +// cargo test vector_fuzz_test -- --nocapture --include-ignored |
| 661 | +async fn vector_fuzz_test() { |
| 662 | + let mut global_rng = rand::thread_rng(); |
| 663 | + for attempt in 0..10000 { |
| 664 | + let seed = global_rng.next_u64(); |
| 665 | + |
| 666 | + let mut rng = rand::rngs::StdRng::from_seed(unsafe { std::mem::transmute([seed, seed, seed, seed]) }); |
| 667 | + let db = Database::open(":memory:").unwrap(); |
| 668 | + let conn = db.connect().unwrap(); |
| 669 | + let dim = rng.gen_range(1..=1536); |
| 670 | + let operations = rng.gen_range(1..128); |
| 671 | + println!("============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================", attempt, seed, dim, operations); |
| 672 | + |
| 673 | + let _ = conn.execute(&format!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )", dim), ()).await; |
| 674 | + // println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim); |
| 675 | + let _ = conn.execute("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );", ()).await; |
| 676 | + // println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );"); |
| 677 | + |
| 678 | + let mut next_id = 1; |
| 679 | + let mut alive = HashSet::new(); |
| 680 | + let uniform = Uniform::new(-1.0, 1.0); |
| 681 | + for _ in 0..operations { |
| 682 | + let operation = rng.gen_range(0..4); |
| 683 | + let vector : Vec<f32> = (0..dim).map(|_| rng.sample(uniform)).collect(); |
| 684 | + let vector_str = format!("[{}]", vector.iter().map(|x| format!("{}", x)).collect::<Vec<String>>().join(",")); |
| 685 | + if operation == 0 { |
| 686 | + // println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str); |
| 687 | + conn.execute("INSERT INTO users VALUES (?, vector(?) )", libsql::params![next_id, vector_str]).await.unwrap(); |
| 688 | + alive.insert(next_id); |
| 689 | + next_id += 1; |
| 690 | + } else if operation == 1 { |
| 691 | + let id = rng.gen_range(0..next_id); |
| 692 | + // println!("DELETE FROM users WHERE id = {};", id); |
| 693 | + conn.execute("DELETE FROM users WHERE id = ?", libsql::params![id]).await.unwrap(); |
| 694 | + alive.remove(&id); |
| 695 | + } else if operation == 2 && !alive.is_empty() { |
| 696 | + let id = alive.iter().collect::<Vec<_>>()[rng.gen_range(0..alive.len())]; |
| 697 | + // println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id); |
| 698 | + conn.execute("UPDATE users SET v = vector(?) WHERE id = ?", libsql::params![vector_str, id]).await.unwrap(); |
| 699 | + } else if operation == 3 { |
| 700 | + let k = rng.gen_range(1..200); |
| 701 | + // println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k); |
| 702 | + let result = conn.query("SELECT * FROM vector_top_k('users_idx', ?, ?)", libsql::params![vector_str, k]).await.unwrap(); |
| 703 | + let count = result.into_stream().count().await; |
| 704 | + assert!(count <= alive.len()); |
| 705 | + if alive.len() > 0 { |
| 706 | + assert!(count > 0); |
| 707 | + } |
| 708 | + } |
| 709 | + } |
| 710 | + let _ = conn.execute("REINDEX users;", ()).await.unwrap(); |
| 711 | + } |
| 712 | +} |
0 commit comments