@@ -5,12 +5,12 @@ use std::num::NonZeroI32;
55use std:: ptr:: NonNull ;
66use std:: slice;
77
8- use crate :: { DecodeError , EmbeddingsError } ;
98use crate :: llama_batch:: LlamaBatch ;
109use crate :: model:: LlamaModel ;
1110use crate :: timing:: LlamaTimings ;
1211use crate :: token:: data:: LlamaTokenData ;
1312use crate :: token:: LlamaToken ;
13+ use crate :: { DecodeError , EmbeddingsError } ;
1414
1515pub mod kv_cache;
1616pub mod params;
@@ -92,17 +92,51 @@ impl<'model> LlamaContext<'model> {
9292 ///
9393 /// # Errors
9494 ///
95- /// When the current context was constructed without enabling embeddings.
95+ /// - When the current context was constructed without enabling embeddings.
96+ /// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
97+ /// - If the given sequence index exceeds the max sequence id.
98+ pub fn embeddings_seq_ith ( & self , i : i32 ) -> Result < & [ f32 ] , EmbeddingsError > {
99+ if !self . embeddings_enabled {
100+ return Err ( EmbeddingsError :: NotEnabled ) ;
101+ }
102+
103+ unsafe {
104+ let embedding = llama_cpp_sys_2:: llama_get_embeddings_seq ( self . context . as_ptr ( ) , i) ;
105+
106+ // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
107+ if embedding. is_null ( ) {
108+ Err ( EmbeddingsError :: NonePoolType )
109+ } else {
110+ Ok ( std:: slice:: from_raw_parts ( embedding, self . model . n_embd ( ) as usize ) )
111+ }
112+ }
113+ }
114+
115+ /// Get the embeddings for the `i`th token in the current context.
116+ ///
117+ /// # Returns
118+ ///
119+ /// A slice containing the embeddings for the last decoded batch of the given token.
120+ /// The size corresponds to the `n_embd` parameter of the context's model.
121+ ///
122+ /// # Errors
123+ ///
124+ /// - When the current context was constructed without enabling embeddings.
125+ /// - When the given token didn't have logits enabled when it was passed.
126+ /// - If the given token index exceeds the max token id.
96127 pub fn embeddings_ith ( & self , i : i32 ) -> Result < & [ f32 ] , EmbeddingsError > {
97128 if !self . embeddings_enabled {
98- return Err ( EmbeddingsError :: NotEnabled )
129+ return Err ( EmbeddingsError :: NotEnabled ) ;
99130 }
100131
101132 unsafe {
102- Ok ( std:: slice:: from_raw_parts (
103- llama_cpp_sys_2:: llama_get_embeddings_ith ( self . context . as_ptr ( ) , i) ,
104- self . model . n_embd ( ) as usize ,
105- ) )
133+ let embedding = llama_cpp_sys_2:: llama_get_embeddings_ith ( self . context . as_ptr ( ) , i) ;
134+ // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
135+ if embedding. is_null ( ) {
136+ Err ( EmbeddingsError :: LogitsNotEnabled )
137+ } else {
138+ Ok ( std:: slice:: from_raw_parts ( embedding, self . model . n_embd ( ) as usize ) )
139+ }
106140 }
107141 }
108142
@@ -155,6 +189,11 @@ impl<'model> LlamaContext<'model> {
155189 let timings = unsafe { llama_cpp_sys_2:: llama_get_timings ( self . context . as_ptr ( ) ) } ;
156190 LlamaTimings { timings }
157191 }
192+
193+ /// Returns a reference to the raw [llama_cpp_sys_2::llama_context] pointer.
194+ pub fn raw_ctx ( & self ) -> & NonNull < llama_cpp_sys_2:: llama_context > {
195+ & self . context
196+ }
158197}
159198
160199impl Drop for LlamaContext < ' _ > {
0 commit comments