@@ -50,6 +50,8 @@ SQLITE_EXTENSION_INIT1
5050
5151#define OPTION_KEY_GENERATE_EMBEDDING "generate_embedding"
5252#define OPTION_KEY_NORMALIZE_EMBEDDING "normalize_embedding"
53+ #define OPTION_KEY_POOLING_TYPE "pooling_type"
54+ #define OPTION_KEY_ATTENTION_TYPE "attention_type"
5355#define OPTION_KEY_MAX_TOKENS "max_tokens"
5456#define OPTION_KEY_JSON_OUTPUT "json_output"
5557#define OPTION_KEY_GPU_LAYERS "gpu_layers"
@@ -70,16 +72,16 @@ typedef struct {
7072 bool log_info ; // flag to enable/disable the logging of info
7173
7274 // ** CONTEXT **
73- uint32_t context_size ; // set both n_ctx and n_batch
75+ uint32_t context_size ; // set both n_ctx and n_batch (*** DONE ***)
7476 uint32_t n_ctx ; // text context, 0 = from model
7577 uint32_t n_batch ; // logical maximum batch size that can be submitted to llama_decode
7678 uint32_t n_ubatch ; // physical maximum batch size
7779 uint32_t n_seq_max ; // max number of sequences (i.e. distinct states for recurrent models)
7880 int32_t n_threads ; // number of threads to use for generation
7981 int32_t n_threads_batch ; // number of threads to use for batch processing
8082 enum llama_rope_scaling_type rope_scaling_type ; // RoPE scaling type, from `enum llama_rope_scaling_type`
81- enum llama_pooling_type pooling_type ; // whether to pool (sum) embedding results by sequence id
82- enum llama_attention_type attention_type ; // attention type to use for embeddings
83+ enum llama_pooling_type pooling_type ; // whether to pool (sum) embedding results by sequence id (*** DONE ***)
84+ enum llama_attention_type attention_type ; // attention type to use for embeddings (*** DONE ***)
8385 float rope_freq_base ; // RoPE base frequency, 0 = from model
8486 float rope_freq_scale ; // RoPE frequency scaling factor, 0 = from model
8587 float yarn_ext_factor ; // YaRN extrapolation mix factor, negative = from model
@@ -103,6 +105,7 @@ typedef struct {
103105 bool normalize_embedding ;// if true, embeddings are normalized
104106 bool json_output ; // if true, embedding result is converted to JSON
105107
108+
106109 // ** CUSTOM **
107110 int32_t max_tokens ; // to control max allowed tokens to generate (to control user's input size)
108111
@@ -132,9 +135,6 @@ typedef struct {
132135 // whisper
133136 struct whisper_context * whisper ;
134137
135- // embedding
136- llama_seq_id sequence_id ; // some models requires to be unique across multiple calls to llm_embed_generate
137-
138138 // chat
139139 struct {
140140 char uuid [UUID_STR_MAXLEN ];
@@ -208,13 +208,21 @@ void llm_set_context_options (struct llama_context_params *llama_context, llm_op
208208 if (options -> generate_embedding ) {
209209 // https://github.com/ggml-org/llama.cpp/discussions/15093
210210 llama_context -> embeddings = true;
211- llama_context -> pooling_type = LLAMA_POOLING_TYPE_LAST ;
211+ llama_context -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
212212 }
213213
214214 if (options -> context_size ) {
215215 llama_context -> n_ctx = options -> context_size ;
216216 llama_context -> n_batch = options -> context_size ;
217217 }
218+
219+ if (options -> pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED ) {
220+ llama_context -> pooling_type = options -> pooling_type ;
221+ }
222+
223+ if (options -> attention_type != LLAMA_ATTENTION_TYPE_UNSPECIFIED ) {
224+ llama_context -> attention_type = options -> attention_type ;
225+ }
218226}
219227
220228static void llm_options_init (llm_options * options ) {
@@ -223,6 +231,8 @@ static void llm_options_init (llm_options *options) {
223231 options -> normalize_embedding = true;
224232 options -> max_tokens = 0 ; // no limits
225233 options -> log_info = false; // disable INFO messages logging
234+ options -> pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED ;
235+ options -> attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED ;
226236}
227237
228238static bool llm_options_callback (void * xdata , const char * key , int key_len , const char * value , int value_len ) {
@@ -252,6 +262,19 @@ static bool llm_options_callback (void *xdata, const char *key, int key_len, con
252262 return true;
253263 }
254264
265+ if (strncasecmp (key , OPTION_KEY_POOLING_TYPE , key_len ) == 0 ) {
266+ if (strcasecmp (buffer , "none" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_NONE ;
267+ else if (strcasecmp (buffer , "mean" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
268+ else if (strcasecmp (buffer , "cls" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_CLS ;
269+ else if (strcasecmp (buffer , "last" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_LAST ;
270+ else if (strcasecmp (buffer , "rank" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_RANK ;
271+ }
272+
273+ if (strncasecmp (key , OPTION_KEY_ATTENTION_TYPE , key_len ) == 0 ) {
274+ if (strcasecmp (buffer , "causal" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_CAUSAL ;
275+ else if (strcasecmp (buffer , "non_causal" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
276+ }
277+
255278 if (strncasecmp (key , OPTION_KEY_MAX_TOKENS , key_len ) == 0 ) {
256279 int value = (int )strtol (buffer , NULL , 0 );
257280 if (value >= 0 ) options -> max_tokens = value ;
@@ -590,34 +613,46 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
590613 return ;
591614 }
592615
616+ llama_seq_id sequence_id = 0 ;
617+ llama_memory_t memory = llama_get_memory (ctx );
618+
619+ // before encoding (defensive if anything lingered)
620+ if (memory ) llama_memory_seq_rm (memory , sequence_id , 0 , -1 );
621+
593622 // set up batch for processing
623+ const enum llama_pooling_type pooling_type = llama_pooling_type (ctx );
594624 llama_batch batch = llama_batch_init (n_tokens , 0 , 1 );
595- llama_seq_id sequence_id = ai -> sequence_id ;
596625 for (int i = 0 ; i < n_tokens ; ++ i ) {
597626 batch .token [batch .n_tokens ] = tokens [i ];
598627 batch .pos [batch .n_tokens ] = i ;
599628 batch .n_seq_id [batch .n_tokens ] = 1 ;
600629 batch .seq_id [batch .n_tokens ][0 ] = sequence_id ;
601- batch .logits [batch .n_tokens ] = i == ( n_tokens - 1 );
630+ batch .logits [batch .n_tokens ] = ( pooling_type == LLAMA_POOLING_TYPE_NONE ) ? 1 : 0 ; // see comment below
602631 batch .n_tokens ++ ;
603632 }
604- ai -> sequence_id ++ ;
633+
634+ // This optimization is valid when pooling is enabled (MEAN / CLS / LAST).
635+ // You only need one “marked” token per sequence to read the sequence-level embedding.
636+ // If pooling_type == NONE and you want token-level embeddings, you must set logits=1 for every token you intend to read.
637+
638+ // last token of this sequence requests outputs
639+ batch .logits [batch .n_tokens - 1 ] = 1 ;
605640
606641 // run model (do real processing)
607- llama_memory_t memory = llama_get_memory (ctx );
642+ // from ggerganov: If your application is going to support both models with and without a memory, then you should simply call llama_decode() always
643+ // https://github.com/ggml-org/llama.cpp/discussions/14454
608644 int32_t rc = (memory ) ? llama_decode (ctx , batch ) : llama_encode (ctx , batch );
609645
610646 if (rc < 0 ) {
611647 sqlite3_free (tokens );
612648 sqlite3_free (embedding );
613649 llama_batch_free (batch );
614- sqlite_context_result_error (context , SQLITE_ERROR , "Model decode failed during embedding generation (%d)" , rc );
650+ sqlite_context_result_error (context , SQLITE_ERROR , "Model %s failed during embedding generation (%d)" , rc , ( memory ) ? "decode" : "encode" );
615651 return ;
616652 }
617653
618- // retrieve embeddings (context set to LLAMA_POOLING_TYPE_LAST in llama_init_from_model)
654+ // retrieve embeddings (context set to LLAMA_POOLING_TYPE_MEAN in llama_init_from_model)
619655 const float * result = NULL ;
620- const enum llama_pooling_type pooling_type = llama_pooling_type (ctx );
621656 if (pooling_type == LLAMA_POOLING_TYPE_NONE ) result = llama_get_embeddings (ctx );
622657 else result = llama_get_embeddings_seq (ctx , sequence_id );
623658
@@ -632,9 +667,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
632667 // check if normalization is needed (default true)
633668 (ai -> options .normalize_embedding ) ? llm_embed_normalize (result , embedding , dimension ) : memcpy (embedding , result , sizeof (float ) * dimension );
634669
670+ // IMPORTANT: clear memory for this sequence so the next call starts clean
635671 if (memory ) {
636- llama_memory_clear ( memory , true);
672+ // remove tokens in this sequence and optionally compact
637673 llama_memory_seq_rm (memory , sequence_id , 0 , -1 );
674+ llama_memory_clear (memory , true);
638675 }
639676
640677 // check if JSON output is set
0 commit comments