Skip to content

Commit adae11e

Browse files
committed
Several improvements to embedding generation
1 parent 363a0d3 commit adae11e

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

src/sqlite-ai.c

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ SQLITE_EXTENSION_INIT1
5050

5151
#define OPTION_KEY_GENERATE_EMBEDDING "generate_embedding"
5252
#define OPTION_KEY_NORMALIZE_EMBEDDING "normalize_embedding"
53+
#define OPTION_KEY_POOLING_TYPE "pooling_type"
54+
#define OPTION_KEY_ATTENTION_TYPE "attention_type"
5355
#define OPTION_KEY_MAX_TOKENS "max_tokens"
5456
#define OPTION_KEY_JSON_OUTPUT "json_output"
5557
#define OPTION_KEY_GPU_LAYERS "gpu_layers"
@@ -70,16 +72,16 @@ typedef struct {
7072
bool log_info; // flag to enable/disable the logging of info
7173

7274
// ** CONTEXT **
73-
uint32_t context_size; // set both n_ctx and n_batch
75+
uint32_t context_size; // set both n_ctx and n_batch (*** DONE ***)
7476
uint32_t n_ctx; // text context, 0 = from model
7577
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
7678
uint32_t n_ubatch; // physical maximum batch size
7779
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
7880
int32_t n_threads; // number of threads to use for generation
7981
int32_t n_threads_batch; // number of threads to use for batch processing
8082
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
81-
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
82-
enum llama_attention_type attention_type; // attention type to use for embeddings
83+
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (*** DONE ***)
84+
enum llama_attention_type attention_type; // attention type to use for embeddings (*** DONE ***)
8385
float rope_freq_base; // RoPE base frequency, 0 = from model
8486
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
8587
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
@@ -103,6 +105,7 @@ typedef struct {
103105
bool normalize_embedding;// if true, embeddings are normalized
104106
bool json_output; // if true, embedding result is converted to JSON
105107

108+
106109
// ** CUSTOM **
107110
int32_t max_tokens; // to control max allowed tokens to generate (to control user's input size)
108111

@@ -132,9 +135,6 @@ typedef struct {
132135
// whisper
133136
struct whisper_context *whisper;
134137

135-
// embedding
136-
llama_seq_id sequence_id; // some models requires to be unique across multiple calls to llm_embed_generate
137-
138138
// chat
139139
struct {
140140
char uuid[UUID_STR_MAXLEN];
@@ -208,13 +208,21 @@ void llm_set_context_options (struct llama_context_params *llama_context, llm_op
208208
if (options->generate_embedding) {
209209
// https://github.com/ggml-org/llama.cpp/discussions/15093
210210
llama_context->embeddings = true;
211-
llama_context->pooling_type = LLAMA_POOLING_TYPE_LAST;
211+
llama_context->pooling_type = LLAMA_POOLING_TYPE_MEAN;
212212
}
213213

214214
if (options->context_size) {
215215
llama_context->n_ctx = options->context_size;
216216
llama_context->n_batch = options->context_size;
217217
}
218+
219+
if (options->pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED) {
220+
llama_context->pooling_type = options->pooling_type;
221+
}
222+
223+
if (options->attention_type != LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
224+
llama_context->attention_type = options->attention_type;
225+
}
218226
}
219227

220228
static void llm_options_init (llm_options *options) {
@@ -223,6 +231,8 @@ static void llm_options_init (llm_options *options) {
223231
options->normalize_embedding = true;
224232
options->max_tokens = 0; // no limits
225233
options->log_info = false; // disable INFO messages logging
234+
options->pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED;
235+
options->attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED;
226236
}
227237

228238
static bool llm_options_callback (void *xdata, const char *key, int key_len, const char *value, int value_len) {
@@ -252,6 +262,19 @@ static bool llm_options_callback (void *xdata, const char *key, int key_len, con
252262
return true;
253263
}
254264

265+
if (strncasecmp(key, OPTION_KEY_POOLING_TYPE, key_len) == 0) {
266+
if (strcasecmp(buffer, "none") == 0) options->pooling_type = LLAMA_POOLING_TYPE_NONE;
267+
else if (strcasecmp(buffer, "mean") == 0) options->pooling_type = LLAMA_POOLING_TYPE_MEAN;
268+
else if (strcasecmp(buffer, "cls") == 0) options->pooling_type = LLAMA_POOLING_TYPE_CLS;
269+
else if (strcasecmp(buffer, "last") == 0) options->pooling_type = LLAMA_POOLING_TYPE_LAST;
270+
else if (strcasecmp(buffer, "rank") == 0) options->pooling_type = LLAMA_POOLING_TYPE_RANK;
271+
}
272+
273+
if (strncasecmp(key, OPTION_KEY_ATTENTION_TYPE, key_len) == 0) {
274+
if (strcasecmp(buffer, "causal") == 0) options->attention_type = LLAMA_ATTENTION_TYPE_CAUSAL;
275+
else if (strcasecmp(buffer, "non_causal") == 0) options->attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;
276+
}
277+
255278
if (strncasecmp(key, OPTION_KEY_MAX_TOKENS, key_len) == 0) {
256279
int value = (int)strtol(buffer, NULL, 0);
257280
if (value >= 0) options->max_tokens = value;
@@ -590,34 +613,46 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
590613
return;
591614
}
592615

616+
llama_seq_id sequence_id = 0;
617+
llama_memory_t memory = llama_get_memory(ctx);
618+
619+
// before encoding (defensive if anything lingered)
620+
if (memory) llama_memory_seq_rm(memory, sequence_id, 0, -1);
621+
593622
// set up batch for processing
623+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
594624
llama_batch batch = llama_batch_init(n_tokens, 0, 1);
595-
llama_seq_id sequence_id = ai->sequence_id;
596625
for (int i = 0; i < n_tokens; ++i) {
597626
batch.token[batch.n_tokens] = tokens[i];
598627
batch.pos[batch.n_tokens] = i;
599628
batch.n_seq_id[batch.n_tokens] = 1;
600629
batch.seq_id[batch.n_tokens][0] = sequence_id;
601-
batch.logits[batch.n_tokens] = i == (n_tokens - 1);
630+
batch.logits[batch.n_tokens] = (pooling_type == LLAMA_POOLING_TYPE_NONE) ? 1 : 0; // see comment below
602631
batch.n_tokens++;
603632
}
604-
ai->sequence_id++;
633+
634+
// This optimization is valid when pooling is enabled (MEAN / CLS / LAST).
635+
// You only need one “marked” token per sequence to read the sequence-level embedding.
636+
// If pooling_type == NONE and you want token-level embeddings, you must set logits=1 for every token you intend to read.
637+
638+
// last token of this sequence requests outputs
639+
batch.logits[batch.n_tokens - 1] = 1;
605640

606641
// run model (do real processing)
607-
llama_memory_t memory = llama_get_memory(ctx);
642+
// from ggerganov: If your application is going to support both models with and without a memory, then you should simply call llama_decode() always
643+
// https://github.com/ggml-org/llama.cpp/discussions/14454
608644
int32_t rc = (memory) ? llama_decode(ctx, batch) : llama_encode(ctx, batch);
609645

610646
if (rc < 0) {
611647
sqlite3_free(tokens);
612648
sqlite3_free(embedding);
613649
llama_batch_free(batch);
614-
sqlite_context_result_error(context, SQLITE_ERROR, "Model decode failed during embedding generation (%d)", rc);
650+
sqlite_context_result_error(context, SQLITE_ERROR, "Model %s failed during embedding generation (%d)", rc, (memory) ? "decode" : "encode");
615651
return;
616652
}
617653

618-
// retrieve embeddings (context set to LLAMA_POOLING_TYPE_LAST in llama_init_from_model)
654+
// retrieve embeddings (context set to LLAMA_POOLING_TYPE_MEAN in llama_init_from_model)
619655
const float *result = NULL;
620-
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
621656
if (pooling_type == LLAMA_POOLING_TYPE_NONE) result = llama_get_embeddings(ctx);
622657
else result = llama_get_embeddings_seq(ctx, sequence_id);
623658

@@ -632,9 +667,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
632667
// check if normalization is needed (default true)
633668
(ai->options.normalize_embedding) ? llm_embed_normalize(result, embedding, dimension) : memcpy(embedding, result, sizeof(float) * dimension);
634669

670+
// IMPORTANT: clear memory for this sequence so the next call starts clean
635671
if (memory) {
636-
llama_memory_clear(memory, true);
672+
// remove tokens in this sequence and optionally compact
637673
llama_memory_seq_rm(memory, sequence_id, 0, -1);
674+
llama_memory_clear(memory, true);
638675
}
639676

640677
// check if JSON output is set

src/sqlite-ai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_AI_VERSION "0.5.8"
27+
#define SQLITE_AI_VERSION "0.5.9"
2828

2929
SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)