Skip to content

Commit 83ede9f

Browse files
committed
Improved embedding and chunking
1 parent 932427d commit 83ede9f

File tree

3 files changed

+225
-48
lines changed

3 files changed

+225
-48
lines changed

notes/EMBEDDING.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
## Are there cases where llama.cpp can produce more than one embedding from a single text input?
2+
3+
In **llama.cpp**, whether you get **one embedding** or **multiple embeddings** from a text input depends on:
4+
5+
1. **Pooling type (`llama_pooling_type`)**
6+
7+
* `LLAMA_POOLING_TYPE_NONE` → no pooling is applied.
8+
9+
* You get an embedding **per token** (so length = number of tokens).
10+
* You retrieve it with `llama_get_embeddings(ctx)` after `llama_encode()` or `llama_decode()`.
11+
* `LLAMA_POOLING_TYPE_MEAN`, `LLAMA_POOLING_TYPE_CLS`, `LLAMA_POOLING_TYPE_LAST` → pooling is applied.
12+
13+
* You get **one embedding per sequence ID**.
14+
* You retrieve it with `llama_get_embeddings_seq(ctx, seq_id)`.
15+
16+
2. **Number of sequence IDs (`seq_id`) in the batch**
17+
18+
* Each sequence in the batch (distinguished by `seq_id`) can produce its own pooled embedding if pooling is enabled.
19+
* Example: if you feed 3 sentences in parallel with `seq_id = 0,1,2` and pooling = MEAN, you’ll get **3 embeddings**.
20+
21+
3. **How you mark logits/outputs in the batch**
22+
23+
* With pooling enabled, you only need to request an output on one token per sequence (often the last token).
24+
* Without pooling, you must request logits for every token you want embeddings for.
25+
26+
---
27+
28+
So, to directly answer:
29+
30+
* **Single text input, pooling enabled****1 embedding** (sequence-level).
31+
* **Single text input, pooling disabled****N embeddings** (token-level, one per token).
32+
* **Multiple texts in one batch****N embeddings** (one per sequence if pooling, or token-level if not).
33+
34+
---
35+
36+
In your embedding use case (sounds like you’re doing what `llama.cpp` calls **sentence embedding**), you probably want **pooling enabled** (MEAN or CLS) and therefore **one embedding per text input**.
37+
38+
## Is there a way to force one embedding (sentence embedding)?
39+
40+
Yes, in **llama.cpp** you can force sentence-level embeddings by enabling a **pooling mode** in the context parameters. That way you’ll always get exactly **one embedding per sequence ID**, regardless of how many tokens the text expands into.
41+
42+
---
43+
44+
### How to do it
45+
46+
When you create the context (`llama_context_params`), set:
47+
48+
```c
49+
struct llama_context_params cparams = llama_context_default_params();
50+
51+
// choose your pooling mode:
52+
cparams.pooling_type = LLAMA_POOLING_TYPE_MEAN; // common for sentence embeddings
53+
// or: LLAMA_POOLING_TYPE_CLS
54+
// or: LLAMA_POOLING_TYPE_LAST
55+
56+
struct llama_context *ctx = llama_new_context_with_model(model, cparams);
57+
```
58+
59+
Then, when you process text with `llama_encode()` or `llama_decode()`:
60+
61+
* Use a single `seq_id` for that text (e.g. `seq_id = 0`).
62+
* After the call, retrieve the **sequence embedding**:
63+
64+
```c
65+
const float *embedding = llama_get_embeddings_seq(ctx, seq_id);
66+
// embedding size = llama_n_embd(model)
67+
```
68+
69+
That’s it: you’ll get **one embedding vector** for the entire input string.
70+
71+
---
72+
73+
### Contrast with token-level embeddings
74+
75+
* If `pooling_type == LLAMA_POOLING_TYPE_NONE``llama_get_embeddings(ctx)` gives you **token-level embeddings** (matrix: tokens × dim).
76+
* If `pooling_type != NONE``llama_get_embeddings_seq(ctx, seq_id)` gives you **sentence-level embeddings** (one vector per sequence).
77+
78+
---
79+
80+
**So the way to force one embedding per text is:**
81+
82+
* Set `pooling_type` in the context params (`MEAN`, `CLS`, or `LAST`).
83+
* Use `llama_get_embeddings_seq()` instead of `llama_get_embeddings()`.
84+
85+
---
86+
87+
## Is the pooling_type related to the model? Can I be sure that all models support all the pooling types?
88+
89+
90+
The **pooling\_type** in **llama.cpp** is **not a property of the model itself** – it’s purely how **llama.cpp** aggregates the **token embeddings** it already computes. All decoder-only LLaMA-style models output a hidden state per token, so pooling is just a post-processing step that the library offers.
91+
92+
---
93+
94+
### What this means
95+
96+
* Any model you load with **llama.cpp** produces **per-token embeddings**.
97+
* Pooling (`MEAN`, `CLS`, `LAST`) is applied on top of those token embeddings to compress them into **one vector**.
98+
* So yes: **all models supported by llama.cpp can use all pooling types**.
99+
* The only requirement is that you create the context with `cparams.pooling_type` set, so llama.cpp knows to allocate memory for pooled embeddings.
100+
101+
---
102+
103+
### Differences between pooling modes
104+
105+
* **MEAN** → average of all token embeddings in the sequence.
106+
107+
* Default / most common for sentence embeddings & semantic search.
108+
* **CLS** → use the embedding of the first token (analogous to `[CLS]` in BERT-style models).
109+
110+
* Works, but LLaMA models weren’t trained with a `[CLS]` objective, so performance may vary.
111+
* **LAST** → use the embedding of the final token.
112+
113+
* Sometimes useful for autoregressive tasks; not typically the best for embeddings.
114+
115+
---
116+
117+
### Important
118+
119+
The quality of the **sentence embedding** you get *does* depend on the **model training objective**:
120+
121+
* Models explicitly trained for embedding tasks (e.g. `all-MiniLM-L6-v2`, `text-embedding-ada-002`, `nomic-embed`) will usually outperform a raw LLaMA model, even if you apply pooling.
122+
* LLaMA-style models with pooling still give you usable vectors, but they weren’t optimized for semantic similarity.
123+
124+
---
125+
126+
So in summary:
127+
128+
* Pooling is **always available** in llama.cpp.
129+
* All models supported by llama.cpp can use **MEAN / CLS / LAST** pooling.
130+
* The *choice of pooling* affects embedding quality, but you won’t get an error from the library.
131+
132+
---

