diff --git a/src/bindings/utils/compileLLamaCpp.ts b/src/bindings/utils/compileLLamaCpp.ts
index b88c9b5f..fe56d3e4 100644
--- a/src/bindings/utils/compileLLamaCpp.ts
+++ b/src/bindings/utils/compileLLamaCpp.ts
@@ -131,9 +131,14 @@ export async function compileLlamaCpp(buildOptions: BuildOptions, compileOptions
if (!cmakeCustomOptions.has("GGML_CCACHE"))
cmakeCustomOptions.set("GGML_CCACHE", "OFF");
- if (!cmakeCustomOptions.has("LLAMA_CURL"))
+ if (!cmakeCustomOptions.has("LLAMA_CURL") || isCmakeValueOff(cmakeCustomOptions.get("LLAMA_CURL"))) {
cmakeCustomOptions.set("LLAMA_CURL", "OFF");
+ // avoid linking to extra libraries that we don't use
+ if (!cmakeCustomOptions.has("LLAMA_OPENSSL"))
+ cmakeCustomOptions.set("LLAMA_OPENSSL", "OFF");
+ }
+
if (buildOptions.platform === "win" && buildOptions.arch === "arm64" && !cmakeCustomOptions.has("GGML_OPENMP"))
cmakeCustomOptions.set("GGML_OPENMP", "OFF");
diff --git a/src/evaluator/LlamaRankingContext.ts b/src/evaluator/LlamaRankingContext.ts
index 71ee32e9..e7e7d0ff 100644
--- a/src/evaluator/LlamaRankingContext.ts
+++ b/src/evaluator/LlamaRankingContext.ts
@@ -1,20 +1,21 @@
-import {AsyncDisposeAggregator, EventRelay, withLock} from "lifecycle-utils";
+import {AsyncDisposeAggregator, EventRelay, splitText, withLock} from "lifecycle-utils";
import {Token} from "../types.js";
import {LlamaText} from "../utils/LlamaText.js";
import {tokenizeInput} from "../utils/tokenizeInput.js";
+import {resolveBeginningTokenToPrepend, resolveEndTokenToAppend} from "../utils/tokenizerUtils.js";
+import {isRankingTemplateValid, parseRankingTemplate} from "../gguf/insights/GgufInsights.js";
import type {LlamaModel} from "./LlamaModel/LlamaModel.js";
import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
-import type {GgufTensorInfo} from "../gguf/types/GgufTensorInfoTypes.js";
export type LlamaRankingContextOptions = {
/**
* The number of tokens the model can see at once.
- * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size
+ * - **`"auto"`** - adapt to the current VRAM state and attempt to set the context size as high as possible up to the size
* the model was trained on.
* - **`number`** - set the context size to a specific number of tokens.
* If there's not enough VRAM, an error will be thrown.
* Use with caution.
- * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible
+ * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attempt to set the context size as high as possible
* up to the size the model was trained on, but at least `min` and at most `max`.
*
* Defaults to `"auto"`.
@@ -36,6 +37,22 @@ export type LlamaRankingContextOptions = {
/** An abort signal to abort the context creation */
createSignal?: AbortSignal,
+ /**
+ * The template to use for the ranking evaluation.
+ * If not provided, the model's template will be used by default.
+ *
+ * The template is tokenized with special tokens enabled, but the provided query and document are not.
+ *
+ * **`{{query}}`** is replaced with the query content.
+ *
+ * **`{{document}}`** is replaced with the document content.
+ *
+ * It's recommended to not set this option unless you know what you're doing.
+ *
+ * Defaults to the model's template.
+ */
+ template?: `${string}{{query}}${string}{{document}}${string}` | `${string}{{document}}${string}{{query}}${string}`,
+
/**
* Ignore insufficient memory errors and continue with the context creation.
* Can cause the process to crash if there's not enough VRAM for the new context.
@@ -50,17 +67,21 @@ export type LlamaRankingContextOptions = {
*/
export class LlamaRankingContext {
/** @internal */ private readonly _llamaContext: LlamaContext;
+ /** @internal */ private readonly _template: string | undefined;
/** @internal */ private readonly _sequence: LlamaContextSequence;
/** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator();
public readonly onDispose = new EventRelay();
private constructor({
- _llamaContext
+ _llamaContext,
+ _template
}: {
- _llamaContext: LlamaContext
+ _llamaContext: LlamaContext,
+ _template: string | undefined
}) {
this._llamaContext = _llamaContext;
+ this._template = _template;
this._sequence = this._llamaContext.getSequence();
this._disposeAggregator.add(
@@ -81,9 +102,6 @@ export class LlamaRankingContext {
* @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query.
*/
public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
- if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
- throw new Error("Computing rankings is not supported for this model.");
-
const resolvedInput = this._getEvaluationInput(query, document);
if (resolvedInput.length > this._llamaContext.contextSize)
@@ -159,7 +177,35 @@ export class LlamaRankingContext {
/** @internal */
private _getEvaluationInput(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) {
- if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null)
+ if (this._template != null) {
+ const resolvedInput = splitText(this._template, ["{{query}}", "{{document}}"])
+ .flatMap((item) => {
+ if (typeof item === "string")
+ return this._llamaContext.model.tokenize(item, true, "trimLeadingSpace");
+ else if (item.separator === "{{query}}")
+ return tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
+ else if (item.separator === "{{document}}")
+ return tokenizeInput(document, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
+ else
+ void (item satisfies never);
+
+ void (item satisfies never);
+ return [];
+ });
+
+ const beginningTokens = resolveBeginningTokenToPrepend(this.model.vocabularyType, this.model.tokens);
+ const endToken = resolveEndTokenToAppend(this.model.vocabularyType, this.model.tokens);
+
+ if (beginningTokens != null && resolvedInput.at(0) !== beginningTokens)
+ resolvedInput.unshift(beginningTokens);
+
+ if (endToken != null && resolvedInput.at(-1) !== endToken)
+ resolvedInput.unshift(endToken);
+
+ return resolvedInput;
+ }
+
+ if (this.model.tokens.eos == null && this.model.tokens.sep == null)
throw new Error("Computing rankings is not supported for this model.");
const resolvedQuery = tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false);
@@ -169,12 +215,12 @@ export class LlamaRankingContext {
return [];
const resolvedInput = [
- this.model.tokens.bos,
+ ...(this.model.tokens.bos == null ? [] : [this.model.tokens.bos]),
...resolvedQuery,
- this.model.tokens.eos,
- this.model.tokens.sep,
+ ...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos]),
+ ...(this.model.tokens.sep == null ? [] : [this.model.tokens.sep]),
...resolvedDocument,
- this.model.tokens.eos
+ ...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos])
];
return resolvedInput;
@@ -218,24 +264,27 @@ export class LlamaRankingContext {
batchSize,
threads = 6,
createSignal,
+ template,
ignoreMemorySafetyChecks
}: LlamaRankingContextOptions) {
- const tensorInfo = _model.fileInfo.tensorInfo;
-
- if (_model.tokens.bos == null || _model.tokens.eos == null || _model.tokens.sep == null)
- throw new Error("Computing rankings is not supported for this model.");
-
- // source: `append_pooling` in `llama.cpp`
- if (findLayer(tensorInfo, "cls", "weight") == null || findLayer(tensorInfo, "cls", "bias") == null)
- throw new Error("Computing rankings is not supported for this model.");
-
- // source: `append_pooling` in `llama.cpp`
- if (findLayer(tensorInfo, "cls.output", "weight") != null && findLayer(tensorInfo, "cls.output", "bias") == null)
- throw new Error("Computing rankings is not supported for this model.");
+ const resolvedTemplate = template ?? parseRankingTemplate(_model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"]);
+
+ if (_model.tokens.eos == null && _model.tokens.sep == null) {
+ if (!isRankingTemplateValid(resolvedTemplate)) {
+ if (resolvedTemplate === _model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"])
+ throw new Error("The model's builtin template is invalid. It must contain both {query} and {document} placeholders.");
+ else
+ throw new Error("The provided template is invalid. It must contain both {{query}} and {{document}} placeholders.");
+ } else if (resolvedTemplate == null)
+ throw new Error("Computing rankings is not supported for this model.");
+ }
if (_model.fileInsights.hasEncoder && _model.fileInsights.hasDecoder)
throw new Error("Computing rankings is not supported for encoder-decoder models.");
+ if (!_model.fileInsights.supportsRanking)
+ throw new Error("Computing rankings is not supported for this model.");
+
const llamaContext = await _model.createContext({
contextSize,
batchSize,
@@ -247,23 +296,12 @@ export class LlamaRankingContext {
});
return new LlamaRankingContext({
- _llamaContext: llamaContext
+ _llamaContext: llamaContext,
+ _template: resolvedTemplate
});
}
}
-function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffix: string) {
- if (tensorInfo == null)
- return undefined;
-
- for (const tensor of tensorInfo) {
- if (tensor.name === name + "." + suffix)
- return tensor;
- }
-
- return undefined;
-}
-
function logitToSigmoid(logit: number) {
return 1 / (1 + Math.exp(-logit));
}
diff --git a/src/gguf/insights/GgufInsights.ts b/src/gguf/insights/GgufInsights.ts
index f32ceb0d..73eb94a6 100644
--- a/src/gguf/insights/GgufInsights.ts
+++ b/src/gguf/insights/GgufInsights.ts
@@ -6,6 +6,7 @@ import {GgufTensorInfo} from "../types/GgufTensorInfoTypes.js";
import {GgufArchitectureType} from "../types/GgufMetadataTypes.js";
import {getReadablePath} from "../../cli/utils/getReadablePath.js";
import {GgufInsightsConfigurationResolver} from "./GgufInsightsConfigurationResolver.js";
+import {GgufInsightsTokens} from "./GgufInsightsTokens.js";
export type GgufInsightsResourceRequirements = {
cpuRam: number,
@@ -16,8 +17,10 @@ export class GgufInsights {
/** @internal */ public readonly _llama: Llama;
/** @internal */ private readonly _modelSize: number;
/** @internal */ private _totalFileLayers: number | null = null;
- /** @internal */ private readonly _ggufFileInfo: GgufFileInfo;
+ /** @internal */ private _supportsRanking?: boolean;
+ /** @internal */ public readonly _ggufFileInfo: GgufFileInfo;
/** @internal */ private readonly _configurationResolver: GgufInsightsConfigurationResolver;
+ /** @internal */ private readonly _tokens: GgufInsightsTokens;
private constructor(ggufFileInfo: GgufFileInfo, llama: Llama) {
this._llama = llama;
@@ -25,6 +28,7 @@ export class GgufInsights {
this._modelSize = calculateTensorsSize(ggufFileInfo.fullTensorInfo ?? [], llama, true, true);
this._configurationResolver = GgufInsightsConfigurationResolver._create(this);
+ this._tokens = GgufInsightsTokens._create(this);
}
/**
@@ -60,6 +64,10 @@ export class GgufInsights {
return this._configurationResolver;
}
+ public get tokens() {
+ return this._tokens;
+ }
+
/** The context size the model was trained on */
public get trainContextSize() {
return this._ggufFileInfo.architectureMetadata.context_length;
@@ -132,6 +140,29 @@ export class GgufInsights {
return false;
}
+ public get supportsRanking() {
+ if (this._supportsRanking != null)
+ return this._supportsRanking;
+
+ const layers = this._ggufFileInfo.fullTensorInfo ?? [];
+ for (let i = layers.length - 1; i >= 0; i--) {
+ const tensor = layers[i];
+ if (tensor == null)
+ continue;
+
+ if (tensor.name === "cls.weight" || tensor.name === "cls.output.weight") {
+ this._supportsRanking = this.tokens.sepToken != null || this.tokens.eosToken != null ||
+ isRankingTemplateValid(parseRankingTemplate(this._ggufFileInfo.metadata?.tokenizer?.["chat_template.rerank"]));
+ this._supportsRanking &&= !(this.hasEncoder && this.hasDecoder); // encoder-decoder models are not supported
+
+ return this._supportsRanking;
+ }
+ }
+
+ this._supportsRanking = false;
+ return this._supportsRanking;
+ }
+
/**
* The size of the SWA (Sliding Window Attention).
*
@@ -787,3 +818,16 @@ function getSwaPatternForArchitecture(architecture?: GgufArchitectureType): numb
return 1;
}
+
+export function parseRankingTemplate(template: string | undefined | null): string | undefined {
+ if (template == null)
+ return undefined;
+
+ return template
+ .replaceAll("{query}", "{{query}}")
+ .replaceAll("{document}", "{{document}}");
+}
+
+export function isRankingTemplateValid(template: string | undefined | null): boolean {
+ return template != null && template.includes("{{query}}") && template.includes("{{document}}");
+}
diff --git a/src/gguf/insights/GgufInsightsTokens.ts b/src/gguf/insights/GgufInsightsTokens.ts
new file mode 100644
index 00000000..70727413
--- /dev/null
+++ b/src/gguf/insights/GgufInsightsTokens.ts
@@ -0,0 +1,51 @@
+/* eslint @stylistic/max-statements-per-line: ["warn", {"ignoredNodes": ["BreakStatement"]}] */
+import type {GgufInsights} from "./GgufInsights.js";
+
+export class GgufInsightsTokens {
+ /** @internal */ private readonly _ggufInsights: GgufInsights;
+
+ private constructor(ggufInsights: GgufInsights) {
+ this._ggufInsights = ggufInsights;
+ }
+
+ public get sepToken(): number | null {
+ const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model;
+ const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length;
+
+ let sepTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["seperator_token_id"];
+ if (sepTokenId == null && tokenizerModel === "bert") {
+ sepTokenId = 102; // source: `llama_vocab::impl::load` in `llama-vocab.cpp`
+ }
+
+ if (totalTokens != null && sepTokenId != null && sepTokenId >= totalTokens)
+ return null;
+
+ return sepTokenId ?? null;
+ }
+
+ public get eosToken(): number | null {
+ const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model;
+ const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length;
+
+ const eosTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["eos_token_id"];
+ if (eosTokenId != null && totalTokens != null && eosTokenId < totalTokens)
+ return eosTokenId;
+
+ switch (tokenizerModel) {
+ case "no_vocab": return null;
+ case "none": return null;
+ case "bert": return null;
+ case "rwkv": return null;
+ case "llama": return 2;
+ case "gpt2": return 11;
+ case "t5": return 1;
+ case "plamo2": return 2;
+ }
+ return 2; // source: `llama_vocab::impl::load` in `llama-vocab.cpp`
+ }
+
+ /** @internal */
+ public static _create(ggufInsights: GgufInsights) {
+ return new GgufInsightsTokens(ggufInsights);
+ }
+}
diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts
index 04b3c589..14d08707 100644
--- a/src/gguf/types/GgufMetadataTypes.ts
+++ b/src/gguf/types/GgufMetadataTypes.ts
@@ -263,7 +263,7 @@ export const enum GgufMetadataTokenizerTokenType {
export type GgufMetadataTokenizer = {
readonly ggml: {
- readonly model: "no_vocab" | "llama" | "gpt2" | "bert" | string,
+ readonly model: "no_vocab" | "none" | "llama" | "gpt2" | "bert" | "rwkv" | "t5" | "plamo2" | string,
readonly pre?: "default" | "llama3" | "llama-v3" | "llama-bpe" | "deepseek-llm" | "deepseek-coder" | "falcon" | "falcon3" |
"pixtral" | "mpt" | "starcoder" | "gpt-2" | "phi-2" | "jina-es" | "jina-de" | "jina-v1-en" | "jina-v2-es" | "jina-v2-de" |
"jina-v2-code" | "refact" | "command-r" | "qwen2" | "stablelm2" | "olmo" | "dbrx" | "smaug-bpe" | "poro-chat" | "chatglm-bpe" |
@@ -279,7 +279,7 @@ export type GgufMetadataTokenizer = {
readonly eot_token_id?: number,
readonly eom_token_id?: number,
readonly unknown_token_id?: number,
- readonly separator_token_id?: number,
+ readonly seperator_token_id?: number,
readonly padding_token_id?: number,
readonly cls_token_id?: number,
readonly mask_token_id?: number,
@@ -304,7 +304,8 @@ export type GgufMetadataTokenizer = {
readonly huggingface?: {
readonly json?: string
},
- readonly chat_template?: string
+ readonly chat_template?: string,
+ readonly "chat_template.rerank"?: string
};
export const enum GgufMetadataArchitecturePoolingType {
diff --git a/src/index.ts b/src/index.ts
index ffd02c80..0e52d2ec 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -84,6 +84,7 @@ import {getModuleVersion} from "./utils/getModuleVersion.js";
import {readGgufFileInfo} from "./gguf/readGgufFileInfo.js";
import {GgufInsights, type GgufInsightsResourceRequirements} from "./gguf/insights/GgufInsights.js";
import {GgufInsightsConfigurationResolver} from "./gguf/insights/GgufInsightsConfigurationResolver.js";
+import {GgufInsightsTokens} from "./gguf/insights/GgufInsightsTokens.js";
import {
createModelDownloader, ModelDownloader, type ModelDownloaderOptions, combineModelDownloaders, CombinedModelDownloader,
type CombinedModelDownloaderOptions
@@ -315,6 +316,7 @@ export {
isGgufMetadataOfArchitectureType,
GgufInsights,
type GgufInsightsResourceRequirements,
+ GgufInsightsTokens,
GgufInsightsConfigurationResolver,
createModelDownloader,
ModelDownloader,
diff --git a/src/utils/parseModelUri.ts b/src/utils/parseModelUri.ts
index babea2b6..7cf1696f 100644
--- a/src/utils/parseModelUri.ts
+++ b/src/utils/parseModelUri.ts
@@ -1,4 +1,5 @@
import filenamify from "filenamify";
+import prettyMilliseconds from "pretty-ms";
import {normalizeGgufDownloadUrl} from "../gguf/utils/normalizeGgufDownloadUrl.js";
import {getFilenameForBinarySplitGgufPartUrls, resolveBinarySplitGgufPartUrls} from "../gguf/utils/resolveBinarySplitGgufPartUrls.js";
import {createSplitPartFilename, getGgufSplitPartsInfo} from "../gguf/utils/resolveSplitGgufParts.js";
@@ -7,9 +8,19 @@ import {isUrl} from "./isUrl.js";
import {ModelFileAccessTokens, resolveModelFileAccessTokensTryHeaders} from "./modelFileAccessTokens.js";
import {isHuggingFaceUrl, ModelDownloadEndpoints, resolveHuggingFaceEndpoint} from "./modelDownloadEndpoints.js";
import {parseModelFileName} from "./parseModelFileName.js";
+import {getConsoleLogPrefix} from "./getConsoleLogPrefix.js";
+import {signalSleep} from "./signalSleep.js";
const defaultHuggingFaceBranch = "main";
const defaultHuggingFaceFileQuantization = "Q4_K_M";
+const huggingFaceRateLimit = {
+ wait: {
+ min: 1000,
+ max: 60 * 5 * 1000,
+ default: 1000
+ },
+ retries: 4
+} as const;
export const genericFilePartNumber = "{:\n{number}\n:}" as const;
@@ -208,9 +219,12 @@ async function fetchHuggingFaceModelManifest({
{},
await resolveModelFileAccessTokensTryHeaders(manifestUrl, tokens, endpoints)
];
+ let rateLimitPendingRetries = 0;
- while (headersToTry.length > 0) {
- const headers = headersToTry.shift();
+ for (let i = 0; i < headersToTry.length * (1 + rateLimitPendingRetries); i++) {
+ const headers = headersToTry[i % headersToTry.length];
+ if (headers == null)
+ continue;
let response: Awaited> | undefined;
try {
@@ -226,10 +240,52 @@ async function fetchHuggingFaceModelManifest({
signal
});
} catch (err) {
+ if (signal?.aborted && err === signal?.reason)
+ throw err;
+
throw new Error(`Failed to fetch manifest for resolving URI ${JSON.stringify(fullUri)}: ${err}`);
}
- if ((response.status >= 500 || response.status === 429 || response.status === 401) && headersToTry.length > 0)
+ if (response.status === 429) {
+ const doneRetires = Math.floor(i / headersToTry.length);
+ rateLimitPendingRetries = Math.min(doneRetires + 1, huggingFaceRateLimit.retries);
+
+ if (i % headersToTry.length === headersToTry.length - 1 && i !== headersToTry.length * (1 + rateLimitPendingRetries) - 1) {
+ const [,secondsUntilResetString] = response.headers.get("ratelimit")
+ ?.split(";")
+ .map((part) => part.split("="))
+ .find(([key, value]) => key === "t" && !isNaN(Number(value))) ?? [];
+
+ if (secondsUntilResetString != null) {
+ const timeToWait = Math.min(
+ huggingFaceRateLimit.wait.max,
+ Math.max(
+ huggingFaceRateLimit.wait.min,
+ Number(secondsUntilResetString) * 1000
+ )
+ );
+ console.info(
+ getConsoleLogPrefix() +
+ "Received a rate limit response from Hugging Face, waiting for " + (
+ prettyMilliseconds(timeToWait, {
+ keepDecimalsOnWholeSeconds: true,
+ secondsDecimalDigits: 0,
+ compact: true,
+ verbose: true
+ })
+ ) + " before retrying..."
+ );
+ await signalSleep(timeToWait, signal);
+ } else
+ await signalSleep(huggingFaceRateLimit.wait.default, signal);
+ }
+
+ continue;
+ }
+
+ if ((response.status >= 500 || response.status === 429 || response.status === 401) &&
+ i < headersToTry.length * (1 + rateLimitPendingRetries) - 1
+ )
continue;
if (response.status === 400 || response.status === 404)
diff --git a/src/utils/signalSleep.ts b/src/utils/signalSleep.ts
new file mode 100644
index 00000000..977661f8
--- /dev/null
+++ b/src/utils/signalSleep.ts
@@ -0,0 +1,22 @@
+export function signalSleep(delay: number, abortSignal?: AbortSignal): Promise {
+ return new Promise((accept, reject) => {
+ if (abortSignal?.aborted)
+ return void reject(abortSignal.reason);
+
+ let timeout: ReturnType | undefined = undefined;
+ function onAbort() {
+ reject(abortSignal?.reason);
+ clearTimeout(timeout);
+ abortSignal?.removeEventListener("abort", onAbort);
+ }
+
+ function onTimeout() {
+ accept();
+ timeout = undefined;
+ abortSignal?.removeEventListener("abort", onAbort);
+ }
+
+ abortSignal?.addEventListener("abort", onAbort);
+ timeout = setTimeout(onTimeout, delay);
+ });
+}
diff --git a/src/utils/tokenizerUtils.ts b/src/utils/tokenizerUtils.ts
index 71e9dbb5..ce459966 100644
--- a/src/utils/tokenizerUtils.ts
+++ b/src/utils/tokenizerUtils.ts
@@ -12,6 +12,10 @@ export function resolveBeginningTokenToPrepend(vocabularyType: LlamaVocabularyTy
if (vocabularyType === LlamaVocabularyType.wpm)
return tokens.bos;
+
+ if (vocabularyType === LlamaVocabularyType.ugm)
+ return null;
+
if (tokens.shouldPrependBosToken)
return tokens.bos;
@@ -29,6 +33,9 @@ export function resolveEndTokenToAppend(vocabularyType: LlamaVocabularyType, tok
if (vocabularyType === LlamaVocabularyType.wpm)
return tokens.sep;
+ if (vocabularyType === LlamaVocabularyType.ugm)
+ return tokens.eos;
+
if (tokens.shouldAppendEosToken)
return tokens.eos;
diff --git a/test/modelDependent/llama3.1/tokenPredictor.test.ts b/test/modelDependent/llama3.1/tokenPredictor.test.ts
index 181f7bde..07f7966c 100644
--- a/test/modelDependent/llama3.1/tokenPredictor.test.ts
+++ b/test/modelDependent/llama3.1/tokenPredictor.test.ts
@@ -6,10 +6,13 @@ import {compareTokens} from "../../../src/utils/compareTokens.js";
describe("llama 3.1", () => {
describe("token predictor", () => {
- test("DraftModelTokenPredictor", {timeout: 1000 * 60 * 60 * 2}, async () => {
+ test("DraftModelTokenPredictor", {timeout: 1000 * 60 * 60 * 2}, async (test) => {
const modelPath = await getModelFile("Meta-Llama-3.1-8B-Instruct.Q4_K_M.gguf");
const llama = await getTestLlama();
+ if (llama.gpu !== "metal")
+ test.skip(); // the outputs are a bit different on different platforms, so skipping on other platforms due to flakiness
+
const model = await llama.loadModel({
modelPath
});
diff --git a/test/standalone/cli/recommendedModels.test.ts b/test/standalone/cli/recommendedModels.test.ts
index 6ccf727a..56cb67ab 100644
--- a/test/standalone/cli/recommendedModels.test.ts
+++ b/test/standalone/cli/recommendedModels.test.ts
@@ -4,7 +4,7 @@ import {recommendedModels} from "../../../src/cli/recommendedModels.js";
describe("cli", () => {
describe("recommended models", () => {
- test("all URIs resolve correctly", async () => {
+ test("all URIs resolve correctly", {timeout: 1000 * 60 * 6}, async () => {
const unresolvedUris = (
await Promise.all(
recommendedModels
@@ -18,10 +18,11 @@ describe("cli", () => {
try {
await resolveParsedModelUri(parseModelUri(uri));
return null;
- } catch (err) {
+ } catch (err: Error | any) {
return {
modelName,
- uri
+ uri,
+ error: String(err?.stack ?? err)
};
}
})