Skip to content

Commit 542a410

Browse files
committed
Add functionality for creating embeddings
1 parent e74a19e commit 542a410

File tree

6 files changed

+174
-22
lines changed

6 files changed

+174
-22
lines changed

llama-cpp-2/src/context.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
33
use std::fmt::{Debug, Formatter};
44
use std::num::NonZeroI32;
5+
use std::ptr::NonNull;
6+
use std::slice;
57

8+
use crate::{DecodeError, EmbeddingsError};
69
use crate::llama_batch::LlamaBatch;
710
use crate::model::LlamaModel;
811
use crate::timing::LlamaTimings;
912
use crate::token::data::LlamaTokenData;
1013
use crate::token::LlamaToken;
11-
use crate::DecodeError;
12-
use std::ptr::NonNull;
13-
use std::slice;
1414

1515
pub mod kv_cache;
1616
pub mod params;
@@ -24,6 +24,7 @@ pub struct LlamaContext<'a> {
2424
/// a reference to the contexts model.
2525
pub model: &'a LlamaModel,
2626
initialized_logits: Vec<i32>,
27+
embeddings_enabled: bool,
2728
}
2829

2930
impl Debug for LlamaContext<'_> {
@@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> {
3839
pub(crate) fn new(
3940
llama_model: &'model LlamaModel,
4041
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
42+
embeddings_enabled: bool,
4143
) -> Self {
4244
Self {
4345
context: llama_context,
4446
model: llama_model,
4547
initialized_logits: Vec::new(),
48+
embeddings_enabled,
4649
}
4750
}
4851

@@ -80,6 +83,29 @@ impl<'model> LlamaContext<'model> {
8083
}
8184
}
8285

86+
/// Get the embeddings for the `i`th sequence in the current context.
87+
///
88+
/// # Returns
89+
///
90+
/// A slice containing the embeddings for the last decoded batch.
91+
/// The size corresponds to the `n_embd` parameter of the context's model.
92+
///
93+
/// # Errors
94+
///
95+
/// When the current context was constructed without enabling embeddings.
96+
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
97+
if !self.embeddings_enabled {
98+
return Err(EmbeddingsError::NotEnabled)
99+
}
100+
101+
unsafe {
102+
Ok(std::slice::from_raw_parts(
103+
llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i),
104+
self.model.n_embd() as usize,
105+
))
106+
}
107+
}
108+
83109
/// Get the logits for the ith token in the context.
84110
///
85111
/// # Panics

llama-cpp-2/src/context/params.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! A safe wrapper around `llama_context_params`.
2-
use llama_cpp_sys_2;
32
use std::fmt::Debug;
43
use std::num::NonZeroU32;
54

5+
use llama_cpp_sys_2;
6+
67
/// A rusty wrapper around `rope_scaling_type`.
78
#[repr(i8)]
89
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -267,6 +268,19 @@ impl LlamaContextParams {
267268
self.context_params.n_threads
268269
}
269270

271+
/// Get the number of threads allocated for batches.
272+
///
273+
/// # Examples
274+
///
275+
/// ```rust
276+
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
277+
/// assert_eq!(params.n_threads_batch(), 4);
278+
/// ```
279+
#[must_use]
280+
pub fn n_threads_batch(&self) -> u32 {
281+
self.context_params.n_threads_batch
282+
}
283+
270284
/// Set the number of threads.
271285
///
272286
/// # Examples
@@ -282,6 +296,51 @@ impl LlamaContextParams {
282296
self.context_params.n_threads = n_threads;
283297
self
284298
}
299+
300+
/// Set the number of threads allocated for batches.
301+
///
302+
/// # Examples
303+
///
304+
/// ```rust
305+
/// use llama_cpp_2::context::params::LlamaContextParams;
306+
/// let params = LlamaContextParams::default()
307+
/// .with_n_threads_batch(8);
308+
/// assert_eq!(params.n_threads_batch(), 8);
309+
/// ```
310+
#[must_use]
311+
pub fn with_n_threads_batch(mut self, n_threads: u32) -> Self {
312+
self.context_params.n_threads_batch = n_threads;
313+
self
314+
}
315+
316+
/// Check whether embeddings are enabled
317+
///
318+
/// # Examples
319+
///
320+
/// ```rust
321+
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
322+
/// assert!(!params.embedding());
323+
/// ```
324+
#[must_use]
325+
pub fn embedding(&self) -> bool {
326+
self.context_params.embedding
327+
}
328+
329+
/// Enable the use of embeddings
330+
///
331+
/// # Examples
332+
///
333+
/// ```rust
334+
/// use llama_cpp_2::context::params::LlamaContextParams;
335+
/// let params = LlamaContextParams::default()
336+
/// .with_embedding(true);
337+
/// assert!(params.embedding());
338+
/// ```
339+
#[must_use]
340+
pub fn with_embedding(mut self, embedding: bool) -> Self {
341+
self.context_params.embedding = embedding;
342+
self
343+
}
285344
}
286345

