Skip to content

Commit 54ff421

Browse files
authored
Merge pull request #1761 from tursodatabase/vector-search-accept-k-float
accept K parameter as float if there is no loss in the precision after rounding to the integer
2 parents 8abff7b + 80a10f9 commit 54ff421

File tree

4 files changed

+76
-25
lines changed

4 files changed

+76
-25
lines changed

libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216033,6 +216033,7 @@ int vectorIndexSearch(
216033216033
char **pzErrMsg
216034216034
) {
216035216035
int type, dims, k, rc;
216036+
double kDouble;
216036216037
const char *zIdxName;
216037216038
const char *zErrMsg;
216038216039
Vector *pVector = NULL;
@@ -216063,17 +216064,32 @@ int vectorIndexSearch(
216063216064
rc = SQLITE_ERROR;
216064216065
goto out;
216065216066
}
216066-
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
216067-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
216068-
rc = SQLITE_ERROR;
216069-
goto out;
216070-
}
216071-
k = sqlite3_value_int(argv[2]);
216072-
if( k < 0 ){
216073-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
216067+
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
216068+
k = sqlite3_value_int(argv[2]);
216069+
if( k < 0 ){
216070+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
216071+
rc = SQLITE_ERROR;
216072+
goto out;
216073+
}
216074+
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
216075+
kDouble = sqlite3_value_double(argv[2]);
216076+
k = (int)kDouble;
216077+
if( (double)k != kDouble ){
216078+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
216079+
rc = SQLITE_ERROR;
216080+
goto out;
216081+
}
216082+
if( k < 0 ){
216083+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
216084+
rc = SQLITE_ERROR;
216085+
goto out;
216086+
}
216087+
}else{
216088+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
216074216089
rc = SQLITE_ERROR;
216075216090
goto out;
216076216091
}
216092+
216077216093
if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
216078216094
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
216079216095
rc = SQLITE_ERROR;

libsql-ffi/bundled/src/sqlite3.c

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216033,6 +216033,7 @@ int vectorIndexSearch(
216033216033
char **pzErrMsg
216034216034
) {
216035216035
int type, dims, k, rc;
216036+
double kDouble;
216036216037
const char *zIdxName;
216037216038
const char *zErrMsg;
216038216039
Vector *pVector = NULL;
@@ -216063,17 +216064,32 @@ int vectorIndexSearch(
216063216064
rc = SQLITE_ERROR;
216064216065
goto out;
216065216066
}
216066-
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
216067-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
216068-
rc = SQLITE_ERROR;
216069-
goto out;
216070-
}
216071-
k = sqlite3_value_int(argv[2]);
216072-
if( k < 0 ){
216073-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
216067+
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
216068+
k = sqlite3_value_int(argv[2]);
216069+
if( k < 0 ){
216070+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
216071+
rc = SQLITE_ERROR;
216072+
goto out;
216073+
}
216074+
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
216075+
kDouble = sqlite3_value_double(argv[2]);
216076+
k = (int)kDouble;
216077+
if( (double)k != kDouble ){
216078+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
216079+
rc = SQLITE_ERROR;
216080+
goto out;
216081+
}
216082+
if( k < 0 ){
216083+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
216084+
rc = SQLITE_ERROR;
216085+
goto out;
216086+
}
216087+
}else{
216088+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
216074216089
rc = SQLITE_ERROR;
216075216090
goto out;
216076216091
}
216092+
216077216093
if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
216078216094
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
216079216095
rc = SQLITE_ERROR;

libsql-sqlite3/src/vectorIndex.c

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ int vectorIndexSearch(
951951
char **pzErrMsg
952952
) {
953953
int type, dims, k, rc;
954+
double kDouble;
954955
const char *zIdxName;
955956
const char *zErrMsg;
956957
Vector *pVector = NULL;
@@ -981,17 +982,32 @@ int vectorIndexSearch(
981982
rc = SQLITE_ERROR;
982983
goto out;
983984
}
984-
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
985-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
986-
rc = SQLITE_ERROR;
987-
goto out;
988-
}
989-
k = sqlite3_value_int(argv[2]);
990-
if( k < 0 ){
991-
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
985+
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
986+
k = sqlite3_value_int(argv[2]);
987+
if( k < 0 ){
988+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
989+
rc = SQLITE_ERROR;
990+
goto out;
991+
}
992+
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
993+
kDouble = sqlite3_value_double(argv[2]);
994+
k = (int)kDouble;
995+
if( (double)k != kDouble ){
996+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
997+
rc = SQLITE_ERROR;
998+
goto out;
999+
}
1000+
if( k < 0 ){
1001+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
1002+
rc = SQLITE_ERROR;
1003+
goto out;
1004+
}
1005+
}else{
1006+
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
9921007
rc = SQLITE_ERROR;
9931008
goto out;
9941009
}
1010+
9951011
if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
9961012
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
9971013
rc = SQLITE_ERROR;

libsql-sqlite3/test/libsql_vector_index.test

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ do_execsql_test vector-simple {
117117
SELECT * FROM vector_top_k('t_simple_idx', '[1,2,3]', 1);
118118
SELECT * FROM vector_top_k('t_simple_idx', '[5,6,7]', 1);
119119
SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), 1);
120-
} {{1} {3} {1}}
120+
SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), CAST(1 as REAL));
121+
} {{1} {3} {1} {1}}
121122

122123
do_execsql_test vector-empty {
123124
CREATE TABLE t_empty( v FLOAT32(3));
@@ -484,6 +485,7 @@ do_test vector-errors {
484485
lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector('[1, 2, 3, 4, 5]'))}]
485486
lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector64('[1,2,3,4]'))}]
486487
lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2]'), 2)}]
488+
lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2,3,4]'), 2.5)}]
487489
sqlite3_exec db { CREATE TABLE t_mixed_t( v FLOAT32(3)); }
488490
sqlite3_exec db { INSERT INTO t_mixed_t VALUES('[1]'); }
489491
lappend ret [error_messages {CREATE INDEX t_mixed_t_idx ON t_mixed_t( libsql_vector_idx(v) )}]
@@ -503,5 +505,6 @@ do_test vector-errors {
503505
{vector index(insert): dimensions are different: 5 != 4}
504506
{vector index(insert): vector type differs from column type: 2 != 1}
505507
{vector index(search): dimensions are different: 2 != 4}
508+
{vector index(search): third parameter (k) must be an integer, but float value were provided}
506509
{vector index(insert): dimensions are different: 1 != 3}
507510
}]

0 commit comments

Comments
 (0)