Skip to content

Commit d1b4416

Browse files
authored
fix: reranking probabilities (#412)
1 parent 5d07289 commit d1b4416

File tree

4 files changed

+71
-40
lines changed

4 files changed

+71
-40
lines changed

package-lock.json

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@
167167
"typedoc": "^0.27.6",
168168
"typedoc-plugin-markdown": "^4.4.1",
169169
"typedoc-plugin-mdn-links": "^4.0.7",
170-
"typedoc-vitepress-theme": "^1.1.1",
170+
"typedoc-vitepress-theme": "^1.1.2",
171171
"typescript": "^5.7.2",
172172
"typescript-eslint": "^8.19.1",
173173
"vite-node": "^2.1.8",

src/evaluator/LlamaRankingContext.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ export class LlamaRankingContext {
7676

7777
/**
7878
* Get the ranking score for a document for a query.
79+
*
80+
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
81+
* @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
7982
*/
8083
public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
8184
if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
@@ -96,6 +99,9 @@ export class LlamaRankingContext {
9699

97100
/**
98101
* Get the ranking scores for all the given documents for a query.
102+
*
103+
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
104+
* @returns an array of ranking scores between 0 and 1 representing the probability that the document is relevant to the query.
99105
*/
100106
public async rankAll(query: Token[] | string | LlamaText, documents: Array<Token[] | string | LlamaText>): Promise<number[]> {
101107
const resolvedTokens = documents.map((document) => this._getEvaluationInput(query, document));
@@ -120,9 +126,15 @@ export class LlamaRankingContext {
120126

121127
/**
122128
* Get the ranking scores for all the given documents for a query and sort them by score from highest to lowest.
129+
*
130+
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
123131
*/
124132
public async rankAndSort<const T extends string>(query: Token[] | string | LlamaText, documents: T[]): Promise<Array<{
125133
document: T,
134+
135+
/**
136+
* A ranking score is a number between 0 and 1 representing the probability that the document is relevant to the query.
137+
*/
126138
score: number
127139
}>> {
128140
const scores = await this.rankAll(query, documents);
@@ -190,7 +202,10 @@ export class LlamaRankingContext {
190202
if (embedding.length === 0)
191203
return 0;
192204

193-
return embedding[0]!;
205+
const logit = embedding[0]!;
206+
const probability = logitToSigmoid(logit);
207+
208+
return probability;
194209
});
195210
}
196211

@@ -249,3 +264,7 @@ function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffi
249264

250265
return undefined;
251266
}
267+
268+
function logitToSigmoid(logit: number) {
269+
return 1 / (1 + Math.exp(-logit));
270+
}

test/modelDependent/bgeReranker/rank.test.ts

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ describe("bgeReranker", () => {
4040
const highestRankDocument = documents[highestRankIndex];
4141
expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world");
4242

43-
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4");
43+
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156");
4444
expect(simplifyRanks(ranks)).toMatchInlineSnapshot(`
4545
[
46-
-11,
47-
-11,
48-
-11,
49-
-5.6,
50-
-11,
51-
-4,
52-
-11,
53-
-11,
54-
-11,
55-
-11,
46+
0.00001670142184809518,
47+
0.00001670142184809518,
48+
0.00001670142184809518,
49+
0.003684239899435989,
50+
0.00001670142184809518,
51+
0.01798620996209156,
52+
0.00001670142184809518,
53+
0.00001670142184809518,
54+
0.00001670142184809518,
55+
0.00001670142184809518,
5656
]
5757
`);
5858
});
@@ -91,19 +91,19 @@ describe("bgeReranker", () => {
9191
const highestRankDocument = documents[highestRankIndex];
9292
expect(highestRankDocument).to.eql("Mount Everest is the tallest mountain in the world");
9393

94-
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("-4");
94+
expect(simplifyRanks([highestRank])[0]).toMatchInlineSnapshot("0.01798620996209156");
9595
expect(simplifyRanks(ranks)).toMatchInlineSnapshot(`
9696
[
97-
-11,
98-
-11,
99-
-11,
100-
-5.6,
101-
-11,
102-
-4,
103-
-11,
104-
-11,
105-
-11,
106-
-11,
97+
0.00001670142184809518,
98+
0.00001670142184809518,
99+
0.00001670142184809518,
100+
0.003684239899435989,
101+
0.00001670142184809518,
102+
0.01798620996209156,
103+
0.00001670142184809518,
104+
0.00001670142184809518,
105+
0.00001670142184809518,
106+
0.00001670142184809518,
107107
]
108108
`);
109109
});
@@ -141,42 +141,42 @@ describe("bgeReranker", () => {
141141
expect(simplifySortedRanks([topDocument])[0]).toMatchInlineSnapshot(`
142142
{
143143
"document": "Mount Everest is the tallest mountain in the world",
144-
"score": -4,
144+
"score": 0.01798620996209156,
145145
}
146146
`);
147147
expect(simplifySortedRanks(rankedDocuments)).toMatchInlineSnapshot(`
148148
[
149149
{
150150
"document": "Mount Everest is the tallest mountain in the world",
151-
"score": -4,
151+
"score": 0.01798620996209156,
152152
},
153153
{
154154
"document": "The capital of France is Paris",
155-
"score": -5.6,
155+
"score": 0.003684239899435989,
156156
},
157157
{
158158
"document": "Not all the things that shine are made of gold",
159-
"score": -11,
159+
"score": 0.00001670142184809518,
160160
},
161161
{
162162
"document": "I love eating pizza with extra cheese",
163-
"score": -11,
163+
"score": 0.00001670142184809518,
164164
},
165165
{
166166
"document": "Dogs love to play fetch with their owners",
167-
"score": -11,
167+
"score": 0.00001670142184809518,
168168
},
169169
{
170170
"document": "The sky is clear and blue today",
171-
"score": -11,
171+
"score": 0.00001670142184809518,
172172
},
173173
{
174174
"document": "Cleaning the house is a good way to keep it tidy",
175-
"score": -11,
175+
"score": 0.00001670142184809518,
176176
},
177177
{
178178
"document": "A warm cup of tea is perfect for a cold winter day",
179-
"score": -11,
179+
"score": 0.00001670142184809518,
180180
},
181181
]
182182
`);
@@ -185,16 +185,28 @@ describe("bgeReranker", () => {
185185
});
186186

187187
function simplifyRanks<const T extends number[]>(ranks: T): T {
188-
return ranks.map((rank) => parseFloat(roundToPrecision(rank, 0.2).toFixed(1))) as T;
188+
return ranks.map((rank) => simplifyScore(rank)) as T;
189189
}
190190

191191
function simplifySortedRanks<const T extends {document: string, score: number}[]>(values: T): T {
192192
return values.map((item) => ({
193193
document: item.document,
194-
score: parseFloat(roundToPrecision(item.score, 0.2).toFixed(1))
194+
score: simplifyScore(item.score)
195195
})) as T;
196196
}
197197

198+
function simplifyScore(score: number) {
199+
return toSigmoid(parseFloat(roundToPrecision(toLogit(score), 0.2).toFixed(1)));
200+
}
201+
198202
function roundToPrecision(value: number, precision: number): number {
199203
return Math.round(value / precision) * precision;
200204
}
205+
206+
function toLogit(sigmoid: number) {
207+
return Math.log(sigmoid / (1 - sigmoid));
208+
}
209+
210+
function toSigmoid(logit: number) {
211+
return 1 / (1 + Math.exp(-logit));
212+
}

0 commit comments

Comments
 (0)