Skip to content

Commit a77d45d

Browse files
authored
Merge pull request #1594 from tursodatabase/vector-search-pk-fix
vector search index key fix
2 parents a964315 + 1006f3b commit a77d45d

File tree

7 files changed

+310
-171
lines changed

7 files changed

+310
-171
lines changed

libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Lines changed: 97 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -85093,13 +85093,19 @@ struct VectorOutRows {
8509385093
sqlite3_value **ppValues;
8509485094
};
8509585095

85096+
// limit to the sql part which we render in order to perform operations with shadow tables
85097+
// we render this parts of SQL on stack - thats why we have hard limit on this
85098+
// stack simplify memory managment code and also doesn't impose very strict limits here since 128 bytes for column names should be enough for almost all use cases
85099+
#define VECTOR_INDEX_SQL_RENDER_LIMIT 128
85100+
8509685101
void vectorIdxParamsInit(VectorIdxParams *, u8 *, int);
8509785102
u64 vectorIdxParamsGetU64(const VectorIdxParams *, char);
8509885103
double vectorIdxParamsGetF64(const VectorIdxParams *, char);
8509985104
int vectorIdxParamsPutU64(VectorIdxParams *, char, u64);
8510085105
int vectorIdxParamsPutF64(VectorIdxParams *, char, double);
8510185106

85102-
int vectorIdxKeyGet(Table*, VectorIdxKey *, const char **);
85107+
int vectorIdxKeyGet(const Index *, VectorIdxKey *, const char **);
85108+
int vectorIdxKeyRowidLike(const VectorIdxKey *);
8510385109
int vectorIdxKeyDefsRender(const VectorIdxKey *, const char *, char *, int);
8510485110
int vectorIdxKeyNamesRender(int, const char *, char *, int);
8510585111

@@ -85110,7 +85116,7 @@ i64 vectorInRowLegacyId(const VectorInRow *);
8511085116
int vectorInRowPlaceholderRender(const VectorInRow *, char *, int);
8511185117
void vectorInRowFree(sqlite3 *, VectorInRow *);
8511285118

85113-
int vectorOutRowsAlloc(sqlite3 *, VectorOutRows *, int, int, char);
85119+
int vectorOutRowsAlloc(sqlite3 *, VectorOutRows *, int, int, int);
8511485120
int vectorOutRowsPut(VectorOutRows *, int, int, const u64 *, sqlite3_value *);
8511585121
void vectorOutRowsGet(sqlite3_context *, const VectorOutRows *, int, int);
8511685122
void vectorOutRowsFree(sqlite3 *, VectorOutRows *);
@@ -126101,21 +126107,6 @@ SQLITE_PRIVATE void sqlite3CreateIndex(
126101126107
pIndex->aSortOrder[i] = (u8)requestedSortOrder;
126102126108
}
126103126109