287346
/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)

llama-cpp-2/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ pub enum LLamaCppError {
5252
/// There was an error adding a token to a batch.
5353
#[error["{0}"]]
5454
BatchAddError(#[from] BatchAddError),
55+
#[error(transparent)]
56+
EmbeddingError(#[from] EmbeddingsError),
5557
}
5658

5759
/// Failed to Load context
@@ -76,6 +78,13 @@ pub enum DecodeError {
7678
Unknown(c_int),
7779
}
7880

81+
/// When embedding related functions fail
82+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
83+
pub enum EmbeddingsError {
84+
#[error("Embeddings weren't enabled in the context options")]
85+
NotEnabled,
86+
}
87+
7988
/// Decode a error from llama.cpp into a [`DecodeError`].
8089
impl From<NonZeroI32> for DecodeError {
8190
fn from(value: NonZeroI32) -> Self {

llama-cpp-2/src/llama_backend.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use crate::LLamaCppError;
44
use std::sync::atomic::AtomicBool;
55
use std::sync::atomic::Ordering::SeqCst;
6+
use llama_cpp_sys_2::ggml_log_level;
67

78
/// Representation of an initialized llama backend
89
/// This is required as a parameter for most llama functions as the backend must be initialized
@@ -68,6 +69,19 @@ impl LlamaBackend {
6869
}
6970
Ok(LlamaBackend {})
7071
}
72+
73+
/// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`.
74+
pub fn void_logs(&mut self) {
75+
unsafe extern "C" fn void_log(
76+
_level: ggml_log_level,
77+
_text: *const ::std::os::raw::c_char,
78+
_user_data: *mut ::std::os::raw::c_void,
79+
) {}
80+
81+
unsafe {
82+
llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut())
83+
}
84+
}
7185
}
7286

7387
/// A rusty wrapper around `numa_strategy`.

llama-cpp-2/src/llama_batch.rs

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos
66
/// A safe wrapper around `llama_batch`.
77
#[derive(Debug)]
88
pub struct LlamaBatch {
9-
/// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initilized
9+
/// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
1010
allocated: usize,
11-
/// The logits that are initilized. Used by [`LlamaContext`] to ensure that only initilized logits are accessed.
11+
/// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
1212
pub(crate) initialized_logits: Vec<i32>,
13-
/// The llama_cpp batch. always initilize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
13+
/// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
1414
pub(crate) llama_batch: llama_batch,
1515
}
1616

@@ -31,7 +31,7 @@ impl LlamaBatch {
3131
}
3232

3333
/// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the
34-
/// token will be initilized and can be read from after the next decode.
34+
/// token will be initialized and can be read from after the next decode.
3535
///
3636
/// # Panics
3737
///
@@ -90,7 +90,49 @@ impl LlamaBatch {
9090

9191
Ok(())
9292
}
93-
/// Create a new `LlamaBatch` that cab contain up to `n_tokens` tokens.
93+
94+
/// Add a sequence of tokens to the batch for the given sequence id. If [logits_all] is true, the
95+
/// tokens will be initialized and can be read from after the next decode.
96+
///
97+
/// Either way the last token in the sequence will have its logits set to `true`.
98+
///
99+
/// # Errors
100+
///
101+
/// Returns an error if there is insufficient space in the buffer
102+
pub fn add_sequence(&mut self, tokens: &[LlamaToken],
103+
seq_id: i32,
104+
logits_all: bool) -> Result<(), BatchAddError> {
105+
let n_tokens_0 = self.llama_batch.n_tokens;
106+
let n_tokens = tokens.len();
107+
108+
if self.allocated < n_tokens_0 as usize + n_tokens {
109+
return Err(BatchAddError::InsufficientSpace(self.allocated));
110+
}
111+
if n_tokens == 0 {
112+
return Ok(())
113+
}
114+
115+
self.llama_batch.n_tokens += n_tokens as i32;
116+
for (i, token) in tokens.iter().enumerate() {
117+
let j = n_tokens_0 as usize + i;
118+
unsafe {
119+
self.llama_batch.token.add(j).write(token.0);
120+
self.llama_batch.pos.add(j).write(i as i32);
121+
let seq_id_ptr = *self.llama_batch.seq_id.add(j);
122+
seq_id_ptr.write(seq_id);
123+
self.llama_batch.n_seq_id.add(j).write(1);
124+
self.llama_batch.logits.add(j).write(logits_all as i8)
125+
}
126+
}
127+
128+
unsafe {
129+
self.llama_batch.logits.add(n_tokens - 1).write(true as i8);
130+
}
131+
132+
Ok(())
133+
}
134+
135+
/// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
94136
///
95137
/// # Arguments
96138
///

llama-cpp-2/src/model.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
//! A safe wrapper around `llama_model`.
2-
use crate::context::params::LlamaContextParams;
2+
use std::ffi::CString;
3+
use std::os::raw::c_int;
4+
use std::path::Path;
5+
use std::ptr::NonNull;
6+
7+
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
38
use crate::context::LlamaContext;
9+
use crate::context::params::LlamaContextParams;
410
use crate::llama_backend::LlamaBackend;
511
use crate::model::params::LlamaModelParams;
612
use crate::token::LlamaToken;
713
use crate::token_type::LlamaTokenType;
8-
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
9-
use std::ffi::CString;
10-
use std::os::raw::c_int;
11-
use std::path::Path;
12-
use std::ptr::NonNull;
1314

1415
pub mod params;
1516

@@ -29,6 +30,7 @@ pub enum AddBos {
2930
/// Do not add the beginning of stream token to the start of the string.
3031
Never,
3132
}
33+
3234
unsafe impl Send for LlamaModel {}
3335

3436
unsafe impl Sync for LlamaModel {}
@@ -38,12 +40,12 @@ impl LlamaModel {
3840
///
3941
/// # Panics
4042
///
41-
/// If the number of tokens the model was trained on does not fit into an `u16`. This should be impossible on most
43+
/// If the number of tokens the model was trained on does not fit into an `u32`. This should be impossible on most
4244
/// platforms due to llama.cpp returning a `c_int` (i32 on most platforms) which is almost certainly positive.
4345
#[must_use]
44-
pub fn n_ctx_train(&self) -> u16 {
46+
pub fn n_ctx_train(&self) -> u32 {
4547
let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
46-
u16::try_from(n_ctx_train).expect("n_ctx_train fits into an u16")
48+
u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
4749
}
4850

4951
/// Get all tokens in the model.
@@ -54,6 +56,7 @@ impl LlamaModel {
5456
.map(LlamaToken::new)
5557
.map(|llama_token| (llama_token, self.token_to_str(llama_token)))
5658
}
59+
5760
/// Get the beginning of stream token.
5861
#[must_use]
5962
pub fn token_bos(&self) -> LlamaToken {
@@ -276,7 +279,7 @@ impl LlamaModel {
276279
/// # Errors
277280
///
278281
/// See [`LlamaModelLoadError`] for more information.
279-
#[tracing::instrument(skip_all)]
282+
#[tracing::instrument(skip_all, fields(params))]
280283
pub fn load_from_file(
281284
_: &LlamaBackend,
282285
path: impl AsRef<Path>,
@@ -290,13 +293,12 @@ impl LlamaModel {
290293

291294
let cstr = CString::new(path)?;
292295
let llama_model = unsafe {
293-
println!("{:?}", params.params);
294296
llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params)
295297
};
296298

297299
let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
298300

299-
println!("Loaded {path:?}");
301+
tracing::debug!(?path, "Loaded model");
300302
Ok(LlamaModel { model })
301303
}
302304

@@ -318,7 +320,7 @@ impl LlamaModel {
318320
};
319321
let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
320322

321-
Ok(LlamaContext::new(self, context))
323+
Ok(LlamaContext::new(self, context, params.embedding()))
322324
}
323325
}
324326

0 commit comments

Comments
 (0)