diff --git a/src/bindings/utils/compileLLamaCpp.ts b/src/bindings/utils/compileLLamaCpp.ts index b88c9b5f..fe56d3e4 100644 --- a/src/bindings/utils/compileLLamaCpp.ts +++ b/src/bindings/utils/compileLLamaCpp.ts @@ -131,9 +131,14 @@ export async function compileLlamaCpp(buildOptions: BuildOptions, compileOptions if (!cmakeCustomOptions.has("GGML_CCACHE")) cmakeCustomOptions.set("GGML_CCACHE", "OFF"); - if (!cmakeCustomOptions.has("LLAMA_CURL")) + if (!cmakeCustomOptions.has("LLAMA_CURL") || isCmakeValueOff(cmakeCustomOptions.get("LLAMA_CURL"))) { cmakeCustomOptions.set("LLAMA_CURL", "OFF"); + // avoid linking to extra libraries that we don't use + if (!cmakeCustomOptions.has("LLAMA_OPENSSL")) + cmakeCustomOptions.set("LLAMA_OPENSSL", "OFF"); + } + if (buildOptions.platform === "win" && buildOptions.arch === "arm64" && !cmakeCustomOptions.has("GGML_OPENMP")) cmakeCustomOptions.set("GGML_OPENMP", "OFF"); diff --git a/src/evaluator/LlamaRankingContext.ts b/src/evaluator/LlamaRankingContext.ts index 71ee32e9..e7e7d0ff 100644 --- a/src/evaluator/LlamaRankingContext.ts +++ b/src/evaluator/LlamaRankingContext.ts @@ -1,20 +1,21 @@ -import {AsyncDisposeAggregator, EventRelay, withLock} from "lifecycle-utils"; +import {AsyncDisposeAggregator, EventRelay, splitText, withLock} from "lifecycle-utils"; import {Token} from "../types.js"; import {LlamaText} from "../utils/LlamaText.js"; import {tokenizeInput} from "../utils/tokenizeInput.js"; +import {resolveBeginningTokenToPrepend, resolveEndTokenToAppend} from "../utils/tokenizerUtils.js"; +import {isRankingTemplateValid, parseRankingTemplate} from "../gguf/insights/GgufInsights.js"; import type {LlamaModel} from "./LlamaModel/LlamaModel.js"; import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js"; -import type {GgufTensorInfo} from "../gguf/types/GgufTensorInfoTypes.js"; export type LlamaRankingContextOptions = { /** * The number of tokens the model can see at once. - * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size + * - **`"auto"`** - adapt to the current VRAM state and attempt to set the context size as high as possible up to the size * the model was trained on. * - **`number`** - set the context size to a specific number of tokens. * If there's not enough VRAM, an error will be thrown. * Use with caution. - * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible + * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attempt to set the context size as high as possible * up to the size the model was trained on, but at least `min` and at most `max`. * * Defaults to `"auto"`. @@ -36,6 +37,22 @@ export type LlamaRankingContextOptions = { /** An abort signal to abort the context creation */ createSignal?: AbortSignal, + /** + * The template to use for the ranking evaluation. + * If not provided, the model's template will be used by default. + * + * The template is tokenized with special tokens enabled, but the provided query and document are not. + * + * **`{{query}}`** is replaced with the query content. + * + * **`{{document}}`** is replaced with the document content. + * + * It's recommended to not set this option unless you know what you're doing. + * + * Defaults to the model's template. + */ + template?: `${string}{{query}}${string}{{document}}${string}` | `${string}{{document}}${string}{{query}}${string}`, + /** * Ignore insufficient memory errors and continue with the context creation. * Can cause the process to crash if there's not enough VRAM for the new context. @@ -50,17 +67,21 @@ export type LlamaRankingContextOptions = { */ export class LlamaRankingContext { /** @internal */ private readonly _llamaContext: LlamaContext; + /** @internal */ private readonly _template: string | undefined; /** @internal */ private readonly _sequence: LlamaContextSequence; /** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator(); public readonly onDispose = new EventRelay(); private constructor({ - _llamaContext + _llamaContext, + _template }: { - _llamaContext: LlamaContext + _llamaContext: LlamaContext, + _template: string | undefined }) { this._llamaContext = _llamaContext; + this._template = _template; this._sequence = this._llamaContext.getSequence(); this._disposeAggregator.add( @@ -81,9 +102,6 @@ export class LlamaRankingContext { * @returns a ranking score between 0 and 1 representing the probability that the document is relevant to the query. */ public async rank(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) { - if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null) - throw new Error("Computing rankings is not supported for this model."); - const resolvedInput = this._getEvaluationInput(query, document); if (resolvedInput.length > this._llamaContext.contextSize) @@ -159,7 +177,35 @@ export class LlamaRankingContext { /** @internal */ private _getEvaluationInput(query: Token[] | string | LlamaText, document: Token[] | string | LlamaText) { - if (this.model.tokens.bos == null || this.model.tokens.eos == null || this.model.tokens.sep == null) + if (this._template != null) { + const resolvedInput = splitText(this._template, ["{{query}}", "{{document}}"]) + .flatMap((item) => { + if (typeof item === "string") + return this._llamaContext.model.tokenize(item, true, "trimLeadingSpace"); + else if (item.separator === "{{query}}") + return tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false); + else if (item.separator === "{{document}}") + return tokenizeInput(document, this._llamaContext.model.tokenizer, "trimLeadingSpace", false); + else + void (item satisfies never); + + void (item satisfies never); + return []; + }); + + const beginningTokens = resolveBeginningTokenToPrepend(this.model.vocabularyType, this.model.tokens); + const endToken = resolveEndTokenToAppend(this.model.vocabularyType, this.model.tokens); + + if (beginningTokens != null && resolvedInput.at(0) !== beginningTokens) + resolvedInput.unshift(beginningTokens); + + if (endToken != null && resolvedInput.at(-1) !== endToken) + resolvedInput.unshift(endToken); + + return resolvedInput; + } + + if (this.model.tokens.eos == null && this.model.tokens.sep == null) throw new Error("Computing rankings is not supported for this model."); const resolvedQuery = tokenizeInput(query, this._llamaContext.model.tokenizer, "trimLeadingSpace", false); @@ -169,12 +215,12 @@ export class LlamaRankingContext { return []; const resolvedInput = [ - this.model.tokens.bos, + ...(this.model.tokens.bos == null ? [] : [this.model.tokens.bos]), ...resolvedQuery, - this.model.tokens.eos, - this.model.tokens.sep, + ...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos]), + ...(this.model.tokens.sep == null ? [] : [this.model.tokens.sep]), ...resolvedDocument, - this.model.tokens.eos + ...(this.model.tokens.eos == null ? [] : [this.model.tokens.eos]) ]; return resolvedInput; @@ -218,24 +264,27 @@ export class LlamaRankingContext { batchSize, threads = 6, createSignal, + template, ignoreMemorySafetyChecks }: LlamaRankingContextOptions) { - const tensorInfo = _model.fileInfo.tensorInfo; - - if (_model.tokens.bos == null || _model.tokens.eos == null || _model.tokens.sep == null) - throw new Error("Computing rankings is not supported for this model."); - - // source: `append_pooling` in `llama.cpp` - if (findLayer(tensorInfo, "cls", "weight") == null || findLayer(tensorInfo, "cls", "bias") == null) - throw new Error("Computing rankings is not supported for this model."); - - // source: `append_pooling` in `llama.cpp` - if (findLayer(tensorInfo, "cls.output", "weight") != null && findLayer(tensorInfo, "cls.output", "bias") == null) - throw new Error("Computing rankings is not supported for this model."); + const resolvedTemplate = template ?? parseRankingTemplate(_model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"]); + + if (_model.tokens.eos == null && _model.tokens.sep == null) { + if (!isRankingTemplateValid(resolvedTemplate)) { + if (resolvedTemplate === _model.fileInfo.metadata?.tokenizer?.["chat_template.rerank"]) + throw new Error("The model's builtin template is invalid. It must contain both {query} and {document} placeholders."); + else + throw new Error("The provided template is invalid. It must contain both {{query}} and {{document}} placeholders."); + } else if (resolvedTemplate == null) + throw new Error("Computing rankings is not supported for this model."); + } if (_model.fileInsights.hasEncoder && _model.fileInsights.hasDecoder) throw new Error("Computing rankings is not supported for encoder-decoder models."); + if (!_model.fileInsights.supportsRanking) + throw new Error("Computing rankings is not supported for this model."); + const llamaContext = await _model.createContext({ contextSize, batchSize, @@ -247,23 +296,12 @@ export class LlamaRankingContext { }); return new LlamaRankingContext({ - _llamaContext: llamaContext + _llamaContext: llamaContext, + _template: resolvedTemplate }); } } -function findLayer(tensorInfo: GgufTensorInfo[] | undefined, name: string, suffix: string) { - if (tensorInfo == null) - return undefined; - - for (const tensor of tensorInfo) { - if (tensor.name === name + "." + suffix) - return tensor; - } - - return undefined; -} - function logitToSigmoid(logit: number) { return 1 / (1 + Math.exp(-logit)); } diff --git a/src/gguf/insights/GgufInsights.ts b/src/gguf/insights/GgufInsights.ts index f32ceb0d..73eb94a6 100644 --- a/src/gguf/insights/GgufInsights.ts +++ b/src/gguf/insights/GgufInsights.ts @@ -6,6 +6,7 @@ import {GgufTensorInfo} from "../types/GgufTensorInfoTypes.js"; import {GgufArchitectureType} from "../types/GgufMetadataTypes.js"; import {getReadablePath} from "../../cli/utils/getReadablePath.js"; import {GgufInsightsConfigurationResolver} from "./GgufInsightsConfigurationResolver.js"; +import {GgufInsightsTokens} from "./GgufInsightsTokens.js"; export type GgufInsightsResourceRequirements = { cpuRam: number, @@ -16,8 +17,10 @@ export class GgufInsights { /** @internal */ public readonly _llama: Llama; /** @internal */ private readonly _modelSize: number; /** @internal */ private _totalFileLayers: number | null = null; - /** @internal */ private readonly _ggufFileInfo: GgufFileInfo; + /** @internal */ private _supportsRanking?: boolean; + /** @internal */ public readonly _ggufFileInfo: GgufFileInfo; /** @internal */ private readonly _configurationResolver: GgufInsightsConfigurationResolver; + /** @internal */ private readonly _tokens: GgufInsightsTokens; private constructor(ggufFileInfo: GgufFileInfo, llama: Llama) { this._llama = llama; @@ -25,6 +28,7 @@ export class GgufInsights { this._modelSize = calculateTensorsSize(ggufFileInfo.fullTensorInfo ?? [], llama, true, true); this._configurationResolver = GgufInsightsConfigurationResolver._create(this); + this._tokens = GgufInsightsTokens._create(this); } /** @@ -60,6 +64,10 @@ export class GgufInsights { return this._configurationResolver; } + public get tokens() { + return this._tokens; + } + /** The context size the model was trained on */ public get trainContextSize() { return this._ggufFileInfo.architectureMetadata.context_length; @@ -132,6 +140,29 @@ export class GgufInsights { return false; } + public get supportsRanking() { + if (this._supportsRanking != null) + return this._supportsRanking; + + const layers = this._ggufFileInfo.fullTensorInfo ?? []; + for (let i = layers.length - 1; i >= 0; i--) { + const tensor = layers[i]; + if (tensor == null) + continue; + + if (tensor.name === "cls.weight" || tensor.name === "cls.output.weight") { + this._supportsRanking = this.tokens.sepToken != null || this.tokens.eosToken != null || + isRankingTemplateValid(parseRankingTemplate(this._ggufFileInfo.metadata?.tokenizer?.["chat_template.rerank"])); + this._supportsRanking &&= !(this.hasEncoder && this.hasDecoder); // encoder-decoder models are not supported + + return this._supportsRanking; + } + } + + this._supportsRanking = false; + return this._supportsRanking; + } + /** * The size of the SWA (Sliding Window Attention). * @@ -787,3 +818,16 @@ function getSwaPatternForArchitecture(architecture?: GgufArchitectureType): numb return 1; } + +export function parseRankingTemplate(template: string | undefined | null): string | undefined { + if (template == null) + return undefined; + + return template + .replaceAll("{query}", "{{query}}") + .replaceAll("{document}", "{{document}}"); +} + +export function isRankingTemplateValid(template: string | undefined | null): boolean { + return template != null && template.includes("{{query}}") && template.includes("{{document}}"); +} diff --git a/src/gguf/insights/GgufInsightsTokens.ts b/src/gguf/insights/GgufInsightsTokens.ts new file mode 100644 index 00000000..70727413 --- /dev/null +++ b/src/gguf/insights/GgufInsightsTokens.ts @@ -0,0 +1,51 @@ +/* eslint @stylistic/max-statements-per-line: ["warn", {"ignoredNodes": ["BreakStatement"]}] */ +import type {GgufInsights} from "./GgufInsights.js"; + +export class GgufInsightsTokens { + /** @internal */ private readonly _ggufInsights: GgufInsights; + + private constructor(ggufInsights: GgufInsights) { + this._ggufInsights = ggufInsights; + } + + public get sepToken(): number | null { + const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model; + const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length; + + let sepTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["seperator_token_id"]; + if (sepTokenId == null && tokenizerModel === "bert") { + sepTokenId = 102; // source: `llama_vocab::impl::load` in `llama-vocab.cpp` + } + + if (totalTokens != null && sepTokenId != null && sepTokenId >= totalTokens) + return null; + + return sepTokenId ?? null; + } + + public get eosToken(): number | null { + const tokenizerModel = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.model; + const totalTokens = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.tokens?.length; + + const eosTokenId = this._ggufInsights._ggufFileInfo?.metadata?.tokenizer?.ggml?.["eos_token_id"]; + if (eosTokenId != null && totalTokens != null && eosTokenId < totalTokens) + return eosTokenId; + + switch (tokenizerModel) { + case "no_vocab": return null; + case "none": return null; + case "bert": return null; + case "rwkv": return null; + case "llama": return 2; + case "gpt2": return 11; + case "t5": return 1; + case "plamo2": return 2; + } + return 2; // source: `llama_vocab::impl::load` in `llama-vocab.cpp` + } + + /** @internal */ + public static _create(ggufInsights: GgufInsights) { + return new GgufInsightsTokens(ggufInsights); + } +} diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts index 04b3c589..14d08707 100644 --- a/src/gguf/types/GgufMetadataTypes.ts +++ b/src/gguf/types/GgufMetadataTypes.ts @@ -263,7 +263,7 @@ export const enum GgufMetadataTokenizerTokenType { export type GgufMetadataTokenizer = { readonly ggml: { - readonly model: "no_vocab" | "llama" | "gpt2" | "bert" | string, + readonly model: "no_vocab" | "none" | "llama" | "gpt2" | "bert" | "rwkv" | "t5" | "plamo2" | string, readonly pre?: "default" | "llama3" | "llama-v3" | "llama-bpe" | "deepseek-llm" | "deepseek-coder" | "falcon" | "falcon3" | "pixtral" | "mpt" | "starcoder" | "gpt-2" | "phi-2" | "jina-es" | "jina-de" | "jina-v1-en" | "jina-v2-es" | "jina-v2-de" | "jina-v2-code" | "refact" | "command-r" | "qwen2" | "stablelm2" | "olmo" | "dbrx" | "smaug-bpe" | "poro-chat" | "chatglm-bpe" | @@ -279,7 +279,7 @@ export type GgufMetadataTokenizer = { readonly eot_token_id?: number, readonly eom_token_id?: number, readonly unknown_token_id?: number, - readonly separator_token_id?: number, + readonly seperator_token_id?: number, readonly padding_token_id?: number, readonly cls_token_id?: number, readonly mask_token_id?: number, @@ -304,7 +304,8 @@ export type GgufMetadataTokenizer = { readonly huggingface?: { readonly json?: string }, - readonly chat_template?: string + readonly chat_template?: string, + readonly "chat_template.rerank"?: string }; export const enum GgufMetadataArchitecturePoolingType { diff --git a/src/index.ts b/src/index.ts index ffd02c80..0e52d2ec 100644 --- a/src/index.ts +++ b/src/index.ts @@ -84,6 +84,7 @@ import {getModuleVersion} from "./utils/getModuleVersion.js"; import {readGgufFileInfo} from "./gguf/readGgufFileInfo.js"; import {GgufInsights, type GgufInsightsResourceRequirements} from "./gguf/insights/GgufInsights.js"; import {GgufInsightsConfigurationResolver} from "./gguf/insights/GgufInsightsConfigurationResolver.js"; +import {GgufInsightsTokens} from "./gguf/insights/GgufInsightsTokens.js"; import { createModelDownloader, ModelDownloader, type ModelDownloaderOptions, combineModelDownloaders, CombinedModelDownloader, type CombinedModelDownloaderOptions @@ -315,6 +316,7 @@ export { isGgufMetadataOfArchitectureType, GgufInsights, type GgufInsightsResourceRequirements, + GgufInsightsTokens, GgufInsightsConfigurationResolver, createModelDownloader, ModelDownloader, diff --git a/src/utils/parseModelUri.ts b/src/utils/parseModelUri.ts index babea2b6..7cf1696f 100644 --- a/src/utils/parseModelUri.ts +++ b/src/utils/parseModelUri.ts @@ -1,4 +1,5 @@ import filenamify from "filenamify"; +import prettyMilliseconds from "pretty-ms"; import {normalizeGgufDownloadUrl} from "../gguf/utils/normalizeGgufDownloadUrl.js"; import {getFilenameForBinarySplitGgufPartUrls, resolveBinarySplitGgufPartUrls} from "../gguf/utils/resolveBinarySplitGgufPartUrls.js"; import {createSplitPartFilename, getGgufSplitPartsInfo} from "../gguf/utils/resolveSplitGgufParts.js"; @@ -7,9 +8,19 @@ import {isUrl} from "./isUrl.js"; import {ModelFileAccessTokens, resolveModelFileAccessTokensTryHeaders} from "./modelFileAccessTokens.js"; import {isHuggingFaceUrl, ModelDownloadEndpoints, resolveHuggingFaceEndpoint} from "./modelDownloadEndpoints.js"; import {parseModelFileName} from "./parseModelFileName.js"; +import {getConsoleLogPrefix} from "./getConsoleLogPrefix.js"; +import {signalSleep} from "./signalSleep.js"; const defaultHuggingFaceBranch = "main"; const defaultHuggingFaceFileQuantization = "Q4_K_M"; +const huggingFaceRateLimit = { + wait: { + min: 1000, + max: 60 * 5 * 1000, + default: 1000 + }, + retries: 4 +} as const; export const genericFilePartNumber = "{:\n{number}\n:}" as const; @@ -208,9 +219,12 @@ async function fetchHuggingFaceModelManifest({ {}, await resolveModelFileAccessTokensTryHeaders(manifestUrl, tokens, endpoints) ]; + let rateLimitPendingRetries = 0; - while (headersToTry.length > 0) { - const headers = headersToTry.shift(); + for (let i = 0; i < headersToTry.length * (1 + rateLimitPendingRetries); i++) { + const headers = headersToTry[i % headersToTry.length]; + if (headers == null) + continue; let response: Awaited> | undefined; try { @@ -226,10 +240,52 @@ async function fetchHuggingFaceModelManifest({ signal }); } catch (err) { + if (signal?.aborted && err === signal?.reason) + throw err; + throw new Error(`Failed to fetch manifest for resolving URI ${JSON.stringify(fullUri)}: ${err}`); } - if ((response.status >= 500 || response.status === 429 || response.status === 401) && headersToTry.length > 0) + if (response.status === 429) { + const doneRetires = Math.floor(i / headersToTry.length); + rateLimitPendingRetries = Math.min(doneRetires + 1, huggingFaceRateLimit.retries); + + if (i % headersToTry.length === headersToTry.length - 1 && i !== headersToTry.length * (1 + rateLimitPendingRetries) - 1) { + const [,secondsUntilResetString] = response.headers.get("ratelimit") + ?.split(";") + .map((part) => part.split("=")) + .find(([key, value]) => key === "t" && !isNaN(Number(value))) ?? []; + + if (secondsUntilResetString != null) { + const timeToWait = Math.min( + huggingFaceRateLimit.wait.max, + Math.max( + huggingFaceRateLimit.wait.min, + Number(secondsUntilResetString) * 1000 + ) + ); + console.info( + getConsoleLogPrefix() + + "Received a rate limit response from Hugging Face, waiting for " + ( + prettyMilliseconds(timeToWait, { + keepDecimalsOnWholeSeconds: true, + secondsDecimalDigits: 0, + compact: true, + verbose: true + }) + ) + " before retrying..." + ); + await signalSleep(timeToWait, signal); + } else + await signalSleep(huggingFaceRateLimit.wait.default, signal); + } + + continue; + } + + if ((response.status >= 500 || response.status === 429 || response.status === 401) && + i < headersToTry.length * (1 + rateLimitPendingRetries) - 1 + ) continue; if (response.status === 400 || response.status === 404) diff --git a/src/utils/signalSleep.ts b/src/utils/signalSleep.ts new file mode 100644 index 00000000..977661f8 --- /dev/null +++ b/src/utils/signalSleep.ts @@ -0,0 +1,22 @@ +export function signalSleep(delay: number, abortSignal?: AbortSignal): Promise { + return new Promise((accept, reject) => { + if (abortSignal?.aborted) + return void reject(abortSignal.reason); + + let timeout: ReturnType | undefined = undefined; + function onAbort() { + reject(abortSignal?.reason); + clearTimeout(timeout); + abortSignal?.removeEventListener("abort", onAbort); + } + + function onTimeout() { + accept(); + timeout = undefined; + abortSignal?.removeEventListener("abort", onAbort); + } + + abortSignal?.addEventListener("abort", onAbort); + timeout = setTimeout(onTimeout, delay); + }); +} diff --git a/src/utils/tokenizerUtils.ts b/src/utils/tokenizerUtils.ts index 71e9dbb5..ce459966 100644 --- a/src/utils/tokenizerUtils.ts +++ b/src/utils/tokenizerUtils.ts @@ -12,6 +12,10 @@ export function resolveBeginningTokenToPrepend(vocabularyType: LlamaVocabularyTy if (vocabularyType === LlamaVocabularyType.wpm) return tokens.bos; + + if (vocabularyType === LlamaVocabularyType.ugm) + return null; + if (tokens.shouldPrependBosToken) return tokens.bos; @@ -29,6 +33,9 @@ export function resolveEndTokenToAppend(vocabularyType: LlamaVocabularyType, tok if (vocabularyType === LlamaVocabularyType.wpm) return tokens.sep; + if (vocabularyType === LlamaVocabularyType.ugm) + return tokens.eos; + if (tokens.shouldAppendEosToken) return tokens.eos; diff --git a/test/modelDependent/llama3.1/tokenPredictor.test.ts b/test/modelDependent/llama3.1/tokenPredictor.test.ts index 181f7bde..07f7966c 100644 --- a/test/modelDependent/llama3.1/tokenPredictor.test.ts +++ b/test/modelDependent/llama3.1/tokenPredictor.test.ts @@ -6,10 +6,13 @@ import {compareTokens} from "../../../src/utils/compareTokens.js"; describe("llama 3.1", () => { describe("token predictor", () => { - test("DraftModelTokenPredictor", {timeout: 1000 * 60 * 60 * 2}, async () => { + test("DraftModelTokenPredictor", {timeout: 1000 * 60 * 60 * 2}, async (test) => { const modelPath = await getModelFile("Meta-Llama-3.1-8B-Instruct.Q4_K_M.gguf"); const llama = await getTestLlama(); + if (llama.gpu !== "metal") + test.skip(); // the outputs are a bit different on different platforms, so skipping on other platforms due to flakiness + const model = await llama.loadModel({ modelPath }); diff --git a/test/standalone/cli/recommendedModels.test.ts b/test/standalone/cli/recommendedModels.test.ts index 6ccf727a..56cb67ab 100644 --- a/test/standalone/cli/recommendedModels.test.ts +++ b/test/standalone/cli/recommendedModels.test.ts @@ -4,7 +4,7 @@ import {recommendedModels} from "../../../src/cli/recommendedModels.js"; describe("cli", () => { describe("recommended models", () => { - test("all URIs resolve correctly", async () => { + test("all URIs resolve correctly", {timeout: 1000 * 60 * 6}, async () => { const unresolvedUris = ( await Promise.all( recommendedModels @@ -18,10 +18,11 @@ describe("cli", () => { try { await resolveParsedModelUri(parseModelUri(uri)); return null; - } catch (err) { + } catch (err: Error | any) { return { modelName, - uri + uri, + error: String(err?.stack ?? err) }; } })