src/sqlite-ai.c

Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ SQLITE_EXTENSION_INIT1
9595
#define AI_COLUMN_REPLY 0
9696

9797
#define AI_DEFAULT_MODEL_OPTIONS "gpu_layers=99"
98-
#define AI_DEFAULT_CONTEXT_EMBEDDING_OPTIONS "generate_embedding=1,normalize_embedding=1,pooling_type=last"
98+
#define AI_DEFAULT_CONTEXT_EMBEDDING_OPTIONS "generate_embedding=1,normalize_embedding=1,pooling_type=mean"
9999
#define AI_DEFAULT_CONTEXT_CHAT_OPTIONS "context_size=4096"
100100
#define AI_DEFAULT_CONTEXT_TEXTGEN_OPTIONS "context_size=4096"
101101

@@ -322,7 +322,11 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
322322
// https://github.com/ggml-org/llama.cpp/discussions/15093
323323
int value = (int)strtol(buffer, NULL, 0);
324324
options->embeddings = (value != 0);
325-
options->pooling_type = LLAMA_POOLING_TYPE_LAST;
325+
options->pooling_type = LLAMA_POOLING_TYPE_MEAN;
326+
327+
// for non-causal models, batch size must be equal to ubatch size
328+
// when generating embeddings, always tie them together.
329+
options->n_ubatch = options->n_batch;
326330
return true;
327331
}
328332

@@ -372,7 +376,8 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
372376
}
373377

