Skip to content

Commit 00305f7

Browse files
authored
feat: Qwen3 Reranker support (#506)
* feat: Qwen3 Reranker support * fix: handle HuggingFace rate limit response * fix: adapt to `llama.cpp` breaking changes
1 parent eefe78c commit 00305f7

File tree

11 files changed

+281
-51
lines changed

11 files changed

+281
-51
lines changed

src/bindings/utils/compileLLamaCpp.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,14 @@ export async function compileLlamaCpp(buildOptions: BuildOptions, compileOptions
131131
if (!cmakeCustomOptions.has("GGML_CCACHE"))
132132
cmakeCustomOptions.set("GGML_CCACHE", "OFF");
133133

134-
if (!cmakeCustomOptions.has("LLAMA_CURL"))
134+
if (!cmakeCustomOptions.has("LLAMA_CURL") || isCmakeValueOff(cmakeCustomOptions.get("LLAMA_CURL"))) {
135135
cmakeCustomOptions.set("LLAMA_CURL", "OFF");
136136

137+
// avoid linking to extra libraries that we don't use
138+
if (!cmakeCustomOptions.has("LLAMA_OPENSSL"))
139+
cmakeCustomOptions.set("LLAMA_OPENSSL", "OFF");
140+
}
141+
137142
if (buildOptions.platform === "win" && buildOptions.arch === "arm64" && !cmakeCustomOptions.has("GGML_OPENMP"))
138143
cmakeCustomOptions.set("GGML_OPENMP", "OFF");
139144

src/evaluator/LlamaRankingContext.ts

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
import {AsyncDisposeAggregator, EventRelay, withLock} from "lifecycle-utils";
1+
import {AsyncDisposeAggregator, EventRelay, splitText, withLock} from "lifecycle-utils";
22
import {Token} from "../types.js";
33
import {LlamaText} from "../utils/LlamaText.js";
44
import {tokenizeInput} from "../utils/tokenizeInput.js";
5+
import {resolveBeginningTokenToPrepend, resolveEndTokenToAppend} from "../utils/tokenizerUtils.js";
6+
import {isRankingTemplateValid, parseRankingTemplate} from "../gguf/insights/GgufInsights.js";
57
import type {LlamaModel} from "./LlamaModel/LlamaModel.js";
68
import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
7-
import type {GgufTensorInfo} from "../gguf/types/GgufTensorInfoTypes.js";
89

910
export type LlamaRankingContextOptions = {
1011
/**
1112
* The number of tokens the model can see at once.
12-
* - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size
13+
* - **`"auto"`** - adapt to the current VRAM state and attempt to set the context size as high as possible up to the size
1314
* the model was trained on.
1415
* - **`number`** - set the context size to a specific number of tokens.
1516
* If there's not enough VRAM, an error will be thrown.
1617
* Use with caution.
17-
* - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible
18+
* - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attempt to set the context size as high as possible
1819
* up to the size the model was trained on, but at least `min` and at most `max`.
1920
*
2021
* Defaults to `"auto"`.
@@ -36,6 +37,22 @@ export type LlamaRankingContextOptions = {
3637
/** An abort signal to abort the context creation */
3738
createSignal?: AbortSignal,
3839

40+
/**
41+
* The template to use for the ranking evaluation.
42+
* If not provided, the model's template will be used by default.
43+
*
44+
* The template is tokenized with special tokens enabled, but the provided query and document are not.
45+
*
46+
* **<span v-pre>`{{query}}`</span>** is replaced with the query content.
47+
*
48+
* **<span v-pre>`{{document}}`</span>** is replaced with the document content.
49+
*
50+
* It's recommended to not set this option unless you know what you're doing.
51+
*
52+
* Defaults to the model's template.
53+
*/
54+
template?: `${string}{{query}}${string}{{document}}${string}` | `${string}{{document}}${string}{{query}}${string}`,
55+
3956
/**
4057
* Ignore insufficient memory errors and continue with the context creation.
4158
* Can cause the process to crash if there's not enough VRAM for the new context.
@@ -50,17 +67,21 @@ export type LlamaRankingContextOptions = {
5067
*/
5168
export class LlamaRankingContext {
5269
/** @internal */ private readonly _llamaContext: LlamaContext;
70+
/** @internal */ private readonly _template: string | undefined;
5371
/** @internal */ private readonly _sequence: LlamaContextSequence;
5472
/** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator();
5573

5674
public readonly onDispose = new EventRelay<void>();
5775

5876
private constructor({
59-
_llamaContext
77+
_llamaContext,
78+
_template
6079
}: {
61-
_llamaContext: LlamaContext
80+
_llamaContext: LlamaContext,
81+
_template: string | undefined
6282
}) {
6383
this._llamaContext = _llamaContext;
84+
this._template = _template;
6485
this._sequence = this._llamaContext.getSequence();
6586

6687
this._disposeAggregator.add(
@@ -81,9 +102,6 @@ export class LlamaRankingContext {
81102
* @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
82103
*/
83104
public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
84-
if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
85-
throw new Error("Computing rankings is not supported for this model.");
86-
87105
const resolvedInput = this._getEvaluationInput(query, document);
88106

89107
if (resolvedInput.length > this._llamaContext.contextSize)
@@ -159,7 +177,35 @@ export class LlamaRankingContext {
159177

160178
/** @internal */
161179
private _getEvaluationInput(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
162-
if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
180+
if (this._template != null) {
181+
const resolvedInput = splitText(this._template, ["{{query}}", "{{document}}"])
182+
.flatMap((item) => {
183+
if (typeof item === "string")
184+
return this._llamaContext.model.tokenize(item, true, "trimLeadingSpace");
185+
else if (item.separator === "{{query}}")
186+
return tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
187+
else if (item.separator === "{{document}}")
188+
return tokenizeInput(document, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
189+
else
190+
void (item satisfies never);
191+
192+
void (item satisfies never);
193+
return [];
194+
});
195+
196+
const beginningTokens = resolveBeginningTokenToPrepend(this.model.vocabularyType, this.model.tokens);
197+
const endToken = resolveEndTokenToAppend(this.model.vocabularyType, this.model.tokens);
198+
199+
if (beginningTokens != null && resolvedInput.at(0) !== beginningTokens)
200+
resolvedInput.unshift(beginningTokens);
201+
202+
if (endToken != null && resolvedInput.at(-1) !== endToken)
203+
resolvedInput.unshift(endToken);
204+
205+
return resolvedInput;
206+
}
207+
208+
if (this.model.tokens.eos == null && this.model.tokens.sep == null)
163209
throw new Error("Computing rankings is not supported for this model.");
164210

165211
const resolvedQuery = tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
@@ -169,12 +215,12 @@ export class LlamaRankingContext {
169215
return [];
170216

171217
const resolvedInput = [
172-
this.model.tokens.bos,
218+
...(this.model.tokens.bos == null ? [] : [this.model.tokens.bos]),
173219
...resolvedQuery,
174-
this.model.tokens.eos,
175-
this.model.tokens.sep,
220+
...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos]),
221+
...(this.model.tokens.sep == null ? [] : [this.model.tokens.sep]),
176222
...resolvedDocument,
177-
this.model.tokens.eos
223+
...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos])
178224
];
179225

180226
return resolvedInput;
@@ -218,24 +264,27 @@ export class LlamaRankingContext {
218264
batchSize,
219265
threads = 6,
220266
createSignal,
267+
template,
221268
ignoreMemorySafetyChecks
222269
}: LlamaRankingContextOptions) {
223-
const tensorInfo = _model.fileInfo.tensorInfo;
224-
225-
if (_model.tokens.bos == null || _model.tokens.eos == null || _model.tokens.sep == null)
226-
throw new Error("Computing rankings is not supported for this model.");
227-
228-
// source: `append_pooling` in `llama.cpp`
229-
if (findLayer(tensorInfo, "cls", "weight") == null || findLayer(tensorInfo, "cls", "bias") == null)
230-
throw new Error("Computing rankings is not supported for this model.");
231-
232-
// source: `append_pooling` in `llama.cpp`
233-
if (findLayer(tensorInfo, "cls.output", "weight") != null && findLayer(tensorInfo, "cls.output", "bias") == null)
234-
throw new Error("Computing rankings is not supported for this model.");
270+
const resolvedTemplate = template ?? parseRankingTemplate(_model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"]);
271+
272+
if (_model.tokens.eos == null && _model.tokens.sep == null) {
273+
if (!isRankingTemplateValid(resolvedTemplate)) {
274+
if (resolvedTemplate === _model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"])
275+
throw new Error("The model's builtin template is invalid. It must contain both {query} and {document} placeholders.");
276+
else
277+
throw new Error("The provided template is invalid. It must contain both {{query}} and {{document}} placeholders.");
278+
} else if (resolvedTemplate == null)
279+
throw new Error("Computing rankings is not supported for this model.");
280+
}
235281

236282
if (_model.fileInsights.hasEncoder && _model.fileInsights.hasDecoder)
237283
throw new Error("Computing rankings is not supported for encoder-decoder models.");
238284

285+
if (!_model.fileInsights.supportsRanking)
286+
throw new Error("Computing rankings is not supported for this model.");
287+
239288
const llamaContext = await _model.createContext({
240289
contextSize,
241290
batchSize,
@@ -247,23 +296,12 @@ export class LlamaRankingContext {
247296
});
248297

249298
return new LlamaRankingContext({
250-
_llamaContext: llamaContext
299+
_llamaContext: llamaContext,
300+
_template: resolvedTemplate
251301
});
252302
}
253303
}
254304

255-
function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffix: string) {
256-
if (tensorInfo == null)
257-
return undefined;
258-
259-
for (const tensor of tensorInfo) {
260-
if (tensor.name === name + "." + suffix)
261-
return tensor;
262-
}
263-
264-
return undefined;
265-
}
266-
267305
function logitToSigmoid(logit: number) {
268306
return 1 / (1 + Math.exp(-logit));
269307
}

src/gguf/insights/GgufInsights.ts

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {GgufTensorInfo} from "../types/GgufTensorInfoTypes.js";
66
import {GgufArchitectureType} from "../types/GgufMetadataTypes.js";
77
import {getReadablePath} from "../../cli/utils/getReadablePath.js";
88
import {GgufInsightsConfigurationResolver} from "./GgufInsightsConfigurationResolver.js";
9+
import {GgufInsightsTokens} from "./GgufInsightsTokens.js";
910

1011
export type GgufInsightsResourceRequirements = {
1112
cpuRam: number,
@@ -16,15 +17,18 @@ export class GgufInsights {
1617
/** @internal */ public readonly _llama: Llama;
1718
/** @internal */ private readonly _modelSize: number;
1819
/** @internal */ private _totalFileLayers: number | null = null;
19-
/** @internal */ private readonly _ggufFileInfo: GgufFileInfo;
20+
/** @internal */ private _supportsRanking?: boolean;
21+
/** @internal */ public readonly _ggufFileInfo: GgufFileInfo;
2022
/** @internal */ private readonly _configurationResolver: GgufInsightsConfigurationResolver;
23+
/** @internal */ private readonly _tokens: GgufInsightsTokens;
2124

2225
private constructor(ggufFileInfo: GgufFileInfo, llama: Llama) {
2326
this._llama = llama;
2427
this._ggufFileInfo = ggufFileInfo;
2528

2629
this._modelSize = calculateTensorsSize(ggufFileInfo.fullTensorInfo ?? [], llama, true, true);
2730
this._configurationResolver = GgufInsightsConfigurationResolver._create(this);
31+
this._tokens = GgufInsightsTokens._create(this);
2832
}
2933

3034
/**
@@ -60,6 +64,10 @@ export class GgufInsights {
6064
return this._configurationResolver;
6165
}
6266

67+
public get tokens() {
68+
return this._tokens;
69+
}
70+
6371
/** The context size the model was trained on */
6472
public get trainContextSize() {
6573
return this._ggufFileInfo.architectureMetadata.context_length;
@@ -132,6 +140,29 @@ export class GgufInsights {
132140
return false;
133141
}
134142

143+
public get supportsRanking() {
144+
if (this._supportsRanking != null)
145+
return this._supportsRanking;
146+
147+
const layers = this._ggufFileInfo.fullTensorInfo ?? [];
148+
for (let i = layers.length - 1; i >= 0; i--) {
149+
const tensor = layers[i];
150+
if (tensor == null)
151+
continue;
152+
153+
if (tensor.name === "cls.weight" || tensor.name === "cls.output.weight") {
154+
this._supportsRanking = this.tokens.sepToken != null || this.tokens.eosToken != null ||
155+
isRankingTemplateValid(parseRankingTemplate(this._ggufFileInfo.metadata?.tokenizer?.["chat_template.rerank"]));
156+
this._supportsRanking &&= !(this.hasEncoder && this.hasDecoder); // encoder-decoder models are not supported
157+
158+
return this._supportsRanking;
159+
}
160+
}
161+
162+
this._supportsRanking = false;
163+
return this._supportsRanking;
164+
}
165+
135166
/**
136167
* The size of the SWA (Sliding Window Attention).
137168
*
@@ -787,3 +818,16 @@ function getSwaPatternForArchitecture(architecture?: GgufArchitectureType): numb
787818

788819
return 1;
789820
}
821+
822+
export function parseRankingTemplate(template: string | undefined | null): string | undefined {
823+
if (template == null)
824+
return undefined;
825+
826+
return template
827+
.replaceAll("{query}", "{{query}}")
828+
.replaceAll("{document}", "{{document}}");
829+
}
830+
831+
export function isRankingTemplateValid(template: string | undefined | null): boolean {
832+
return template != null && template.includes("{{query}}") && template.includes("{{document}}");
833+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* eslint @stylistic/max-statements-per-line: ["warn", {"ignoredNodes": ["BreakStatement"]}] */
2+
import type {GgufInsights} from "./GgufInsights.js";
3+
4+
export class GgufInsightsTokens {
5+
/** @internal */ private readonly _ggufInsights: GgufInsights;
6+
7+
private constructor(ggufInsights: GgufInsights) {
8+
this._ggufInsights = ggufInsights;
9+
}
10+
11+
public get sepToken(): number | null {
12+
const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model;
13+
const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length;
14+
15+
let sepTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["seperator_token_id"];
16+
if (sepTokenId == null && tokenizerModel === "bert") {
17+
sepTokenId = 102; // source: `llama_vocab::impl::load` in `llama-vocab.cpp`
18+
}
19+
20+
if (totalTokens != null && sepTokenId != null && sepTokenId >= totalTokens)
21+
return null;
22+
23+
return sepTokenId ?? null;
24+
}
25+
26+
public get eosToken(): number | null {
27+
const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model;
28+
const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length;
29+
30+
const eosTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["eos_token_id"];
31+
if (eosTokenId != null && totalTokens != null && eosTokenId < totalTokens)
32+
return eosTokenId;
33+
34+
switch (tokenizerModel) {
35+
case "no_vocab": return null;
36+
case "none": return null;
37+
case "bert": return null;
38+
case "rwkv": return null;
39+
case "llama": return 2;
40+
case "gpt2": return 11;
41+
case "t5": return 1;
42+
case "plamo2": return 2;
43+
}
44+
return 2; // source: `llama_vocab::impl::load` in `llama-vocab.cpp`
45+
}
46+
47+
/** @internal */
48+
public static _create(ggufInsights: GgufInsights) {
49+
return new GgufInsightsTokens(ggufInsights);
50+
}
51+
}

src/gguf/types/GgufMetadataTypes.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ export const enum GgufMetadataTokenizerTokenType {
263263

264264
export type GgufMetadataTokenizer = {
265265
readonly ggml: {
266-
readonly model: "no_vocab" | "llama" | "gpt2" | "bert" | string,
266+
readonly model: "no_vocab" | "none" | "llama" | "gpt2" | "bert" | "rwkv" | "t5" | "plamo2" | string,
267267
readonly pre?: "default" | "llama3" | "llama-v3" | "llama-bpe" | "deepseek-llm" | "deepseek-coder" | "falcon" | "falcon3" |
268268
"pixtral" | "mpt" | "starcoder" | "gpt-2" | "phi-2" | "jina-es" | "jina-de" | "jina-v1-en" | "jina-v2-es" | "jina-v2-de" |
269269
"jina-v2-code" | "refact" | "command-r" | "qwen2" | "stablelm2" | "olmo" | "dbrx" | "smaug-bpe" | "poro-chat" | "chatglm-bpe" |
@@ -279,7 +279,7 @@ export type GgufMetadataTokenizer = {
279279
readonly eot_token_id?: number,
280280
readonly eom_token_id?: number,
281281
readonly unknown_token_id?: number,
282-
readonly separator_token_id?: number,
282+
readonly seperator_token_id?: number,
283283
readonly padding_token_id?: number,
284284
readonly cls_token_id?: number,
285285
readonly mask_token_id?: number,
@@ -304,7 +304,8 @@ export type GgufMetadataTokenizer = {
304304
readonly huggingface?: {
305305
readonly json?: string
306306
},
307-
readonly chat_template?: string
307+
readonly chat_template?: string,
308+
readonly "chat_template.rerank"?: string
308309
};
309310

310311
export const enum GgufMetadataArchitecturePoolingType {

0 commit comments

Comments
 (0)