Skip to content

Commit d5179a9

Browse files
committed
Added new optional dimension argument to vector_convert_ functions
1 parent 19c6ac2 commit d5179a9

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/sqlite-vector.c

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ static void vector_cleanup (sqlite3_context *context, int argc, sqlite3_value **
11751175

11761176
// MARK: -
11771177

1178-
static void *vector_convert_from_json (sqlite3_context *context, sqlite3_vtab *vtab, vector_type type, const char *json, int *size) {
1178+
static void *vector_convert_from_json (sqlite3_context *context, sqlite3_vtab *vtab, vector_type type, const char *json, int *size, int dimension) {
11791179
char *blob = NULL;
11801180

11811181
// skip leading whitespace
@@ -1292,6 +1292,12 @@ static void *vector_convert_from_json (sqlite3_context *context, sqlite3_vtab *v
12921292
}
12931293
}
12941294

1295+
// sanity check vector dimension
1296+
if ((dimension > 0) && (dimension != count)) {
1297+
sqlite3_free(blob);
1298+
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Invalid JSON vector dimension: expected %d but found %d.", dimension, count);
1299+
}
1300+
12951301
if (size) *size = (int)(count * item_size);
12961302
return blob;
12971303
}
@@ -1301,12 +1307,23 @@ static void vector_convert (sqlite3_context *context, vector_type type, int argc
13011307
int value_size = sqlite3_value_bytes(value);
13021308
int value_type = sqlite3_value_type(value);
13031309

1310+
// dimension is an optional argument
1311+
int dimension = (argc == 2) ? sqlite3_value_int(argv[1]) : 0;
1312+
13041313
if (value_type == SQLITE_BLOB) {
13051314
// the only check we can perform is that the blob size is an exact multiplier of the vector type
13061315
if (value_size % vector_type_to_size(type) != 0) {
13071316
context_result_error(context, SQLITE_ERROR, "Invalid BLOB size for format '%s': size must be a multiple of %d bytes.", vector_type_to_name(type), vector_type_to_size(type));
13081317
return;
13091318
}
1319+
if (dimension > 0) {
1320+
int expected_size = (int)vector_type_to_size(type) * dimension;
1321+
if (value_size != expected_size) {
1322+
context_result_error(context, SQLITE_ERROR, "Invalid BLOB size for format '%s': expected dimension should be %d (BLOB is %d bytes instead of %d).", vector_type_to_name(type), dimension, value_size, expected_size);
1323+
return;
1324+
}
1325+
}
1326+
13101327
sqlite3_result_value(context, value);
13111328
return;
13121329
}
@@ -1319,7 +1336,7 @@ static void vector_convert (sqlite3_context *context, vector_type type, int argc
13191336
return;
13201337
}
13211338

1322-
char *blob = vector_convert_from_json(context, NULL, type, json, &value_size);
1339+
char *blob = vector_convert_from_json(context, NULL, type, json, &value_size, dimension);
13231340
if (!blob) return; // error is set in the context
13241341

13251342
sqlite3_result_blob(context, (const void *)blob, value_size, sqlite3_free);
@@ -1393,7 +1410,7 @@ static int vCursorFilterCommon (sqlite3_vtab_cursor *cur, int idxNum, const char
13931410
int vsize = 0;
13941411
if (sqlite3_value_type(argv[2]) == SQLITE_TEXT) {
13951412
vsize = sqlite3_value_bytes(argv[2]);
1396-
vector = (const void *)vector_convert_from_json(NULL, &vtab->base, t_ctx->options.v_type, (const char *)sqlite3_value_text(argv[2]), &vsize);
1413+
vector = (const void *)vector_convert_from_json(NULL, &vtab->base, t_ctx->options.v_type, (const char *)sqlite3_value_text(argv[2]), &vsize, t_ctx->options.v_dim);
13971414
if (!vector) return SQLITE_ERROR; // error already set inside vector_convert_from_json
13981415
} else {
13991416
vector = (const void *)sqlite3_value_blob(argv[2]);
@@ -1896,18 +1913,23 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
18961913
if (rc != SQLITE_OK) goto cleanup;
18971914

18981915
rc = sqlite3_create_function(db, "vector_convert_f32", 1, SQLITE_UTF8, ctx, vector_convert_f32, NULL, NULL);
1916+
rc = sqlite3_create_function(db, "vector_convert_f32", 2, SQLITE_UTF8, ctx, vector_convert_f32, NULL, NULL);
18991917
if (rc != SQLITE_OK) goto cleanup;
19001918

19011919
rc = sqlite3_create_function(db, "vector_convert_f16", 1, SQLITE_UTF8, ctx, vector_convert_f16, NULL, NULL);
1920+
rc = sqlite3_create_function(db, "vector_convert_f16", 2, SQLITE_UTF8, ctx, vector_convert_f16, NULL, NULL);
19021921
if (rc != SQLITE_OK) goto cleanup;
19031922

19041923
rc = sqlite3_create_function(db, "vector_convert_bf16", 1, SQLITE_UTF8, ctx, vector_convert_bf16, NULL, NULL);
1924+
rc = sqlite3_create_function(db, "vector_convert_bf16", 2, SQLITE_UTF8, ctx, vector_convert_bf16, NULL, NULL);
19051925
if (rc != SQLITE_OK) goto cleanup;
19061926

19071927
rc = sqlite3_create_function(db, "vector_convert_i8", 1, SQLITE_UTF8, ctx, vector_convert_i8, NULL, NULL);
1928+
rc = sqlite3_create_function(db, "vector_convert_i8", 2, SQLITE_UTF8, ctx, vector_convert_i8, NULL, NULL);
19081929
if (rc != SQLITE_OK) goto cleanup;
19091930

19101931
rc = sqlite3_create_function(db, "vector_convert_u8", 1, SQLITE_UTF8, ctx, vector_convert_u8, NULL, NULL);
1932+
rc = sqlite3_create_function(db, "vector_convert_u8", 2, SQLITE_UTF8, ctx, vector_convert_u8, NULL, NULL);
19111933
if (rc != SQLITE_OK) goto cleanup;
19121934

19131935
rc = sqlite3_create_module(db, "vector_full_scan", &vFullScanModule, ctx);

src/sqlite-vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_VECTOR_VERSION "0.8.6"
27+
#define SQLITE_VECTOR_VERSION "0.8.7"
2828

2929
SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)