374378
if (strncasecmp(key, OPTION_KEY_POOLING_TYPE, key_len) == 0) {
375-
if (strcasecmp(buffer, "none") == 0) options->pooling_type = LLAMA_POOLING_TYPE_NONE;
379+
if (strcasecmp(buffer, "none") == 0) options->pooling_type = LLAMA_POOLING_TYPE_MEAN;
380+
// pooling_type mean is not supported and so in this version we forced it to be really mean so ONE EMBEDDING will be generated
376381
else if (strcasecmp(buffer, "mean") == 0) options->pooling_type = LLAMA_POOLING_TYPE_MEAN;
377382
else if (strcasecmp(buffer, "cls") == 0) options->pooling_type = LLAMA_POOLING_TYPE_CLS;
378383
else if (strcasecmp(buffer, "last") == 0) options->pooling_type = LLAMA_POOLING_TYPE_LAST;
@@ -701,6 +706,23 @@ static void llm_embed_normalize (const float *src, float *dest, int dim) {
701706
}
702707
}
703708

709+
static void llm_batch_clear (struct llama_batch *batch) {
710+
batch->n_tokens = 0;
711+
}
712+
713+
static void llm_batch_add (struct llama_batch *batch, llama_token id, llama_pos pos, const llama_seq_id *seq_ids, size_t n_seq_ids, bool logits) {
714+
batch->token [batch->n_tokens] = id;
715+
batch->pos [batch->n_tokens] = pos;
716+
batch->n_seq_id[batch->n_tokens] = (int32_t)n_seq_ids;
717+
718+
for (size_t i = 0; i < n_seq_ids; ++i) {
719+
batch->seq_id[batch->n_tokens][i] = seq_ids[i];
720+
}
721+
722+
batch->logits[batch->n_tokens] = logits ? 1 : 0;
723+
batch->n_tokens++;
724+
}
725+
704726
static void llm_embed_generate_run (sqlite3_context *context, const char *text, int32_t text_len) {
705727
ai_context *ai = (ai_context *)sqlite3_user_data(context);
706728
struct llama_model *model = ai->model;
@@ -724,10 +746,12 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
724746
return;
725747
}
726748

749+
// pooling is NOT NONE -> one sentence-level embedding
750+
// more details in notes/EMBEDDING.md
727751
struct llama_context *ctx = ai->ctx;
728752
llama_set_embeddings(ctx, true);
729753

730-
// sanity check tokens
754+
// sanity check context / training window info (warn only)
731755
const int n_ctx_train = llama_model_n_ctx_train(model);
732756
const int n_ctx = llama_n_ctx(ctx);
733757
if (n_ctx > n_ctx_train) {
@@ -736,15 +760,15 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
736760
ai_logger(GGML_LOG_LEVEL_WARN, buffer, sqlite3_context_db_handle(context));
737761
}
738762

739-
// sanity check embedding memory
763+
// allocate embedding buffer
740764
int dimension = llama_model_n_embd(llama_get_model(ctx));
741765
float *embedding = (float *)sqlite3_malloc64(sizeof(float) * dimension);
742766
if (!embedding) {
743767
sqlite_context_result_error(context, SQLITE_NOMEM, "Out of memory: failed to allocate embedding buffer of dimension %d", dimension);
744768
return;
745769
}
746770

747-
// get token count
771+
// get token count (negative return encodes needed size)
748772
int32_t n_tokens = -llama_tokenize(vocab, text, text_len, NULL, 0, true, true);
749773
if (n_tokens == 0) {
750774
sqlite3_free(embedding);
@@ -757,7 +781,14 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
757781
return;
758782
}
759783

760-
// allocate memory for tokens
784+
// even with chunking, decoder embeddings need the full sequence to be in the KV once
785+
if (n_tokens > n_ctx) {
786+
sqlite3_free(embedding);
787+
sqlite_context_result_error(context, SQLITE_TOOBIG, "Input too large for model context: %d tokens > n_ctx %d. Create a context with a n_ctx value higher than %d.", n_tokens, n_ctx, n_tokens);
788+
return;
789+
}
790+
791+
// allocate tokens and tokenize
761792
llama_token *tokens = sqlite3_malloc64(n_tokens * sizeof(llama_token));
762793
if (!tokens) {
763794
sqlite3_free(embedding);
@@ -774,49 +805,57 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
774805
return;
775806
}
776807

