Skip to content

Commit c0faaef

Browse files
committed
Add en embeddings example
1 parent 542a410 commit c0faaef

File tree

5 files changed

+243
-1
lines changed

5 files changed

+243
-1
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = [
44
"llama-cpp-sys-2",
55
"llama-cpp-2",
6-
"simple",
6+
"simple", "embeddings",
77
]
88

99
[workspace.dependencies]

embeddings/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "embeddings"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.34" }
10+
hf-hub = { workspace = true }
11+
clap = { workspace = true , features = ["derive"] }
12+
anyhow = { workspace = true }
13+
14+
[lints]
15+
workspace = true

embeddings/src/main.rs

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2.
2+
#![allow(
3+
clippy::cast_possible_wrap,
4+
clippy::cast_possible_truncation,
5+
clippy::cast_precision_loss,
6+
clippy::cast_sign_loss
7+
)]
8+
9+
use std::io::Write;
10+
use std::path::PathBuf;
11+
use std::str::FromStr;
12+
use std::time::Duration;
13+
14+
use anyhow::{bail, Context, Result};
15+
use clap::Parser;
16+
use hf_hub::api::sync::ApiBuilder;
17+
use llama_cpp_2::context::LlamaContext;
18+
19+
use llama_cpp_2::context::params::LlamaContextParams;
20+
use llama_cpp_2::ggml_time_us;
21+
use llama_cpp_2::llama_backend::LlamaBackend;
22+
use llama_cpp_2::llama_batch::LlamaBatch;
23+
use llama_cpp_2::model::AddBos;
24+
use llama_cpp_2::model::LlamaModel;
25+
use llama_cpp_2::model::params::LlamaModelParams;
26+
27+
#[derive(clap::Parser, Debug, Clone)]
28+
struct Args {
29+
/// The path to the model
30+
#[command(subcommand)]
31+
model: Model,
32+
/// The prompt
33+
#[clap(default_value = "Hello my name is")]
34+
prompt: String,
35+
/// Whether to normalise the produced embeddings
36+
#[clap(short)]
37+
normalise: bool,
38+
/// Disable offloading layers to the gpu
39+
#[cfg(feature = "cublas")]
40+
#[clap(long)]
41+
disable_gpu: bool,
42+
}
43+
44+
45+
#[derive(clap::Subcommand, Debug, Clone)]
46+
enum Model {
47+
/// Use an already downloaded model
48+
Local {
49+
/// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
50+
path: PathBuf,
51+
},
52+
/// Download a model from huggingface (or use a cached version)
53+
#[clap(name = "hf-model")]
54+
HuggingFace {
55+
/// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF`
56+
repo: String,
57+
/// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf`
58+
model: String,
59+
},
60+
}
61+
62+
impl Model {
63+
/// Convert the model to a path - may download from huggingface
64+
fn get_or_load(self) -> Result<PathBuf> {
65+
match self {
66+
Model::Local { path } => Ok(path),
67+
Model::HuggingFace { model, repo } => ApiBuilder::new()
68+
.with_progress(true)
69+
.build()
70+
.with_context(|| "unable to create huggingface api")?
71+
.model(repo)
72+
.get(&model)
73+
.with_context(|| "unable to download model"),
74+
}
75+
}
76+
}
77+
78+
fn main() -> Result<()> {
79+
let Args {
80+
model,
81+
prompt,
82+
normalise,
83+
#[cfg(feature = "cublas")]
84+
disable_gpu,
85+
} = Args::parse();
86+
87+
// init LLM
88+
let backend = LlamaBackend::init()?;
89+
90+
// offload all layers to the gpu
91+
let model_params = {
92+
#[cfg(feature = "cublas")]
93+
if !disable_gpu {
94+
LlamaModelParams::default().with_n_gpu_layers(1000)
95+
} else {
96+
LlamaModelParams::default()
97+
}
98+
#[cfg(not(feature = "cublas"))]
99+
LlamaModelParams::default()
100+
};
101+
102+
let model_path = model
103+
.get_or_load()
104+
.with_context(|| "failed to get model from args")?;
105+
106+
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
107+
.with_context(|| "unable to load model")?;
108+
109+
// initialize the context
110+
let ctx_params = LlamaContextParams::default()
111+
.with_n_threads_batch(std::thread::available_parallelism()?.get() as u32)
112+
.with_embedding(true);
113+
114+
let mut ctx = model
115+
.new_context(&backend, ctx_params)
116+
.with_context(|| "unable to create the llama_context")?;
117+
118+
// Split the prompt to display the batching functionality
119+
let prompt_lines = prompt.lines();
120+
121+
// tokenize the prompt
122+
let tokens_lines_list = prompt_lines.map(|line| model.str_to_token(&line, AddBos::Always))
123+
.collect::<Result<Vec<_>, _>>()
124+
.with_context(|| format!("failed to tokenize {prompt}"))?;
125+
126+
let n_ctx = ctx.n_ctx() as usize;
127+
let n_ctx_train = model.n_ctx_train();
128+
129+
eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}");
130+
131+
if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) {
132+
bail!("One of the provided prompts exceeds the size of the context window");
133+
}
134+
135+
// print the prompt token-by-token
136+
eprintln!();
137+
138+
for (i, token_line) in tokens_lines_list.iter().enumerate() {
139+
eprintln!("Prompt {i}");
140+
for token in token_line {
141+
eprintln!(" {} --> {}", token, model.token_to_str(*token)?);
142+
}
143+
eprintln!()
144+
}
145+
146+
std::io::stderr().flush()?;
147+
148+
// create a llama_batch with the size of the context
149+
// we use this object to submit token data for decoding
150+
let mut batch = LlamaBatch::new(n_ctx, tokens_lines_list.len() as i32);
151+
152+
// Amount of tokens in the current batch
153+
let mut s_batch = 0;
154+
let mut output = Vec::with_capacity(tokens_lines_list.len());
155+
156+
let t_main_start = ggml_time_us();
157+
158+
for tokens in &tokens_lines_list {
159+
// Flush the batch if the next prompt would exceed our batch size
160+
if (batch.n_tokens() as usize + tokens.len()) > n_ctx {
161+
batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?;
162+
s_batch = 0;
163+
}
164+
165+
batch.add_sequence(&tokens, s_batch, false)?;
166+
s_batch += 1;
167+
}
168+
// Handle final batch
169+
batch_decode(&mut ctx, &mut batch, s_batch, &mut output, normalise)?;
170+
171+
let t_main_end = ggml_time_us();
172+
173+
for (i, embeddings) in output.iter().enumerate() {
174+
eprintln!("Embeddings {i}: {embeddings:?}");
175+
eprintln!("\n");
176+
}
177+
178+
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
179+
let total_tokens: usize = tokens_lines_list.iter().map(|v| v.len()).sum();
180+
181+
eprintln!(
182+
"Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n",
183+
total_tokens,
184+
duration.as_secs_f32(),
185+
total_tokens as f32 / duration.as_secs_f32()
186+
);
187+
188+
println!("{}", ctx.timings());
189+
190+
Ok(())
191+
}
192+
193+
fn batch_decode(ctx: &mut LlamaContext, batch: &mut LlamaBatch, s_batch: i32, output: &mut Vec<Vec<f32>>, normalise: bool) -> Result<()> {
194+
ctx.clear_kv_cache();
195+
ctx.decode(batch).with_context(|| "llama_decode() failed")?;
196+
batch.clear();
197+
198+
for i in 0..s_batch {
199+
let embedding = ctx.embeddings_ith(i).with_context(|| "Failed to get embeddings")?;
200+
let output_embeddings = if normalise {
201+
normalize(embedding)
202+
} else {
203+
embedding.to_vec()
204+
};
205+
206+
output.push(output_embeddings);
207+
}
208+
209+
Ok(())
210+
}
211+
212+
fn normalize(input: &[f32]) -> Vec<f32> {
213+
let magnitude = input.iter().fold(0.0, |acc, &val| val.mul_add(val, acc)).sqrt();
214+
215+
input.iter().map(|&val| val / magnitude).collect()
216+
}

llama-cpp-2/src/llama_batch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ impl LlamaBatch {
127127

128128
unsafe {
129129
self.llama_batch.logits.add(n_tokens - 1).write(true as i8);
130+
self.initialized_logits.push(self.llama_batch.n_tokens - 1);
130131
}
131132

132133
Ok(())

0 commit comments

Comments
 (0)