126104-
126105-
#ifndef SQLITE_OMIT_VECTOR
126106-
vectorIdxRc = vectorIndexCreate(pParse, pIndex, db->aDb[iDb].zDbSName, pUsing);
126107-
if( vectorIdxRc < 0 ){
126108-
goto exit_create_index;
126109-
}
126110-
if( vectorIdxRc >= 1 ){
126111-
idxType = SQLITE_IDXTYPE_VECTOR;
126112-
pIndex->idxType = idxType;
126113-
}
126114-
if( vectorIdxRc == 1 ){
126115-
skipRefill = 1;
126116-
}
126117-
#endif
126118-
126119126110
/* Append the table key to the end of the index. For WITHOUT ROWID
126120126111
** tables (when pPk!=0) this will be the declared PRIMARY KEY. For
126121126112
** normal tables (when pPk==0) this will be the rowid.
@@ -126142,6 +126133,26 @@ SQLITE_PRIVATE void sqlite3CreateIndex(
126142126133
sqlite3DefaultRowEst(pIndex);
126143126134
if( pParse->pNewTable==0 ) estimateIndexWidth(pIndex);
126144126135

126136+
#ifndef SQLITE_OMIT_VECTOR
126137+
// we want to have complete information about index columns before invocation of vectorIndexCreate method
126138+
vectorIdxRc = vectorIndexCreate(pParse, pIndex, db->aDb[iDb].zDbSName, pUsing);
126139+
if( vectorIdxRc < 0 ){
126140+
goto exit_create_index;
126141+
}
126142+
if( vectorIdxRc >= 1 ){
126143+
idxType = SQLITE_IDXTYPE_VECTOR;
126144+
/*
126145+
* SQLite can use B-Tree indices in some optimizations (like SELECT COUNT(*) can use any full B-Tree index instead of PK index)
126146+
* But, SQLite pretty conservative about usage of unordered indices - that's what we need here
126147+
*/
126148+
pIndex->bUnordered = 1;
126149+
pIndex->idxType = idxType;
126150+
}
126151+
if( vectorIdxRc == 1 ){
126152+
skipRefill = 1;
126153+
}
126154+
#endif
126155+
126145126156
/* If this index contains every column of its table, then mark
126146126157
** it as a covering index */
126147126158
assert( HasRowid(pTab)
@@ -209858,8 +209869,8 @@ int diskAnnCreateIndex(
209858209869
int type, dims;
209859209870
u64 maxNeighborsParam, blockSizeBytes;
209860209871
char *zSql;
209861-
char columnSqlDefs[DISKANN_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...)
209862-
char columnSqlNames[DISKANN_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
209872+
char columnSqlDefs[VECTOR_INDEX_SQL_RENDER_LIMIT]; // definition of columns (e.g. index_key INTEGER BINARY, index_key1 TEXT, ...)
209873+
char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
209863209874
if( vectorIdxKeyDefsRender(pKey, "index_key", columnSqlDefs, sizeof(columnSqlDefs)) != 0 ){
209864209875
return SQLITE_ERROR;
209865209876
}
@@ -209912,14 +209923,29 @@ int diskAnnCreateIndex(
209912209923
return SQLITE_ERROR;
209913209924
}
209914209925
}
209915-
zSql = sqlite3MPrintf(
209916-
db,
209917-
"CREATE TABLE IF NOT EXISTS \"%w\".%s_shadow (%s, data BLOB, PRIMARY KEY (%s))",
209918-
zDbSName,
209919-
zIdxName,
209920-
columnSqlDefs,
209921-
columnSqlNames
209922-
);
209926+
// we want to preserve rowid - so it must be explicit in the schema
209927+
// also, we don't want to store redundant set of fields - so the strategy is like that:
209928+
// 1. If we have single PK with INTEGER affinity and BINARY collation we only need single PK of same type
209929+
// 2. In other case we need rowid PK and unique index over other fields
209930+
if( vectorIdxKeyRowidLike(pKey) ){
209931+
zSql = sqlite3MPrintf(
209932+
db,
209933+
"CREATE TABLE IF NOT EXISTS \"%w\".%s_shadow (%s, data BLOB, PRIMARY KEY (%s))",
209934+
zDbSName,
209935+
zIdxName,
209936+
columnSqlDefs,
209937+
columnSqlNames
209938+
);
209939+
}else{
209940+
zSql = sqlite3MPrintf(
209941+
db,
209942+
"CREATE TABLE IF NOT EXISTS \"%w\".%s_shadow (rowid INTEGER PRIMARY KEY, %s, data BLOB, UNIQUE (%s))",
209943+
zDbSName,
209944+
zIdxName,
209945+
columnSqlDefs,
209946+
columnSqlNames
209947+
);
209948+
}
209923209949
rc = sqlite3_exec(db, zSql, 0, 0, 0);
209924209950
sqlite3DbFree(db, zSql);
209925209951
return rc;
@@ -209992,8 +210018,8 @@ static int diskAnnGetShadowRowid(const DiskAnnIndex *pIndex, const VectorInRow *
209992210018
sqlite3_stmt *pStmt = NULL;
209993210019
char *zSql = NULL;
209994210020

209995-
char columnSqlNames[DISKANN_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
209996-
char columnSqlPlaceholders[DISKANN_SQL_RENDER_LIMIT]; // just placeholders (e.g. ?,?,?, ...)
210021+
char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
210022+
char columnSqlPlaceholders[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just placeholders (e.g. ?,?,?, ...)
209997210023
if( vectorIdxKeyNamesRender(pInRow->nKeys, "index_key", columnSqlNames, sizeof(columnSqlNames)) != 0 ){
209998210024
rc = SQLITE_ERROR;
209999210025
goto out;
@@ -210050,7 +210076,7 @@ static int diskAnnGetShadowRowKeys(const DiskAnnIndex *pIndex, u64 nRowid, const
210050210076
sqlite3_stmt *pStmt = NULL;
210051210077
char *zSql = NULL;
210052210078

210053-
char columnSqlNames[DISKANN_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
210079+
char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
210054210080
if( vectorIdxKeyNamesRender(pKey->nKeyColumns, "index_key", columnSqlNames, sizeof(columnSqlNames)) != 0 ){
210055210081
rc = SQLITE_ERROR;
210056210082
goto out;
@@ -210104,15 +210130,19 @@ static int diskAnnInsertShadowRow(const DiskAnnIndex *pIndex, const VectorInRow
210104210130
sqlite3_stmt *pStmt = NULL;
210105210131
char *zSql = NULL;
210106210132

210107-
char columnSqlPlaceholders[DISKANN_SQL_RENDER_LIMIT]; // just placeholders (e.g. ?,?,?, ...)
210133+
char columnSqlPlaceholders[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just placeholders (e.g. ?,?,?, ...)
210134+
char columnSqlNames[VECTOR_INDEX_SQL_RENDER_LIMIT]; // just column names (e.g. index_key, index_key1, index_key2, ...)
210108210135
if( vectorInRowPlaceholderRender(pVectorInRow, columnSqlPlaceholders, sizeof(columnSqlPlaceholders)) != 0 ){
210109210136
rc = SQLITE_ERROR;
210110210137
goto out;
210111210138
}
210139+
if( vectorIdxKeyNamesRender(pVectorInRow->nKeys, "index_key", columnSqlNames, sizeof(columnSqlNames)) != 0 ){
210140+
return SQLITE_ERROR;
210141+
}
210112210142
zSql = sqlite3MPrintf(
210113210143
pIndex->db,
210114-
"INSERT INTO \"%w\".%s VALUES (%s, ?) RETURNING rowid",
210115-
pIndex->zDbSName, pIndex->zShadow, columnSqlPlaceholders
210144+
"INSERT INTO \"%w\".%s(%s, data) VALUES (%s, ?) RETURNING rowid",
210145+
pIndex->zDbSName, pIndex->zShadow, columnSqlNames, columnSqlPlaceholders
210116210146
);
210117210147
if( zSql == NULL ){
210118210148
rc = SQLITE_NOMEM_BKPT;
@@ -210669,7 +210699,7 @@ int diskAnnSearch(
210669210699
goto out;
210670210700
}
210671210701
nOutRows = MIN(k, ctx.nCandidates);
210672-
rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, pKey->aKeyAffinity[0]);
210702+
rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey));
210673210703
if( rc != SQLITE_OK ){
210674210704
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows");
210675210705
goto out;
@@ -211577,30 +211607,41 @@ int vectorIdxParamsPutF64(VectorIdxParams *pParams, char tag, double value) {
211577211607
** VectorIdxKey utilities
211578211608
****************************************************************************/
211579211609

211580-
int vectorIdxKeyGet(Table *pTable, VectorIdxKey *pKey, const char **pzErrMsg) {
211581-
int i;
211582-
Index *pPk;
211583-
// we actually need to change strategy here and use PK if it's available and fallback to ROWID only if there is no other choice
211584-
// will change this later as it must be done carefully in order to not brake behaviour of existing indices
211585-
if( !HasRowid(pTable) ){
211586-
pPk = sqlite3PrimaryKeyIndex(pTable);
211587-
if( pPk->nKeyCol > VECTOR_INDEX_MAX_KEY_COLUMNS ){
211588-
*pzErrMsg = "exceeded limit for composite columns in primary key index";
211589-
return -1;
211590-
}
211591-
pKey->nKeyColumns = pPk->nKeyCol;
211592-
for(i = 0; i < pPk->nKeyCol; i++){
211593-
pKey->aKeyAffinity[i] = pTable->aCol[pPk->aiColumn[i]].affinity;
211594-
pKey->azKeyCollation[i] = pPk->azColl[i];
211595-
}
211596-
} else{
211610+
int vectorIdxKeyGet(const Index *pIndex, VectorIdxKey *pKey, const char **pzErrMsg) {
211611+
Table *pTable;
211612+
Index *pPkIndex;
211613+
int i, nKeyColumns;
211614+
211615+
assert( pIndex->nKeyCol == 1 );
211616+
assert( pIndex->nColumn > pIndex->nKeyCol );
211617+
211618+
pTable = pIndex->pTable;
211619+
nKeyColumns = pIndex->nColumn - pIndex->nKeyCol;
211620+
if( nKeyColumns == 1 && pIndex->aiColumn[pIndex->nKeyCol] == XN_ROWID ){
211597211621
pKey->nKeyColumns = 1;
211598211622
pKey->aKeyAffinity[0] = SQLITE_AFF_INTEGER;
211599211623
pKey->azKeyCollation[0] = "BINARY";
211624+
return 0;
211625+
}
211626+
if( nKeyColumns > VECTOR_INDEX_MAX_KEY_COLUMNS ){
211627+
*pzErrMsg = "exceeded limit for composite columns in primary key index";
211628+
return -1;
211629+
}
211630+
pPkIndex = sqlite3PrimaryKeyIndex(pIndex->pTable);
211631+
assert( pPkIndex->nKeyCol == nKeyColumns );
211632+
211633+
pKey->nKeyColumns = nKeyColumns;
211634+
for(i = 0; i < pPkIndex->nKeyCol; i++){
211635+
pKey->aKeyAffinity[i] = pTable->aCol[pPkIndex->aiColumn[i]].affinity;
211636+
pKey->azKeyCollation[i] = pPkIndex->azColl[i];
211600211637
}
211601211638
return 0;
211602211639
}
211603211640

211641+
int vectorIdxKeyRowidLike(const VectorIdxKey *pKey){
211642+
return pKey->nKeyColumns == 1 && pKey->aKeyAffinity[0] == SQLITE_AFF_INTEGER && sqlite3StrICmp(pKey->azKeyCollation[0], "BINARY") == 0;
211643+
}
211644+
211604211645
int vectorIdxKeyDefsRender(const VectorIdxKey *pKey, const char *prefix, char *pBuf, int nBufSize) {
211605211646
static const char * const azType[] = {
211606211647
/* SQLITE_AFF_BLOB */ " BLOB",
@@ -211748,7 +211789,7 @@ void vectorInRowFree(sqlite3 *db, VectorInRow *pVectorInRow) {
211748211789
** VectorOutRows utilities
211749211790
****************************************************************************/
211750211791

211751-
int vectorOutRowsAlloc(sqlite3 *db, VectorOutRows *pRows, int nRows, int nCols, char firstColumnAff){
211792+
int vectorOutRowsAlloc(sqlite3 *db, VectorOutRows *pRows, int nRows, int nCols, int rowidLike){
211752211793
assert( nCols > 0 && nRows >= 0 );
211753211794
pRows->nRows = nRows;
211754211795
pRows->nCols = nCols;
@@ -211759,7 +211800,8 @@ int vectorOutRowsAlloc(sqlite3 *db, VectorOutRows *pRows, int nRows, int nCols,
211759211800
return SQLITE_NOMEM_BKPT;
211760211801
}
211761211802

211762-
if( nCols == 1 && firstColumnAff == SQLITE_AFF_INTEGER ){
211803+
if( rowidLike ){
211804+
assert( nCols == 1 );
211763211805
pRows->aIntValues = sqlite3DbMallocRaw(db, nRows * sizeof(i64));
211764211806
if( pRows->aIntValues == NULL ){
211765211807
return SQLITE_NOMEM_BKPT;
@@ -212383,7 +212425,7 @@ int vectorIndexCreate(Parse *pParse, const Index *pIdx, const char *zDbSName, co
212383212425
if( rc != SQLITE_OK ){
212384212426
return CREATE_FAIL;
212385212427
}
212386-
if( vectorIdxKeyGet(pTable, &idxKey, &pzErrMsg) != 0 ){
212428+
if( vectorIdxKeyGet(pIdx, &idxKey, &pzErrMsg) != 0 ){
212387212429
sqlite3ErrorMsg(pParse, "vector index: failed to detect underlying table key: %s", pzErrMsg);
212388212430
return CREATE_FAIL;
212389212431
}
@@ -212480,7 +212522,7 @@ int vectorIndexSearch(sqlite3 *db, const char* zDbSName, int argc, sqlite3_value
212480212522
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to open diskann index");
212481212523
goto out;
212482212524
}
212483-
if( vectorIdxKeyGet(pIndex->pTable, &pKey, &zErrMsg) != 0 ){
212525+
if( vectorIdxKeyGet(pIndex, &pKey, &zErrMsg) != 0 ){
212484212526
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to extract table key: %s", zErrMsg);
212485212527
rc = SQLITE_ERROR;
212486212528
goto out;

0 commit comments

Comments
 (0)