808+
// max batch size
809+
uint32_t n_batch = llama_n_batch(ctx);
810+
811+
size_t pos_base = 0; // running position across chunks
777812
llama_seq_id sequence_id = 0;
778813
llama_memory_t memory = llama_get_memory(ctx);
779814

780-
// before encoding (defensive if anything lingered)
781-
if (memory) llama_memory_seq_rm(memory, sequence_id, 0, -1);
815+
if (memory) {
816+
// start from a clean slate for this sequence
817+
llama_memory_seq_rm(memory, sequence_id, 0, -1);
818+
819+
// fresh KV for this prompt (only once!)
820+
llama_memory_clear(memory, /*clear_kv_cache_only=*/true);
821+
}
782822

783-
// set up batch for processing
823+
// LLAMA_POOLING_TYPE_NONE is disabled in this version
784824
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
785-
llama_batch batch = llama_batch_init(n_tokens, 0, 1);
786-
for (int i = 0; i < n_tokens; ++i) {
787-
batch.token[batch.n_tokens] = tokens[i];
788-
batch.pos[batch.n_tokens] = i;
789-
batch.n_seq_id[batch.n_tokens] = 1;
790-
batch.seq_id[batch.n_tokens][0] = sequence_id;
791-
batch.logits[batch.n_tokens] = (pooling_type == LLAMA_POOLING_TYPE_NONE) ? 1 : 0; // see comment below
792-
batch.n_tokens++;
793-
}
794-
795-
// This optimization is valid when pooling is enabled (MEAN / CLS / LAST).
796-
// You only need one “marked” token per sequence to read the sequence-level embedding.
797-
// If pooling_type == NONE and you want token-level embeddings, you must set logits=1 for every token you intend to read.
798-
799-
// last token of this sequence requests outputs
800-
batch.logits[batch.n_tokens - 1] = 1;
801-
802-
// run model (do real processing)
803-
// from ggerganov: If your application is going to support both models with and without a memory, then you should simply call llama_decode() always
804-
// https://github.com/ggml-org/llama.cpp/discussions/14454
805-
int32_t rc = (memory) ? llama_decode(ctx, batch) : llama_encode(ctx, batch);
806-
807-
if (rc < 0) {
808-
sqlite3_free(tokens);
809-
sqlite3_free(embedding);
810-
llama_batch_free(batch);
811-
sqlite_context_result_error(context, SQLITE_ERROR, "Model %s failed during embedding generation (%d)", rc, (memory) ? "decode" : "encode");
812-
return;
813-
}
825+
GGML_ASSERT(pooling_type != LLAMA_POOLING_TYPE_NONE);
814826

815-
// retrieve embeddings (context set to LLAMA_POOLING_TYPE_MEAN in llama_init_from_model)
816-
const float *result = NULL;
817-
if (pooling_type == LLAMA_POOLING_TYPE_NONE) result = llama_get_embeddings(ctx);
818-
else result = llama_get_embeddings_seq(ctx, sequence_id);
827+
// init batch: n_seq_max = 1 (single prompt), embd = 0
828+
llama_batch batch = llama_batch_init(n_batch, 0, 1);
829+
while (pos_base < n_tokens) {
830+
llm_batch_clear(&batch);
831+
832+
size_t to_feed = (n_tokens - pos_base > n_batch) ? n_batch : (n_tokens - pos_base);
833+
834+
// fill the batch with up to n_batch tokens
835+
for (size_t i = 0; i < to_feed; ++i) {
836+
const llama_token tk = (llama_token)tokens[pos_base + i];
837+
const llama_pos ps = (llama_pos)(pos_base + i);
838+
const bool want_logits = (i + 1 == to_feed); // last token in this chunk
839+
llm_batch_add(&batch, tk, ps, &sequence_id, 1, want_logits);
840+
}
841+
842+
// run model on this chunk
843+
// from ggerganov: If your application is going to support both models with and without a memory, then you should simply call llama_decode() always
844+
// https://github.com/ggml-org/llama.cpp/discussions/14454
845+
int32_t rc = (memory) ? llama_decode(ctx, batch) : llama_encode(ctx, batch);
846+
if (rc < 0) {
847+
sqlite3_free(tokens);
848+
sqlite3_free(embedding);
849+
llama_batch_free(batch);
850+
sqlite_context_result_error(context, SQLITE_ERROR, "Model %s failed during embedding generation (%d)", (memory) ? "decode" : "encode", rc);
851+
return;
852+
}
853+
854+
pos_base += to_feed;
855+
}
819856

