1- import { AsyncDisposeAggregator , EventRelay , withLock } from "lifecycle-utils" ;
1+ import { AsyncDisposeAggregator , EventRelay , splitText , withLock } from "lifecycle-utils" ;
22import { Token } from "../types.js" ;
33import { LlamaText } from "../utils/LlamaText.js" ;
44import { tokenizeInput } from "../utils/tokenizeInput.js" ;
5+ import { resolveBeginningTokenToPrepend , resolveEndTokenToAppend } from "../utils/tokenizerUtils.js" ;
6+ import { isRankingTemplateValid , parseRankingTemplate } from "../gguf/insights/GgufInsights.js" ;
57import type { LlamaModel } from "./LlamaModel/LlamaModel.js" ;
68import type { LlamaContext , LlamaContextSequence } from "./LlamaContext/LlamaContext.js" ;
7- import type { GgufTensorInfo } from "../gguf/types/GgufTensorInfoTypes.js" ;
89
910export type LlamaRankingContextOptions = {
1011 /**
1112 * The number of tokens the model can see at once.
12- * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size
13+ * - **`"auto"`** - adapt to the current VRAM state and attempt to set the context size as high as possible up to the size
1314 * the model was trained on.
1415 * - **`number`** - set the context size to a specific number of tokens.
1516 * If there's not enough VRAM, an error will be thrown.
1617 * Use with caution.
17- * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible
18+ * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attempt to set the context size as high as possible
1819 * up to the size the model was trained on, but at least `min` and at most `max`.
1920 *
2021 * Defaults to `"auto"`.
@@ -36,6 +37,22 @@ export type LlamaRankingContextOptions = {
3637 /** An abort signal to abort the context creation */
3738 createSignal ?: AbortSignal ,
3839
40+ /**
41+ * The template to use for the ranking evaluation.
42+ * If not provided, the model's template will be used by default.
43+ *
44+ * The template is tokenized with special tokens enabled, but the provided query and document are not.
45+ *
46+ * **<span v-pre>`{{query}}`</span>** is replaced with the query content.
47+ *
48+ * **<span v-pre>`{{document}}`</span>** is replaced with the document content.
49+ *
50+ * It's recommended to not set this option unless you know what you're doing.
51+ *
52+ * Defaults to the model's template.
53+ */
54+ template ?: `${string } {{query}}${string } {{document}}${string } ` | `${string } {{document}}${string } {{query}}${string } `,
55+
3956 /**
4057 * Ignore insufficient memory errors and continue with the context creation.
4158 * Can cause the process to crash if there's not enough VRAM for the new context.
@@ -50,17 +67,21 @@ export type LlamaRankingContextOptions = {
5067 */
5168export class LlamaRankingContext {
5269 /** @internal */ private readonly _llamaContext : LlamaContext ;
70+ /** @internal */ private readonly _template : string | undefined ;
5371 /** @internal */ private readonly _sequence : LlamaContextSequence ;
5472 /** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator ( ) ;
5573
5674 public readonly onDispose = new EventRelay < void > ( ) ;
5775
5876 private constructor ( {
59- _llamaContext
77+ _llamaContext,
78+ _template
6079 } : {
61- _llamaContext : LlamaContext
80+ _llamaContext : LlamaContext ,
81+ _template : string | undefined
6282 } ) {
6383 this . _llamaContext = _llamaContext ;
84+ this . _template = _template ;
6485 this . _sequence = this . _llamaContext . getSequence ( ) ;
6586
6687 this . _disposeAggregator . add (
@@ -81,9 +102,6 @@ export class LlamaRankingContext {
81102 * @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
82103 */
83104 public async rank ( query : Token [ ] | string | LlamaText , document : Token [ ] | string | LlamaText ) {
84- if ( this . model . tokens . bos == null || this . model . tokens . eos == null || this . model . tokens . sep == null )
85- throw new Error ( "Computing rankings is not supported for this model." ) ;
86-
87105 const resolvedInput = this . _getEvaluationInput ( query , document ) ;
88106
89107 if ( resolvedInput . length > this . _llamaContext . contextSize )
@@ -159,7 +177,35 @@ export class LlamaRankingContext {
159177
160178 /** @internal */
161179 private _getEvaluationInput ( query : Token [ ] | string | LlamaText , document : Token [ ] | string | LlamaText ) {
162- if ( this . model . tokens . bos == null || this . model . tokens . eos == null || this . model . tokens . sep == null )
180+ if ( this . _template != null ) {
181+ const resolvedInput = splitText ( this . _template , [ "{{query}}" , "{{document}}" ] )
182+ . flatMap ( ( item ) => {
183+ if ( typeof item === "string" )
184+ return this . _llamaContext . model . tokenize ( item , true , "trimLeadingSpace" ) ;
185+ else if ( item . separator === "{{query}}" ) {
186+ return tokenizeInput ( query , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
187+ } else if ( item . separator === "{{document}}" ) {
188+ return tokenizeInput ( document , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
189+ } else
190+ void ( item satisfies never ) ;
191+
192+ void ( item satisfies never ) ;
193+ return [ ] ;
194+ } ) ;
195+
196+ const beginningTokens = resolveBeginningTokenToPrepend ( this . model . vocabularyType , this . model . tokens ) ;
197+ const endToken = resolveEndTokenToAppend ( this . model . vocabularyType , this . model . tokens ) ;
198+
199+ if ( beginningTokens != null && resolvedInput . at ( 0 ) !== beginningTokens )
200+ resolvedInput . unshift ( beginningTokens ) ;
201+
202+ if ( endToken != null && resolvedInput . at ( - 1 ) !== endToken )
203+ resolvedInput . unshift ( endToken ) ;
204+
205+ return resolvedInput ;
206+ }
207+
208+ if ( this . model . tokens . eos == null && this . model . tokens . sep == null )
163209 throw new Error ( "Computing rankings is not supported for this model." ) ;
164210
165211 const resolvedQuery = tokenizeInput ( query , this . _llamaContext . model . tokenizer , "trimLeadingSpace" , false ) ;
@@ -169,12 +215,12 @@ export class LlamaRankingContext {
169215 return [ ] ;
170216
171217 const resolvedInput = [
172- this . model . tokens . bos ,
218+ ... ( this . model . tokens . bos == null ? [ ] : [ this . model . tokens . bos ] ) ,
173219 ...resolvedQuery ,
174- this . model . tokens . eos ,
175- this . model . tokens . sep ,
220+ ... ( this . model . tokens . eos == null ? [ ] : [ this . model . tokens . eos ] ) ,
221+ ... ( this . model . tokens . sep == null ? [ ] : [ this . model . tokens . sep ] ) ,
176222 ...resolvedDocument ,
177- this . model . tokens . eos
223+ ... ( this . model . tokens . eos == null ? [ ] : [ this . model . tokens . eos ] )
178224 ] ;
179225
180226 return resolvedInput ;
@@ -218,24 +264,27 @@ export class LlamaRankingContext {
218264 batchSize,
219265 threads = 6 ,
220266 createSignal,
267+ template,
221268 ignoreMemorySafetyChecks
222269 } : LlamaRankingContextOptions ) {
223- const tensorInfo = _model . fileInfo . tensorInfo ;
224-
225- if ( _model . tokens . bos == null || _model . tokens . eos == null || _model . tokens . sep == null )
226- throw new Error ( "Computing rankings is not supported for this model." ) ;
227-
228- // source: `append_pooling` in `llama.cpp`
229- if ( findLayer ( tensorInfo , "cls" , "weight" ) == null || findLayer ( tensorInfo , "cls" , "bias" ) == null )
230- throw new Error ( "Computing rankings is not supported for this model." ) ;
231-
232- // source: `append_pooling` in `llama.cpp`
233- if ( findLayer ( tensorInfo , "cls.output" , "weight" ) != null && findLayer ( tensorInfo , "cls.output" , "bias" ) == null )
234- throw new Error ( "Computing rankings is not supported for this model." ) ;
270+ const resolvedTemplate = template ?? parseRankingTemplate ( _model . fileInfo . metadata ?. tokenizer ?. [ "chat_template.rerank" ] ) ;
271+
272+ if ( _model . tokens . eos == null && _model . tokens . sep == null ) {
273+ if ( ! isRankingTemplateValid ( resolvedTemplate ) ) {
274+ if ( resolvedTemplate === _model . fileInfo . metadata ?. tokenizer ?. [ "chat_template.rerank" ] )
275+ throw new Error ( "The model's builtin template is invalid. It must contain both {query} and {document} placeholders." ) ;
276+ else
277+ throw new Error ( "The provided template is invalid. It must contain both {{query}} and {{document}} placeholders." ) ;
278+ } else if ( resolvedTemplate == null )
279+ throw new Error ( "Computing rankings is not supported for this model." ) ;
280+ }
235281
236282 if ( _model . fileInsights . hasEncoder && _model . fileInsights . hasDecoder )
237283 throw new Error ( "Computing rankings is not supported for encoder-decoder models." ) ;
238284
285+ if ( ! _model . fileInsights . supportsRanking )
286+ throw new Error ( "Computing rankings is not supported for this model." ) ;
287+
239288 const llamaContext = await _model . createContext ( {
240289 contextSize,
241290 batchSize,
@@ -247,23 +296,12 @@ export class LlamaRankingContext {
247296 } ) ;
248297
249298 return new LlamaRankingContext ( {
250- _llamaContext : llamaContext
299+ _llamaContext : llamaContext ,
300+ _template : resolvedTemplate
251301 } ) ;
252302 }
253303}
254304
255- function findLayer ( tensorInfo : GgufTensorInfo [ ] | undefined , name : string , suffix : string ) {
256- if ( tensorInfo == null )
257- return undefined ;
258-
259- for ( const tensor of tensorInfo ) {
260- if ( tensor . name === name + "." + suffix )
261- return tensor ;
262- }
263-
264- return undefined ;
265- }
266-
267305function logitToSigmoid ( logit : number ) {
268306 return 1 / ( 1 + Math . exp ( - logit ) ) ;
269307}
0 commit comments