Skip to content

Commit 32b53ed

Browse files
committed
Update for the (hopefully stable!) llama.cpp changes.
1 parent c0faaef commit 32b53ed

File tree

6 files changed

+66
-23
lines changed

6 files changed

+66
-23
lines changed

embeddings/src/main.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ fn main() -> Result<()> {
109109
// initialize the context
110110
let ctx_params = LlamaContextParams::default()
111111
.with_n_threads_batch(std::thread::available_parallelism()?.get() as u32)
112-
.with_embedding(true);
112+
.with_embeddings(true);
113113

114114
let mut ctx = model
115115
.new_context(&backend, ctx_params)
@@ -193,10 +193,9 @@ fn main() -> Result<()> {
193193
fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
194194
ctx.clear_kv_cache();
195195
ctx.decode(batch).with_context(|| "llama_decode() failed")?;
196-
batch.clear();
197196

198197
for i in 0..s_batch {
199-
let embedding = ctx.embeddings_ith(i).with_context(|| "Failed to get embeddings")?;
198+
let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?;
200199
let output_embeddings = if normalise {
201200
normalize(embedding)
202201
} else {
@@ -206,6 +205,8 @@ fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, ou
206205
output.push(output_embeddings);
207206
}
208207

208+
batch.clear();
209+
209210
Ok(())
210211
}
211212

llama-cpp-2/src/context.rs

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ use std::num::NonZeroI32;
55
use std::ptr::NonNull;
66
use std::slice;
77

8-
use crate::{DecodeError, EmbeddingsError};
98
use crate::llama_batch::LlamaBatch;
109
use crate::model::LlamaModel;
1110
use crate::timing::LlamaTimings;
1211
use crate::token::data::LlamaTokenData;
1312
use crate::token::LlamaToken;
13+
use crate::{DecodeError, EmbeddingsError};
1414

1515
pub mod kv_cache;
1616
pub mod params;
@@ -92,17 +92,51 @@ impl<'model> LlamaContext<'model> {
9292
///
9393
/// # Errors
9494
///
95-
/// When the current context was constructed without enabling embeddings.
95+
/// - When the current context was constructed without enabling embeddings.
96+
/// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
97+
/// - If the given sequence index exceeds the max sequence id.
98+
pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
99+
if !self.embeddings_enabled {
100+
return Err(EmbeddingsError::NotEnabled);
101+
}
102+
103+
unsafe {
104+
let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);
105+
106+
// Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
107+
if embedding.is_null() {
108+
Err(EmbeddingsError::NonePoolType)
109+
} else {
110+
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
111+
}
112+
}
113+
}
114+
115+
/// Get the embeddings for the `i`th token in the current context.
116+
///
117+
/// # Returns
118+
///
119+
/// A slice containing the embeddings for the last decoded batch of the given token.
120+
/// The size corresponds to the `n_embd` parameter of the context's model.
121+
///
122+
/// # Errors
123+
///
124+
/// - When the current context was constructed without enabling embeddings.
125+
/// - When the given token didn't have logits enabled when it was passed.
126+
/// - If the given token index exceeds the max token id.
96127
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
97128
if !self.embeddings_enabled {
98-
return Err(EmbeddingsError::NotEnabled)
129+
return Err(EmbeddingsError::NotEnabled);
99130
}
100131

101132
unsafe {
102-
Ok(std::slice::from_raw_parts(
103-
llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i),
104-
self.model.n_embd() as usize,
105-
))
133+
let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
134+
// Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
135+
if embedding.is_null() {
136+
Err(EmbeddingsError::LogitsNotEnabled)
137+
} else {
138+
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
139+
}
106140
}
107141
}
108142

@@ -155,6 +189,11 @@ impl<'model> LlamaContext<'model> {
155189
let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) };
156190
LlamaTimings { timings }
157191
}
192+
193+
/// Returns a reference to the raw [llama_cpp_sys_2::llama_context] pointer.
194+
pub fn raw_ctx(&self) -> &NonNull<llama_cpp_sys_2::llama_context> {
195+
&self.context
196+
}
158197
}
159198

160199
impl Drop for LlamaContext<'_> {

llama-cpp-2/src/context/params.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,11 @@ impl LlamaContextParams {
319319
///
320320
/// ```rust
321321
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
322-
/// assert!(!params.embedding());
322+
/// assert!(!params.embeddings());
323323
/// ```
324324
#[must_use]
325-
pub fn embedding(&self) -> bool {
326-
self.context_params.embedding
325+
pub fn embeddings(&self) -> bool {
326+
self.context_params.embeddings
327327
}
328328

329329
/// Enable the use of embeddings
@@ -333,12 +333,12 @@ impl LlamaContextParams {
333333
/// ```rust
334334
/// use llama_cpp_2::context::params::LlamaContextParams;
335335
/// let params = LlamaContextParams::default()
336-
/// .with_embedding(true);
337-
/// assert!(params.embedding());
336+
/// .with_embeddings(true);
337+
/// assert!(params.embeddings());
338338
/// ```
339339
#[must_use]
340-
pub fn with_embedding(mut self, embedding: bool) -> Self {
341-
self.context_params.embedding = embedding;
340+
pub fn with_embeddings(mut self, embedding: bool) -> Self {
341+
self.context_params.embeddings = embedding;
342342
self
343343
}
344344
}

llama-cpp-2/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ pub enum DecodeError {
8383
pub enum EmbeddingsError {
8484
#[error("Embeddings weren't enabled in the context options")]
8585
NotEnabled,
86+
#[error("Logits were not enabled for the given token")]
87+
LogitsNotEnabled,
88+
#[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")]
89+
NonePoolType,
8690
}
8791

8892
/// Decode a error from llama.cpp into a [`DecodeError`].

llama-cpp-2/src/llama_batch.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,13 @@ impl LlamaBatch {
121121
let seq_id_ptr = *self.llama_batch.seq_id.add(j);
122122
seq_id_ptr.write(seq_id);
123123
self.llama_batch.n_seq_id.add(j).write(1);
124-
self.llama_batch.logits.add(j).write(logits_all as i8)
124+
125+
let write_logits = logits_all || i == n_tokens - 1;
126+
self.llama_batch.logits.add(j).write(write_logits as i8)
125127
}
126128
}
127129

128-
unsafe {
129-
self.llama_batch.logits.add(n_tokens - 1).write(true as i8);
130-
self.initialized_logits.push(self.llama_batch.n_tokens - 1);
131-
}
130+
self.initialized_logits.push(self.llama_batch.n_tokens - 1);
132131

133132
Ok(())
134133
}

llama-cpp-2/src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ impl LlamaModel {
320320
};
321321
let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
322322

323-
Ok(LlamaContext::new(self, context, params.embedding()))
323+
Ok(LlamaContext::new(self, context, params.embeddings()))
324324
}
325325
}
326326

0 commit comments

Comments
 (0)