Skip to content

Commit 4e1c676

Browse files
committed
feat: reranking (LlamaRankingContext)
1 parent bb33a5d commit 4e1c676

File tree

12 files changed

+543
-4
lines changed

12 files changed

+543
-4
lines changed

.vitepress/config/apiReferenceSidebar.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import {DefaultTheme} from "vitepress";
22
/* eslint import/no-unresolved: "off" */
3-
import typedocSidebar from "../../docs/api/typedoc-sidebar.json"; // if this import fails, run `npm run docs:generateTypedoc`
3+
import typedocSidebar from "../../docs/api/typedoc-sidebar.json";
44

55
const categoryOrder = [
66
"Functions",
@@ -28,6 +28,7 @@ const classesOrder = [
2828
"LlamaCompletion",
2929
"LlamaEmbeddingContext",
3030
"LlamaEmbedding",
31+
"LlamaRankingContext",
3132
"LlamaGrammar",
3233
"LlamaJsonSchemaGrammar",
3334
"LlamaText",

docs/guide/embedding.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,59 @@ const embedding = await context.getEmbeddingFor(text);
138138
console.log("Embedding vector:", embedding.vector);
139139
```
140140
141+
## Reranking Documents {#reranking}
142+
After you search for the most similar documents using embedding vectors,
143+
you can use inference to rerank (sort) the documents based on their relevance to the given query.
144+
145+
Doing this allows you to combine the best of both worlds: the speed of embedding and the quality of inference.
146+
147+
```typescript
148+
import {fileURLToPath} from "url";
149+
import path from "path";
150+
import {getLlama} from "node-llama-cpp";
151+
152+
const __dirname = path.dirname(
153+
fileURLToPath(import.meta.url)
154+
);
155+
156+
const llama = await getLlama();
157+
const model = await llama.loadModel({
158+
modelPath: path.join(__dirname, "bge-reranker-v2-m3-Q8_0.gguf")
159+
});
160+
const context = await model.createRankingContext();
161+
162+
const documents = [
163+
"The sky is clear and blue today",
164+
"I love eating pizza with extra cheese",
165+
"Dogs love to play fetch with their owners",
166+
"The capital of France is Paris",
167+
"Drinking water is important for staying hydrated",
168+
"Mount Everest is the tallest mountain in the world",
169+
"A warm cup of tea is perfect for a cold winter day",
170+
"Painting is a form of creative expression",
171+
"Not all the things that shine are made of gold",
172+
"Cleaning the house is a good way to keep it tidy"
173+
];
174+
175+
const query = "Tell me a goegraphical fact";
176+
const rankedDocuments = await context.rankAndSort(query, documents);
177+
178+
const topDocument = rankedDocuments[0]!;
179+
const secondDocument = rankedDocuments[1]!;
180+
181+
console.log("query:", query);
182+
console.log("Top document:", topDocument.document);
183+
console.log("Second document:", secondDocument.document);
184+
console.log("Ranked documents:", rankedDocuments);
185+
```
186+
> This example will produce this output:
187+
> ```
188+
> query: Tell me a goegraphical fact
189+
> Top document: Mount Everest is the tallest mountain in the world
190+
> Second document: The capital of France is Paris
191+
> ```
192+
> This example uses [bge-reranker-v2-m3-Q8_0.gguf](https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF/blob/main/bge-reranker-v2-m3-Q8_0.gguf)
193+
141194
## Using External Databases
142195
When you have a large number of documents you want to use with embedding, it's often more efficient to store them with their embedding in an external database and search for the most similar embeddings there.
143196

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ npx -y node-llama-cpp inspect gpu
9797
* [Remote GGUF reader](./api/functions/readGgufFileInfo.md)
9898
* [User input safety](./guide/llama-text.md#input-safety-in-node-llama-cpp)
9999
* [Token prediction](./guide/token-prediction.md)
100+
* [Reranking](./guide/embedding.md#reranking)
100101

101102
</template>
102103
<template v-slot:simple-code>

llama/addon/AddonContext.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
415415
context_params.embeddings = options.Get("embeddings").As<Napi::Boolean>().Value();
416416
}
417417

418+
if (options.Has("ranking") && options.Get("ranking").As<Napi::Boolean>().Value()) {
419+
context_params.pooling_type = LLAMA_POOLING_TYPE_RANK;
420+
}
421+
418422
if (options.Has("flashAttention")) {
419423
context_params.flash_attn = options.Get("flashAttention").As<Napi::Boolean>().Value();
420424
}

src/bindings/AddonTypes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export type BindingModule = {
2626
flashAttention?: boolean,
2727
logitsAll?: boolean,
2828
embeddings?: boolean,
29+
ranking?: boolean,
2930
threads?: number,
3031
performanceTracking?: boolean
3132
}): AddonContext

src/evaluator/LlamaContext/LlamaContext.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ export class LlamaContext {
8080
itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism"
8181
} = {},
8282
performanceTracking = false,
83-
_embeddings
83+
_embeddings,
84+
_ranking
8485
}: LlamaContextOptions & {
8586
sequences: number,
8687
contextSize: number,
@@ -121,6 +122,7 @@ export class LlamaContext {
121122
flashAttention: this._flashAttention,
122123
threads: this._idealThreads,
123124
embeddings: _embeddings,
125+
ranking: _ranking,
124126
performanceTracking: this._performanceTracking
125127
}));
126128
this._batchingOptions = {

src/evaluator/LlamaContext/types.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,13 @@ export type LlamaContextOptions = {
171171
* embedding mode only
172172
* @internal
173173
*/
174-
_embeddings?: boolean
174+
_embeddings?: boolean,
175+
176+
/**
177+
* ranking mode
178+
* @internal
179+
*/
180+
_ranking?: boolean
175181
};
176182
export type LlamaContextSequenceRepeatPenalty = {
177183
/** Tokens to lower the predication probability of to be the next predicted token */

src/evaluator/LlamaModel/LlamaModel.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {LlamaEmbeddingContext, LlamaEmbeddingContextOptions} from "../LlamaEmbed
1818
import {GgufArchitectureType, GgufMetadata} from "../../gguf/types/GgufMetadataTypes.js";
1919
import {OverridesObject} from "../../utils/OverridesObject.js";
2020
import {maxRecentDetokenizerTokens} from "../../consts.js";
21+
import {LlamaRankingContext, LlamaRankingContextOptions} from "../LlamaRankingContext.js";
2122
import {TokenAttribute, TokenAttributes} from "./utils/TokenAttributes.js";
2223
import type {Llama} from "../../bindings/Llama.js";
2324
import type {BuiltinSpecialTokenValue} from "../../utils/LlamaText.js";
@@ -532,6 +533,16 @@ export class LlamaModel {
532533
return await LlamaEmbeddingContext._create({_model: this}, options);
533534
}
534535

536+
/**
537+
* @see [Reranking Documents](https://node-llama-cpp.withcat.ai/guide/embedding#reranking) tutorial
538+
*/
539+
public async createRankingContext(options: LlamaRankingContextOptions = {}) {
540+
if (this._vocabOnly)
541+
throw new Error("Model is loaded in vocabOnly mode, so no context can be created");
542+
543+
return await LlamaRankingContext._create({_model: this}, options);
544+
}
545+
535546
/**
536547
* Get warnings about the model file that would affect its usage.
537548
*

0 commit comments

Comments
 (0)