1
- import { AsyncDisposeAggregator , EventRelay , withLock } from "lifecycle-utils" ;
1
+ import { AsyncDisposeAggregator , EventRelay , splitText , withLock } from "lifecycle-utils" ;
2
2
import { Token } from "../types.js" ;
3
3
import { LlamaText } from "../utils/LlamaText.js" ;
4
4
import { tokenizeInput } from "../utils/tokenizeInput.js" ;
5
+ import { resolveBeginningTokenToPrepend , resolveEndTokenToAppend } from "../utils/tokenizerUtils.js" ;
6
+ import { isRankingTemplateValid , parseRankingTemplate } from "../gguf/insights/GgufInsights.js" ;
5
7
import type { LlamaModel } from "./LlamaModel/LlamaModel.js" ;
6
8
import type { LlamaContext , LlamaContextSequence } from "./LlamaContext/LlamaContext.js" ;
7
- import type { GgufTensorInfo } from "../gguf/types/GgufTensorInfoTypes.js" ;
8
9
9
10
export type LlamaRankingContextOptions = {
10
11
/**
11
12
* The number of tokens the model can see at once.
12
- * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size
13
+ * - **`"auto"`** - adapt to the current VRAM state and attempt to set the context size as high as possible up to the size
13
14
* the model was trained on.
14
15
* - **`number`** - set the context size to a specific number of tokens.
15
16
* If there's not enough VRAM, an error will be thrown.
16
17
* Use with caution.
17
- * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible
18
+ * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attempt to set the context size as high as possible
18
19
* up to the size the model was trained on, but at least `min` and at most `max`.
19
20
*
20
21
* Defaults to `"auto"`.
@@ -36,6 +37,22 @@ export type LlamaRankingContextOptions = {
36
37
/** An abort signal to abort the context creation */
37
38
createSignal ?: AbortSignal ,
38
39
40
+ /**
41
+ * The template to use for the ranking evaluation.
42
+ * If not provided, the model's template will be used by default.
43
+ *
44
+ * The template is tokenized with special tokens enabled, but the provided query and document are not.
45
+ *
46
+ * **<span v-pre>`{{query}}`</span>** is replaced with the query content.
47
+ *
48
+ * **<span v-pre>`{{document}}`</span>** is replaced with the document content.
49
+ *
50
+ * It's recommended to not set this option unless you know what you're doing.
51
+ *
52
+ * Defaults to the model's template.
53
+ */
54
+ template ?: `${string } {{query}}${string } {{document}}${string } ` | `${string } {{document}}${string } {{query}}${string } `,
55
+
39
56
/**
40
57
* Ignore insufficient memory errors and continue with the context creation.
41
58
* Can cause the process to crash if there's not enough VRAM for the new context.
@@ -50,17 +67,21 @@ export type LlamaRankingContextOptions = {
50
67
*/
51
68
export class LlamaRankingContext {
52
69
/** @internal */ private readonly _llamaContext : LlamaContext ;
70
+ /** @internal */ private readonly _template : string | undefined ;
53
71
/** @internal */ private readonly _sequence : LlamaContextSequence ;
54
72
/** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator ( ) ;
55
73
56
74
public readonly onDispose = new EventRelay < void > ( ) ;
57
75
58
76
private constructor ( {
59
- _llamaContext
77
+ _llamaContext,
78
+ _template
60
79
} : {
61
- _llamaContext : LlamaContext
80
+ _llamaContext : LlamaContext ,
81
+ _template : string | undefined
62
82
} ) {
63
83
this . _llamaContext = _llamaContext ;
84
+ this . _template = _template ;
64
85
this . _sequence = this . _llamaContext . getSequence ( ) ;
65
86
66
87
this . _disposeAggregator . add (
@@ -81,9 +102,6 @@ export class LlamaRankingContext {
81
102
* @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
82
103
*/
83
104
public async rank ( query : Token [ ] | string | LlamaText , document : Token [ ] | string | LlamaText ) {
84
- if ( this . model . tokens . bos == null || this . model . tokens . eos == null || this . model . tokens . sep == null )
85
- throw new Error ( "Computing rankings is not supported for this model." ) ;
86
-
87
105
const resolvedInput = this . _getEvaluationInput ( query , document ) ;
88
106
89
107
if ( resolvedInput . length > this . _llamaContext . contextSize )
@@ -159,7 +177,35 @@ export class LlamaRankingContext {
159
177
160
178
/** @internal */
161
179
private _getEvaluationInput ( query : Token [ ] | string | LlamaText , document : Token [ ] | string | LlamaText ) {
162
- if ( this . model . tokens . bos == null || this . model . tokens . eos == null || this . model . tokens . sep == null )
180
+ if ( this . _template != null ) {
181
+ const resolvedInput = splitText ( this . _template , [ "{{query}}" , "{{document}}" ] )
182
+ . flatMap ( ( item ) => {
183
+ if ( typeof item === "string" )
184
+ return this . _llamaContext . model . tokenize ( item , true , "trimLeadingSpace" ) ;
185
+ else if ( item . separator === "{{query}}" )
186
+ return tokenizeInput ( query , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
187
+ else if ( item . separator === "{{document}}" )
188
+ return tokenizeInput ( document , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
189
+ else
190
+ void ( item satisfies never ) ;
191
+
192
+ void ( item satisfies never ) ;
193
+ return [ ] ;
194
+ } ) ;
195
+
196
+ const beginningTokens = resolveBeginningTokenToPrepend ( this . model . vocabularyType , this . model . tokens ) ;
197
+ const endToken = resolveEndTokenToAppend ( this . model . vocabularyType , this . model . tokens ) ;
198
+
199
+ if ( beginningTokens != null && resolvedInput . at ( 0 ) !== beginningTokens )
200
+ resolvedInput . unshift ( beginningTokens ) ;
201
+
202
+ if ( endToken != null && resolvedInput . at ( - 1 ) !== endToken )
203
+ resolvedInput . unshift ( endToken ) ;
204
+
205
+ return resolvedInput ;
206
+ }
207
+
208
+ if ( this . model . tokens . eos == null && this . model . tokens . sep == null )
163
209
throw new Error ( "Computing rankings is not supported for this model." ) ;
164
210
165
211
const resolvedQuery = tokenizeInput ( query , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
@@ -169,12 +215,12 @@ export class LlamaRankingContext {
169
215
return [ ] ;
170
216
171
217
const resolvedInput = [
172
- this . model . tokens . bos ,
218
+ ... ( this . model . tokens . bos == null ? [ ] : [ this . model . tokens . bos ] ) ,
173
219
...resolvedQuery ,
174
- this . model . tokens . eos ,
175
- this . model . tokens . sep ,
220
+ ... ( this . model . tokens . eos == null ? [ ] : [ this . model . tokens . eos ] ) ,
221
+ ... ( this . model . tokens . sep == null ? [ ] : [ this . model . tokens . sep ] ) ,
176
222
...resolvedDocument ,
177
- this . model . tokens . eos
223
+ ... ( this . model . tokens . eos == null ? [ ] : [ this . model . tokens . eos ] )
178
224
] ;
179
225
180
226
return resolvedInput ;
@@ -218,24 +264,27 @@ export class LlamaRankingContext {
218
264
batchSize,
219
265
threads = 6 ,
220
266
createSignal,
267
+ template,
221
268
ignoreMemorySafetyChecks
222
269
} : LlamaRankingContextOptions ) {
223
- const tensorInfo = _model . fileInfo . tensorInfo ;
224
-
225
- if ( _model . tokens . bos == null || _model . tokens . eos == null || _model . tokens . sep == null )
226
- throw new Error ( "Computing rankings is not supported for this model." ) ;
227
-
228
- // source: `append_pooling` in `llama.cpp`
229
- if ( findLayer ( tensorInfo , "cls" , "weight" ) == null || findLayer ( tensorInfo , "cls" , "bias" ) == null )
230
- throw new Error ( "Computing rankings is not supported for this model." ) ;
231
-
232
- // source: `append_pooling` in `llama.cpp`
233
- if ( findLayer ( tensorInfo , "cls.output" , "weight" ) != null && findLayer ( tensorInfo , "cls.output" , "bias" ) == null )
234
- throw new Error ( "Computing rankings is not supported for this model." ) ;
270
+ const resolvedTemplate = template ?? parseRankingTemplate ( _model . fileInfo . metadata ?. tokenizer ?. [ "chat_template.rerank" ] ) ;
271
+
272
+ if ( _model . tokens . eos == null && _model . tokens . sep == null ) {
273
+ if ( ! isRankingTemplateValid ( resolvedTemplate ) ) {
274
+ if ( resolvedTemplate === _model . fileInfo . metadata ?. tokenizer ?. [ "chat_template.rerank" ] )
275
+ throw new Error ( "The model's builtin template is invalid. It must contain both {query} and {document} placeholders." ) ;
276
+ else
277
+ throw new Error ( "The provided template is invalid. It must contain both {{query}} and {{document}} placeholders." ) ;
278
+ } else if ( resolvedTemplate == null )
279
+ throw new Error ( "Computing rankings is not supported for this model." ) ;
280
+ }
235
281
236
282
if ( _model . fileInsights . hasEncoder && _model . fileInsights . hasDecoder )
237
283
throw new Error ( "Computing rankings is not supported for encoder-decoder models." ) ;
238
284
285
+ if ( ! _model . fileInsights . supportsRanking )
286
+ throw new Error ( "Computing rankings is not supported for this model." ) ;
287
+
239
288
const llamaContext = await _model . createContext ( {
240
289
contextSize,
241
290
batchSize,
@@ -247,23 +296,12 @@ export class LlamaRankingContext {
247
296
} ) ;
248
297
249
298
return new LlamaRankingContext ( {
250
- _llamaContext : llamaContext
299
+ _llamaContext : llamaContext ,
300
+ _template : resolvedTemplate
251
301
} ) ;
252
302
}
253
303
}
254
304
255
- function findLayer ( tensorInfo : GgufTensorInfo [ ] | undefined , name : string , suffix : string ) {
256
- if ( tensorInfo == null )
257
- return undefined ;
258
-
259
- for ( const tensor of tensorInfo ) {
260
- if ( tensor . name === name + "." + suffix )
261
- return tensor ;
262
- }
263
-
264
- return undefined ;
265
- }
266
-
267
305
function logitToSigmoid ( logit : number ) {
268
306
return 1 / ( 1 + Math . exp ( - logit ) ) ;
269
307
}
0 commit comments