Skip to content

Commit 5e6a677

Browse files
authored
Merge pull request #133 from Hirtol/embeddings
Add Embedding Related Functionality
2 parents 7cc5b85 + ccb434b commit 5e6a677

File tree

11 files changed

+440
-26
lines changed

11 files changed

+440
-26
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = [
44
"llama-cpp-sys-2",
55
"llama-cpp-2",
6-
"simple",
6+
"simple", "embeddings",
77
]
88

99
[workspace.dependencies]

embeddings/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "embeddings"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" }
10+
hf-hub = { workspace = true }
11+
clap = { workspace = true , features = ["derive"] }
12+
anyhow = { workspace = true }
13+
14+
[lints]
15+
workspace = true

embeddings/src/main.rs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
2+
#![allow(
3+
clippy::cast_possible_wrap,
4+
clippy::cast_possible_truncation,
5+
clippy::cast_precision_loss,
6+
clippy::cast_sign_loss
7+
)]
8+
9+
use std::io::Write;
10+
use std::path::PathBuf;
11+
use std::str::FromStr;
12+
use std::time::Duration;
13+
14+
use anyhow::{bail, Context, Result};
15+
use clap::Parser;
16+
use hf_hub::api::sync::ApiBuilder;
17+
18+
use llama_cpp_2::context::LlamaContext;
19+
use llama_cpp_2::context::params::LlamaContextParams;
20+
use llama_cpp_2::ggml_time_us;
21+
use llama_cpp_2::llama_backend::LlamaBackend;
22+
use llama_cpp_2::llama_batch::LlamaBatch;
23+
use llama_cpp_2::model::AddBos;
24+
use llama_cpp_2::model::LlamaModel;
25+
use llama_cpp_2::model::params::LlamaModelParams;
26+
27+
#[derive(clap::Parser, Debug, Clone)]
28+
struct Args {
29+
/// The path to the model
30+
#[command(subcommand)]
31+
model: Model,
32+
/// The prompt
33+
#[clap(default_value = "Hello my name is")]
34+
prompt: String,
35+
/// Whether to normalise the produced embeddings
36+
#[clap(short)]
37+
normalise: bool,
38+
/// Disable offloading layers to the gpu
39+
#[cfg(feature = "cublas")]
40+
#[clap(long)]
41+
disable_gpu: bool,
42+
}
43+
44+
45+
#[derive(clap::Subcommand, Debug, Clone)]
46+
enum Model {
47+
/// Use an already downloaded model
48+
Local {
49+
/// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
50+
path: PathBuf,
51+
},
52+
/// Download a model from huggingface (or use a cached version)
53+
#[clap(name = "hf-model")]
54+
HuggingFace {
55+
/// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5`
56+
repo: String,
57+
/// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf`
58+
model: String,
59+
},
60+
}
61+
62+
impl Model {
63+
/// Convert the model to a path - may download from huggingface
64+
fn get_or_load(self) -> Result<PathBuf> {
65+
match self {
66+
Model::Local { path } => Ok(path),
67+
Model::HuggingFace { model, repo } => ApiBuilder::new()
68+
.with_progress(true)
69+
.build()
70+
.with_context(|| "unable to create huggingface api")?
71+
.model(repo)
72+
.get(&model)
73+
.with_context(|| "unable to download model"),
74+
}
75+
}
76+
}
77+
78+
fn main() -> Result<()> {
79+
let Args {
80+
model,
81+
prompt,
82+
normalise,
83+
#[cfg(feature = "cublas")]
84+
disable_gpu,
85+
} = Args::parse();
86+
87+
// init LLM
88+
let backend = LlamaBackend::init()?;
89+
90+
// offload all layers to the gpu
91+
let model_params = {
92+
#[cfg(feature = "cublas")]
93+
if !disable_gpu {
94+
LlamaModelParams::default().with_n_gpu_layers(1000)
95+
} else {
96+
LlamaModelParams::default()
97+
}
98+
#[cfg(not(feature = "cublas"))]
99+
LlamaModelParams::default()
100+
};
101+
102+
let model_path = model
103+
.get_or_load()
104+
.with_context(|| "failed to get model from args")?;
105+
106+
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
107+
.with_context(|| "unable to load model")?;
108+
109+
// initialize the context
110+
let ctx_params = LlamaContextParams::default()
111+
.with_n_threads_batch(std::thread::available_parallelism()?.get() as u32)
112+
.with_embeddings(true);
113+
114+
let mut ctx = model
115+
.new_context(&backend, ctx_params)
116+
.with_context(|| "unable to create the llama_context")?;
117+
118+
// Split the prompt to display the batching functionality
119+
let prompt_lines = prompt.lines();
120+
121+
// tokenize the prompt
122+
let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always))
123+
.collect::<Result<Vec<_>, _>>()
124+
.with_context(|| format!("failed to tokenize {prompt}"))?;
125+
126+
let n_ctx = ctx.n_ctx() as usize;
127+
let n_ctx_train = model.n_ctx_train();
128+
129+
eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}");
130+
131+
if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) {
132+
bail!("One of the provided prompts exceeds the size of the context window");
133+
}
134+
135+
// print the prompt token-by-token
136+
eprintln!();
137+
138+
for (i, token_line) in tokens_lines_list.iter().enumerate() {
139+
eprintln!("Prompt {i}");
140+
for token in token_line {
141+
eprintln!(" {} --> {}", token, model.token_to_str(*token)?);
142+
}
143+
eprintln!()
144+
}
145+
146+
std::io::stderr().flush()?;
147+
148+
// create a llama_batch with the size of the context
149+
// we use this object to submit token data for decoding
150+
let mut batch = LlamaBatch::new(n_ctx, 1);
151+
152+
let mut max_seq_id_batch = 0;
153+
let mut output = Vec::with_capacity(tokens_lines_list.len());
154+
155+
let t_main_start = ggml_time_us();
156+
157+
for tokens in &tokens_lines_list {
158+
// Flush the batch if the next prompt would exceed our batch size
159+
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
160+
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
161+
max_seq_id_batch = 0;
162+
}
163+
164+
batch.add_sequence(&tokens, max_seq_id_batch, false)?;
165+
max_seq_id_batch += 1;
166+
}
167+
// Handle final batch
168+
batch_decode(&mut ctx, &mut batch, max_seq_id_batch, &mut output, normalise)?;
169+
170+
let t_main_end = ggml_time_us();
171+
172+
for (i, embeddings) in output.iter().enumerate() {
173+
eprintln!("Embeddings {i}: {embeddings:?}");
174+
eprintln!();
175+
}
176+
177+
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
178+
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
179+
eprintln!(
180+
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
181+
total_tokens,
182+
duration.as_secs_f32(),
183+
total_tokens as f32 / duration.as_secs_f32()
184+
);
185+
186+
println!("{}", ctx.timings());
187+
188+
Ok(())
189+
}
190+
191+
fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
192+
ctx.clear_kv_cache();
193+
ctx.decode(batch).with_context(|| "llama_decode() failed")?;
194+
195+
for i in 0..s_batch {
196+
let embedding = ctx.embeddings_seq_ith(i).with_context(|| "Failed to get embeddings")?;
197+
let output_embeddings = if normalise {
198+
normalize(embedding)
199+
} else {
200+
embedding.to_vec()
201+
};
202+
203+
output.push(output_embeddings);
204+
}
205+
206+
batch.clear();
207+
208+
Ok(())
209+
}
210+
211+
fn normalize(input: &[f32]) -> Vec<f32> {
212+
let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt();
213+
214+
input.iter().map(|&val| val / magnitude).collect()
215+
}

