@@ -1175,7 +1175,7 @@ static void vector_cleanup (sqlite3_context *context, int argc, sqlite3_value **
11751175
11761176// MARK: -
11771177
1178- static void * vector_convert_from_json (sqlite3_context * context , sqlite3_vtab * vtab , vector_type type , const char * json , int * size ) {
1178+ static void * vector_convert_from_json (sqlite3_context * context , sqlite3_vtab * vtab , vector_type type , const char * json , int * size , int dimension ) {
11791179 char * blob = NULL ;
11801180
11811181 // skip leading whitespace
@@ -1292,6 +1292,12 @@ static void *vector_convert_from_json (sqlite3_context *context, sqlite3_vtab *v
12921292 }
12931293 }
12941294
1295+ // sanity check vector dimension
1296+ if ((dimension > 0 ) && (dimension != count )) {
1297+ sqlite3_free (blob );
1298+ return sqlite_common_set_error (context , vtab , SQLITE_ERROR , "Invalid JSON vector dimension: expected %d but found %d." , dimension , count );
1299+ }
1300+
12951301 if (size ) * size = (int )(count * item_size );
12961302 return blob ;
12971303}
@@ -1301,12 +1307,23 @@ static void vector_convert (sqlite3_context *context, vector_type type, int argc
13011307 int value_size = sqlite3_value_bytes (value );
13021308 int value_type = sqlite3_value_type (value );
13031309
1310+ // dimension is an optional argument
1311+ int dimension = (argc == 2 ) ? sqlite3_value_int (argv [1 ]) : 0 ;
1312+
13041313 if (value_type == SQLITE_BLOB ) {
13051314 // the only check we can perform is that the blob size is an exact multiplier of the vector type
13061315 if (value_size % vector_type_to_size (type ) != 0 ) {
13071316 context_result_error (context , SQLITE_ERROR , "Invalid BLOB size for format '%s': size must be a multiple of %d bytes." , vector_type_to_name (type ), vector_type_to_size (type ));
13081317 return ;
13091318 }
1319+ if (dimension > 0 ) {
1320+ int expected_size = (int )vector_type_to_size (type ) * dimension ;
1321+ if (value_size != expected_size ) {
1322+ context_result_error (context , SQLITE_ERROR , "Invalid BLOB size for format '%s': expected dimension should be %d (BLOB is %d bytes instead of %d)." , vector_type_to_name (type ), dimension , value_size , expected_size );
1323+ return ;
1324+ }
1325+ }
1326+
13101327 sqlite3_result_value (context , value );
13111328 return ;
13121329 }
@@ -1319,7 +1336,7 @@ static void vector_convert (sqlite3_context *context, vector_type type, int argc
13191336 return ;
13201337 }
13211338
1322- char * blob = vector_convert_from_json (context , NULL , type , json , & value_size );
1339+ char * blob = vector_convert_from_json (context , NULL , type , json , & value_size , dimension );
13231340 if (!blob ) return ; // error is set in the context
13241341
13251342 sqlite3_result_blob (context , (const void * )blob , value_size , sqlite3_free );
@@ -1393,7 +1410,7 @@ static int vCursorFilterCommon (sqlite3_vtab_cursor *cur, int idxNum, const char
13931410 int vsize = 0 ;
13941411 if (sqlite3_value_type (argv [2 ]) == SQLITE_TEXT ) {
13951412 vsize = sqlite3_value_bytes (argv [2 ]);
1396- vector = (const void * )vector_convert_from_json (NULL , & vtab -> base , t_ctx -> options .v_type , (const char * )sqlite3_value_text (argv [2 ]), & vsize );
1413+ vector = (const void * )vector_convert_from_json (NULL , & vtab -> base , t_ctx -> options .v_type , (const char * )sqlite3_value_text (argv [2 ]), & vsize , t_ctx -> options . v_dim );
13971414 if (!vector ) return SQLITE_ERROR ; // error already set inside vector_convert_from_json
13981415 } else {
13991416 vector = (const void * )sqlite3_value_blob (argv [2 ]);
@@ -1896,18 +1913,23 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
18961913 if (rc != SQLITE_OK ) goto cleanup ;
18971914
18981915 rc = sqlite3_create_function (db , "vector_convert_f32" , 1 , SQLITE_UTF8 , ctx , vector_convert_f32 , NULL , NULL );
1916+ rc = sqlite3_create_function (db , "vector_convert_f32" , 2 , SQLITE_UTF8 , ctx , vector_convert_f32 , NULL , NULL );
18991917 if (rc != SQLITE_OK ) goto cleanup ;
19001918
19011919 rc = sqlite3_create_function (db , "vector_convert_f16" , 1 , SQLITE_UTF8 , ctx , vector_convert_f16 , NULL , NULL );
1920+ rc = sqlite3_create_function (db , "vector_convert_f16" , 2 , SQLITE_UTF8 , ctx , vector_convert_f16 , NULL , NULL );
19021921 if (rc != SQLITE_OK ) goto cleanup ;
19031922
19041923 rc = sqlite3_create_function (db , "vector_convert_bf16" , 1 , SQLITE_UTF8 , ctx , vector_convert_bf16 , NULL , NULL );
1924+ rc = sqlite3_create_function (db , "vector_convert_bf16" , 2 , SQLITE_UTF8 , ctx , vector_convert_bf16 , NULL , NULL );
19051925 if (rc != SQLITE_OK ) goto cleanup ;
19061926
19071927 rc = sqlite3_create_function (db , "vector_convert_i8" , 1 , SQLITE_UTF8 , ctx , vector_convert_i8 , NULL , NULL );
1928+ rc = sqlite3_create_function (db , "vector_convert_i8" , 2 , SQLITE_UTF8 , ctx , vector_convert_i8 , NULL , NULL );
19081929 if (rc != SQLITE_OK ) goto cleanup ;
19091930
19101931 rc = sqlite3_create_function (db , "vector_convert_u8" , 1 , SQLITE_UTF8 , ctx , vector_convert_u8 , NULL , NULL );
1932+ rc = sqlite3_create_function (db , "vector_convert_u8" , 2 , SQLITE_UTF8 , ctx , vector_convert_u8 , NULL , NULL );
19111933 if (rc != SQLITE_OK ) goto cleanup ;
19121934
19131935 rc = sqlite3_create_module (db , "vector_full_scan" , & vFullScanModule , ctx );
0 commit comments