@@ -6,6 +6,9 @@ use libsql::{
6
6
params:: { IntoParams , IntoValue } ,
7
7
Connection , Database , Value ,
8
8
} ;
9
+ use rand:: distributions:: Uniform ;
10
+ use rand:: prelude:: * ;
11
+ use std:: collections:: HashSet ;
9
12
10
13
async fn setup ( ) -> Connection {
11
14
let db = Database :: open ( ":memory:" ) . unwrap ( ) ;
@@ -650,3 +653,102 @@ async fn deserialize_row() {
650
653
assert_eq ! ( data. status, Status :: Draft ) ;
651
654
assert_eq ! ( data. wrapper, Wrapper ( Status :: Published ) ) ;
652
655
}
656
+
657
+ #[ tokio:: test]
658
+ #[ ignore]
659
+ // fuzz test can be run explicitly with following command:
660
+ // cargo test vector_fuzz_test -- --nocapture --include-ignored
661
+ async fn vector_fuzz_test ( ) {
662
+ let mut global_rng = rand:: thread_rng ( ) ;
663
+ for attempt in 0 ..10000 {
664
+ let seed = global_rng. next_u64 ( ) ;
665
+
666
+ let mut rng =
667
+ rand:: rngs:: StdRng :: from_seed ( unsafe { std:: mem:: transmute ( [ seed, seed, seed, seed] ) } ) ;
668
+ let db = Database :: open ( ":memory:" ) . unwrap ( ) ;
669
+ let conn = db. connect ( ) . unwrap ( ) ;
670
+ let dim = rng. gen_range ( 1 ..=1536 ) ;
671
+ let operations = rng. gen_range ( 1 ..128 ) ;
672
+ println ! (
673
+ "============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================" ,
674
+ attempt, seed, dim, operations
675
+ ) ;
676
+
677
+ let _ = conn
678
+ . execute (
679
+ & format ! (
680
+ "CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )" ,
681
+ dim
682
+ ) ,
683
+ ( ) ,
684
+ )
685
+ . await ;
686
+ // println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
687
+ let _ = conn
688
+ . execute (
689
+ "CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );" ,
690
+ ( ) ,
691
+ )
692
+ . await ;
693
+ // println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");
694
+
695
+ let mut next_id = 1 ;
696
+ let mut alive = HashSet :: new ( ) ;
697
+ let uniform = Uniform :: new ( -1.0 , 1.0 ) ;
698
+ for _ in 0 ..operations {
699
+ let operation = rng. gen_range ( 0 ..4 ) ;
700
+ let vector: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. sample ( uniform) ) . collect ( ) ;
701
+ let vector_str = format ! (
702
+ "[{}]" ,
703
+ vector
704
+ . iter( )
705
+ . map( |x| format!( "{}" , x) )
706
+ . collect:: <Vec <String >>( )
707
+ . join( "," )
708
+ ) ;
709
+ if operation == 0 {
710
+ // println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
711
+ conn. execute (
712
+ "INSERT INTO users VALUES (?, vector(?) )" ,
713
+ libsql:: params![ next_id, vector_str] ,
714
+ )
715
+ . await
716
+ . unwrap ( ) ;
717
+ alive. insert ( next_id) ;
718
+ next_id += 1 ;
719
+ } else if operation == 1 {
720
+ let id = rng. gen_range ( 0 ..next_id) ;
721
+ // println!("DELETE FROM users WHERE id = {};", id);
722
+ conn. execute ( "DELETE FROM users WHERE id = ?" , libsql:: params![ id] )
723
+ . await
724
+ . unwrap ( ) ;
725
+ alive. remove ( & id) ;
726
+ } else if operation == 2 && !alive. is_empty ( ) {
727
+ let id = alive. iter ( ) . collect :: < Vec < _ > > ( ) [ rng. gen_range ( 0 ..alive. len ( ) ) ] ;
728
+ // println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
729
+ conn. execute (
730
+ "UPDATE users SET v = vector(?) WHERE id = ?" ,
731
+ libsql:: params![ vector_str, id] ,
732
+ )
733
+ . await
734
+ . unwrap ( ) ;
735
+ } else if operation == 3 {
736
+ let k = rng. gen_range ( 1 ..200 ) ;
737
+ // println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
738
+ let result = conn
739
+ . query (
740
+ "SELECT * FROM vector_top_k('users_idx', ?, ?)" ,
741
+ libsql:: params![ vector_str, k] ,
742
+ )
743
+ . await
744
+ . unwrap ( ) ;
745
+ let count = result. into_stream ( ) . count ( ) . await ;
746
+ assert ! ( count <= alive. len( ) ) ;
747
+ if alive. len ( ) > 0 {
748
+ assert ! ( count > 0 ) ;
749
+ }
750
+ }
751
+ }
752
+ let _ = conn. execute ( "REINDEX users;" , ( ) ) . await . unwrap ( ) ;
753
+ }
754
+ }
0 commit comments