857+
// retrieve sentence embedding (pooling is enabled)
858+
const float *result = llama_get_embeddings_seq(ctx, sequence_id);
820859
if (result == NULL) {
821860
sqlite3_free(tokens);
822861
sqlite3_free(embedding);
@@ -840,7 +879,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
840879
sqlite3_str *s = sqlite3_str_new(sqlite3_context_db_handle(context));
841880
sqlite3_str_appendchar(s, 1, '[');
842881
for (int i = 0; i < dimension; i++) {
843-
if (i != 0) sqlite3_str_appendchar(s, 1, ',');
882+
if (i) sqlite3_str_appendchar(s, 1, ',');
844883
sqlite3_str_appendf(s, "%.6g", embedding[i]);
845884
}
846885
sqlite3_str_appendchar(s, 1, ']');
@@ -849,7 +888,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
849888
(json) ? sqlite3_result_text(context, json, -1, sqlite3_free) : sqlite3_result_null(context);
850889
sqlite3_free(embedding);
851890
} else {
852-
sqlite3_result_blob(context, embedding, sizeof(float) * dimension, sqlite3_free);
891+
sqlite3_result_blob(context, embedding, (int)sizeof(float) * dimension, sqlite3_free);
853892
}
854893

855894
sqlite3_free(tokens);
@@ -864,25 +903,31 @@ static void llm_embed_generate (sqlite3_context *context, int argc, sqlite3_valu
864903
int32_t text_len = (int32_t)sqlite3_value_bytes(argv[0]);
865904
const char *model_options = (argc == 2) ? (const char *)sqlite3_value_text(argv[1]) : NULL;
866905

906+
// handle NULL input
907+
if (!text || text_len == 0) {
908+
sqlite3_result_null(context);
909+
return;
910+
}
911+
867912
// passing NULL as xdata because context has been already created
868913
ai_context *ai = (ai_context *)sqlite3_user_data(context);
869914
if (parse_keyvalue_string(ai, model_options, llm_context_options_callback, NULL) == false) return;
870915

871-
if (!text || text_len == 0) return;
916+
// real processing
872917
llm_embed_generate_run(context, text, text_len);
873918
}
874919

875920
static void llm_token_count (sqlite3_context *context, int argc, sqlite3_value **argv) {
876921
// sanity check args and context
877922
if (llm_check_context(context) == false) return;
878923
if (llm_common_args_check(context, "llm_token_count", argc, argv, true) == false) return;
879-
ai_context *ai = (ai_context *)sqlite3_user_data(context);
880924

881925
const char *text = (const char *)sqlite3_value_text(argv[0]);
882926
int32_t text_len = (int32_t)sqlite3_value_bytes(argv[0]);
883927
if (!text || text_len == 0) return;
884928

885929
// sanity check vocab
930+
ai_context *ai = (ai_context *)sqlite3_user_data(context);
886931
const struct llama_vocab *vocab = llama_model_get_vocab(ai->model);
887932
if (!vocab) {
888933
sqlite_context_result_error(context, SQLITE_ERROR, "Failed to extract vocabulary from the model");

src/sqlite-ai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_AI_VERSION "0.6.5"
27+
#define SQLITE_AI_VERSION "0.6.7"
2828

2929
SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)