llama-cpp-2/src/context.rs

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
33
use std::fmt::{Debug, Formatter};
44
use std::num::NonZeroI32;
5+
use std::ptr::NonNull;
6+
use std::slice;
57

68
use crate::llama_batch::LlamaBatch;
79
use crate::model::LlamaModel;
810
use crate::timing::LlamaTimings;
911
use crate::token::data::LlamaTokenData;
1012
use crate::token::LlamaToken;
11-
use crate::DecodeError;
12-
use std::ptr::NonNull;
13-
use std::slice;
13+
use crate::{DecodeError, EmbeddingsError};
1414

1515
pub mod kv_cache;
1616
pub mod params;
@@ -24,6 +24,7 @@ pub struct LlamaContext<'a> {
2424
/// a reference to the contexts model.
2525
pub model: &'a LlamaModel,
2626
initialized_logits: Vec<i32>,
27+
embeddings_enabled: bool,
2728
}
2829

2930
impl Debug for LlamaContext<'_> {
@@ -38,11 +39,13 @@ impl<'model> LlamaContext<'model> {
3839
pub(crate) fn new(
3940
llama_model: &'model LlamaModel,
4041
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
42+
embeddings_enabled: bool,
4143
) -> Self {
4244
Self {
4345
context: llama_context,
4446
model: llama_model,
4547
initialized_logits: Vec::new(),
48+
embeddings_enabled,
4649
}
4750
}
4851

@@ -80,6 +83,63 @@ impl<'model> LlamaContext<'model> {
8083
}
8184
}
8285

86+
/// Get the embeddings for the `i`th sequence in the current context.
87+
///
88+
/// # Returns
89+
///
90+
/// A slice containing the embeddings for the last decoded batch.
91+
/// The size corresponds to the `n_embd` parameter of the context's model.
92+
///
93+
/// # Errors
94+
///
95+
/// - When the current context was constructed without enabling embeddings.
96+
/// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
97+
/// - If the given sequence index exceeds the max sequence id.
98+
pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
99+
if !self.embeddings_enabled {
100+
return Err(EmbeddingsError::NotEnabled);
101+
}
102+
103+
unsafe {
104+
let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);
105+
106+
// Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
107+
if embedding.is_null() {
108+
Err(EmbeddingsError::NonePoolType)
109+
} else {
110+
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
111+
}
112+
}
113+
}
114+
115+
/// Get the embeddings for the `i`th token in the current context.
116+
///
117+
/// # Returns
118+
///
119+
/// A slice containing the embeddings for the last decoded batch of the given token.
120+
/// The size corresponds to the `n_embd` parameter of the context's model.
121+
///
122+
/// # Errors
123+
///
124+
/// - When the current context was constructed without enabling embeddings.
125+
/// - When the given token didn't have logits enabled when it was passed.
126+
/// - If the given token index exceeds the max token id.
127+
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
128+
if !self.embeddings_enabled {
129+
return Err(EmbeddingsError::NotEnabled);
130+
}
131+
132+
unsafe {
133+
let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
134+
// Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
135+
if embedding.is_null() {
136+
Err(EmbeddingsError::LogitsNotEnabled)
137+
} else {
138+
Ok(std::slice::from_raw_parts(embedding, self.model.n_embd() as usize))
139+
}
140+
}
141+
}
142+
83143
/// Get the logits for the ith token in the context.
84144
///
85145
/// # Panics

0 commit comments

Comments
 (0)