From d789cace6f125c9e78faf6fd47e8c30ace609f97 Mon Sep 17 00:00:00 2001 From: srv1n Date: Thu, 6 Feb 2025 18:06:30 +0530 Subject: [PATCH] undid making initialized_logits public --- Cargo.lock | 11 + Cargo.toml | 2 +- examples/reranker/Cargo.toml | 20 ++ examples/reranker/README.md | 75 +++++++ examples/reranker/src/main.rs | 340 ++++++++++++++++++++++++++++++ llama-cpp-2/src/context/params.rs | 4 + 6 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 examples/reranker/Cargo.toml create mode 100644 examples/reranker/README.md create mode 100644 examples/reranker/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 1994a720..09445542 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -934,6 +934,17 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "reranker" +version = "0.1.86" +dependencies = [ + "anyhow", + "clap", + "encoding_rs", + "hf-hub", + "llama-cpp-2", +] + [[package]] name = "ring" version = "0.17.8" diff --git a/Cargo.toml b/Cargo.toml index 1750d6ff..903bdfab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ members = [ "llama-cpp-sys-2", "llama-cpp-2", "examples/embeddings", - "examples/simple", + "examples/simple", "examples/reranker", ] [workspace.dependencies] diff --git a/examples/reranker/Cargo.toml b/examples/reranker/Cargo.toml new file mode 100644 index 00000000..fa32c2d3 --- /dev/null +++ b/examples/reranker/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "reranker" +version = "0.1.86" +edition = "2021" + +[dependencies] +llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" } +hf-hub = { workspace = true } +clap = { workspace = true, features = ["derive"] } +anyhow = { workspace = true } +encoding_rs = { workspace = true } + +[features] +cuda = ["llama-cpp-2/cuda"] +metal = ["llama-cpp-2/metal"] +native = ["llama-cpp-2/native"] +vulkan = ["llama-cpp-2/vulkan"] + +[lints] +workspace = true \ No newline at end of file diff --git a/examples/reranker/README.md b/examples/reranker/README.md new file mode 100644 index 00000000..935c37ca --- /dev/null +++ b/examples/reranker/README.md @@ -0,0 +1,75 @@ +# Rust Reranker Implementation + +A Rust implementation of cross-encoder based reranking using llama-cpp-2. Cross-encoder reranking is a more accurate way to determine similarity between queries and documents compared to traditional embedding-based approaches. + +## Overview + +This implementation adds a new pooling type `LLAMA_POOLING_TYPE_RANK` which enables cross-encoder based reranking. Unlike traditional embedding approaches that encode query and document separately, this method: + +- Processes query and document pairs together in a single pass +- Directly evaluates semantic relationships between the pairs +- Outputs raw similarity scores indicating relevance + +## Installation + +```bash +# Follow instructions to clone repo. +# Navigate to examples reranker +cd examples/reranker + +# Build the project +cargo build --release +``` + +## Usage + +### Command Line Interface + +```bash +cargo run --release -- \  ✔ │ 5s │ 12:48:35 + --model-path "models/bge-reranker-v2-m3.gguf" \ + --query "what is panda?" \ + --documents "hi" \ + --documents "it's a bear" \ + --documents "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." \ + --pooling rank +``` +Should output(with bge-reranker-v2-m3-Q5_0): +rerank score 0: -6.551 +rerank score 1: -3.802 +rerank score 2: 4.522 + +### CLI Arguments + +- `--model-path`: Path to the GGUF model file +- `--query`: The search query +- `--documents`: One or more documents to rank against the query +- `--pooling`: Pooling type (options: none, mean, rank) + +### Pooling Types + +- `rank`: Performs cross-encoder reranking + + +Note: The raw scores are not normalized through a sigmoid function. If you need scores between 0-1, you'll need to implement sigmoid normalization in your application code. + +# Additional notes + +- Query and documents are concatenated using the format queryanswer + +## Supported Models + +Some tested models: + +- [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) +- [jinaai/jina-reranker-v1-tiny-en](https://huggingface.co/jinaai/jina-reranker-v1-tiny-en) + +Not tested others, but anything supported by llama.cpp should work. + +## Implementation Details + +This is a close Rust implementation of the reranker implementation discussed in [llama.cpp PR #9510](https://github.com/ggerganov/llama.cpp/pull/9510). + +## Potential issues + +The bos, eos, sep tokens are being hardcoded. We need to ideally get it from the model and build out the prompts based on each specific model. \ No newline at end of file diff --git a/examples/reranker/src/main.rs b/examples/reranker/src/main.rs new file mode 100644 index 00000000..5a6109ef --- /dev/null +++ b/examples/reranker/src/main.rs @@ -0,0 +1,340 @@ +//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. +#![allow( + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss +)] + +use std::io::Write; +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use hf_hub::api::sync::ApiBuilder; + +use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType}; +use llama_cpp_2::context::LlamaContext; +use llama_cpp_2::ggml_time_us; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::LlamaModel; +use llama_cpp_2::model::{AddBos, Special}; + +#[derive(clap::Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Path to the model file + #[clap(long)] + model_path: PathBuf, + + /// The query to embed + #[clap(long)] + query: String, + + /// The documents to embed and compare against + #[clap(long, num_args = 1..)] + documents: Vec, + + /// Pooling type (none, mean, or rank) + #[clap(long, default_value = "none")] + pooling: String, + + /// Whether to normalise the produced embeddings + #[clap(long, default_value_t = true)] + normalise: bool, +} + +fn main() -> Result<()> { + let Args { + model_path, + query, + documents, + pooling, + normalise, + } = Args::parse(); + + // init LLM + let backend = LlamaBackend::init()?; + + // offload all layers to the gpu + let model_params = { + #[cfg(any(feature = "cuda", feature = "vulkan"))] + if !disable_gpu { + LlamaModelParams::default().with_n_gpu_layers(1000) + } else { + LlamaModelParams::default() + } + #[cfg(not(any(feature = "cuda", feature = "vulkan")))] + LlamaModelParams::default() + }; + + let model = LlamaModel::load_from_file(&backend, model_path, &model_params) + .with_context(|| "unable to load model")?; + // println!("pooling: {}", pooling); + let pooling_type = match pooling.as_str() { + "mean" => LlamaPoolingType::Mean, + "none" => LlamaPoolingType::None, + "rank" => LlamaPoolingType::Rank, + _ => LlamaPoolingType::Unspecified, + }; + + let ctx_params = LlamaContextParams::default() + .with_n_threads_batch(std::thread::available_parallelism()?.get().try_into()?) + .with_embeddings(true) + .with_pooling_type(pooling_type); + println!("ctx_params: {:?}", ctx_params); + let mut ctx = model + .new_context(&backend, ctx_params) + .with_context(|| "unable to create the llama_context")?; + + let n_embd = model.n_embd(); + + let prompt_lines = { + let mut lines = Vec::new(); + for doc in documents { + // Todo! update to get eos and sep from model instead of hardcoding + lines.push(format!("{query}{eos}{sep}{doc}", sep = "", eos = "")); + } + lines + }; + + println!("prompt_lines: {:?}", prompt_lines); + // tokenize the prompt + let tokens_lines_list = prompt_lines + .iter() + .map(|line| model.str_to_token(line, AddBos::Always)) + .collect::, _>>() + .with_context(|| format!("failed to tokenize {:?}", prompt_lines))?; + + let n_ctx = ctx.n_ctx() as usize; + let n_ctx_train = model.n_ctx_train(); + + eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); + + if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { + bail!("One of the provided prompts exceeds the size of the context window"); + } + + // print the prompt token-by-token + eprintln!(); + + for (i, token_line) in tokens_lines_list.iter().enumerate() { + eprintln!("Prompt {i} --> {}", prompt_lines[i]); + eprintln!("Number of tokens: {}", token_line.len()); + for token in token_line { + // Attempt to convert token to string and print it; if it fails, print the token instead + match model.token_to_str(*token, Special::Tokenize) { + Ok(token_str) => eprintln!("{token} --> {token_str}"), + Err(e) => { + eprintln!("Failed to convert token to string, error: {e}"); + eprintln!("Token value: {token}"); + } + } + } + eprintln!(); + } + + std::io::stderr().flush()?; + + // create a llama_batch with the size of the context + // we use this object to submit token data for decoding + let mut batch = LlamaBatch::new(2048, 1); + + // Todo! update to get n_embd to init vector size for better memory management + // let mut n_embd_count = if pooling == "none" { + // tokens_lines_list.iter().map(|tokens| tokens.len()).sum() + // } else { + // tokens_lines_list.len() + // }; + let mut embeddings_stored = 0; + let mut max_seq_id_batch = 0; + let mut output = Vec::with_capacity(tokens_lines_list.len()); + + let t_main_start = ggml_time_us(); + + for tokens in &tokens_lines_list { + // Flush the batch if the next prompt would exceed our batch size + if (batch.n_tokens() as usize + tokens.len()) > 2048 { + batch_decode( + &mut ctx, + &mut batch, + max_seq_id_batch, + n_embd, + &mut output, + normalise, + pooling.clone(), + )?; + embeddings_stored += if pooling == "none" { + batch.n_tokens() + } else { + max_seq_id_batch + }; + max_seq_id_batch = 0; + batch.clear(); + } + + batch.add_sequence(tokens, max_seq_id_batch, false)?; + max_seq_id_batch += 1; + } + // Handle final batch + batch_decode( + &mut ctx, + &mut batch, + max_seq_id_batch, + n_embd, + &mut output, + normalise, + pooling.clone(), + )?; + + let t_main_end = ggml_time_us(); + + for (j, embeddings) in output.iter().enumerate() { + if pooling == "none" { + eprintln!("embedding {j}: "); + for i in 0..n_embd as usize { + if !normalise { + eprint!("{:6.5} ", embeddings[i]); + } else { + eprint!("{:9.6} ", embeddings[i]); + } + } + eprintln!(); + } else if pooling == "rank" { + eprintln!("rerank score {j}: {:8.3}", embeddings[0]); + } else { + eprintln!("embedding {j}: "); + for i in 0..n_embd as usize { + if !normalise { + eprint!("{:6.5} ", embeddings[i]); + } else { + eprint!("{:9.6} ", embeddings[i]); + } + } + eprintln!(); + } + } + + let duration = Duration::from_micros((t_main_end - t_main_start) as u64); + let total_tokens: usize = tokens_lines_list.iter().map(Vec::len).sum(); + eprintln!( + "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", + total_tokens, + duration.as_secs_f32(), + total_tokens as f32 / duration.as_secs_f32() + ); + + println!("{}", ctx.timings()); + + Ok(()) +} + +fn batch_decode( + ctx: &mut LlamaContext, + batch: &mut LlamaBatch, + s_batch: i32, + n_embd: i32, + output: &mut Vec>, + normalise: bool, + pooling: String, +) -> Result<()> { + eprintln!( + "{}: n_tokens = {}, n_seq = {}", + stringify!(batch_decode), + batch.n_tokens(), + s_batch + ); + + // Clear previous kv_cache values + ctx.clear_kv_cache(); + + ctx.decode(batch).with_context(|| "llama_decode() failed")?; + + for i in 0..s_batch { + let embeddings = ctx + .embeddings_seq_ith(i) + .with_context(|| "Failed to get sequence embeddings")?; + let normalized = if normalise { + if pooling == "rank" { + normalize_embeddings(&embeddings, -1) + } else { + normalize_embeddings(&embeddings, 2) + } + } else { + embeddings.to_vec() + }; + output.push(normalized); + } + + batch.clear(); + + Ok(()) +} + +/// Normalizes embeddings based on different normalization strategies +fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec { + let n = input.len(); + let mut output = vec![0.0; n]; + + let sum = match embd_norm { + -1 => 1.0, // no normalization + 0 => { + // max absolute + let max_abs = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max) / 32760.0; + max_abs as f64 + } + 2 => { + // euclidean norm + input + .iter() + .map(|x| (*x as f64).powi(2)) + .sum::() + .sqrt() + } + p => { + // p-norm + let sum = input.iter().map(|x| (x.abs() as f64).powi(p)).sum::(); + sum.powf(1.0 / p as f64) + } + }; + + let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 }; + + for i in 0..n { + output[i] = (input[i] as f64 * norm) as f32; + } + + output +} + +// /// Calculates cosine similarity between two embedding vectors +// fn embedding_similarity_cos(embd1: &[f32], embd2: &[f32]) -> f32 { +// assert_eq!(embd1.len(), embd2.len(), "Embedding vectors must be the same length"); + +// let (sum, sum1, sum2) = embd1.iter().zip(embd2.iter()).fold( +// (0.0f64, 0.0f64, 0.0f64), +// |(sum, sum1, sum2), (e1, e2)| { +// let e1 = *e1 as f64; +// let e2 = *e2 as f64; +// ( +// sum + e1 * e2, +// sum1 + e1 * e1, +// sum2 + e2 * e2 +// ) +// } +// ); + +// // Handle zero vectors +// if sum1 == 0.0 || sum2 == 0.0 { +// return if sum1 == 0.0 && sum2 == 0.0 { +// 1.0 // two zero vectors are similar +// } else { +// 0.0 +// }; +// } + +// (sum / (sum1.sqrt() * sum2.sqrt())) as f32 +// } diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index cfaf967b..892dc8dc 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -55,6 +55,8 @@ pub enum LlamaPoolingType { Cls = 2, /// Last pooling Last = 3, + /// Rank pooling + Rank = 4, } /// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if @@ -66,6 +68,7 @@ impl From for LlamaPoolingType { 1 => Self::Mean, 2 => Self::Cls, 3 => Self::Last, + 4 => Self::Rank, _ => Self::Unspecified, } } @@ -79,6 +82,7 @@ impl From for i32 { LlamaPoolingType::Mean => 1, LlamaPoolingType::Cls => 2, LlamaPoolingType::Last => 3, + LlamaPoolingType::Rank => 4, LlamaPoolingType::Unspecified => -1, } }