Skip to content

Commit c35ff5a

Browse files
authored
feat: flash attention in model selection (#266)
* feat: flash attention in model selection * fix: adapt to `llama.cpp` breaking changes * fix: Llama 3 function calling
1 parent c2e322c commit c35ff5a

File tree

9 files changed

+120
-46
lines changed

9 files changed

+120
-46
lines changed

llama/addon.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,6 @@ static void adjustNapiExternalMemorySubtract(Napi::Env env, uint64_t size) {
108108
}
109109
}
110110

111-
std::string addon_model_token_to_piece(const struct llama_model* model, llama_token token, bool specialTokens) {
112-
std::vector<char> result(8, 0);
113-
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), specialTokens);
114-
if (n_tokens < 0) {
115-
result.resize(-n_tokens);
116-
int check = llama_token_to_piece(model, token, result.data(), result.size(), specialTokens);
117-
GGML_ASSERT(check == -n_tokens);
118-
} else {
119-
result.resize(n_tokens);
120-
}
121-
122-
return std::string(result.data(), result.size());
123-
}
124-
125111
#ifdef GPU_INFO_USE_CUDA
126112
void logCudaError(const char* message) {
127113
addonLlamaCppLogCallback(GGML_LOG_LEVEL_ERROR, (std::string("CUDA error: ") + std::string(message)).c_str(), nullptr);
@@ -395,21 +381,18 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
395381
? info[1].As<Napi::Boolean>().Value()
396382
: false;
397383

398-
// Create a stringstream for accumulating the decoded string.
399-
std::stringstream ss;
384+
std::vector<char> result(8, 0);
385+
const int n_length = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens);
400386

401-
// Decode each token and accumulate the result.
402-
for (size_t i = 0; i < tokens.ElementLength(); i++) {
403-
const std::string piece = addon_model_token_to_piece(model, (llama_token)tokens[i], decodeSpecialTokens);
404-
405-
if (piece.empty()) {
406-
continue;
407-
}
408-
409-
ss << piece;
387+
if (n_length < 0) {
388+
result.resize(-n_length);
389+
int check = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens);
390+
GGML_ASSERT(check == -n_length);
391+
} else {
392+
result.resize(n_length);
410393
}
411394

412-
return Napi::String::New(info.Env(), ss.str());
395+
return Napi::String::New(info.Env(), result.data(), result.size());
413396
}
414397

415398
Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info) {

src/chatWrappers/Llama3ChatWrapper.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export class Llama3ChatWrapper extends ChatWrapper {
2828
functions: {
2929
call: {
3030
optionalPrefixSpace: true,
31-
prefix: "||call:",
31+
prefix: "||call: ",
3232
paramsPrefix: LlamaText(new SpecialTokensText("(")),
3333
suffix: LlamaText(new SpecialTokensText(")"))
3434
},
@@ -56,7 +56,7 @@ export class Llama3ChatWrapper extends ChatWrapper {
5656
functions: {
5757
call: {
5858
optionalPrefixSpace: true,
59-
prefix: "||call:",
59+
prefix: "||call: ",
6060
paramsPrefix: LlamaText(new SpecialTokensText("(")),
6161
suffix: LlamaText(new SpecialTokensText(")"))
6262
},

src/chatWrappers/generic/JinjaTemplateChatWrapper.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,20 @@ import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate
99

1010
export type JinjaTemplateChatWrapperOptions = {
1111
template: string,
12+
13+
/**
14+
* Defaults to `"assistant"`.
15+
*/
1216
modelRoleName?: string,
17+
18+
/**
19+
* Defaults to `"user"`.
20+
*/
1321
userRoleName?: string,
22+
23+
/**
24+
* Defaults to `"system"`.
25+
*/
1426
systemRoleName?: string,
1527

1628
/**

src/cli/commands/ChatCommand.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ async function RunChat({
326326
});
327327
const logBatchSize = batchSize != null;
328328

329-
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
329+
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
330+
flashAttention
331+
});
330332

331333
if (systemInfo)
332334
console.log(llama.systemInfo);

src/cli/commands/CompleteCommand.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ async function RunCompletion({
249249
});
250250
const logBatchSize = batchSize != null;
251251

252-
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
252+
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
253+
flashAttention
254+
});
253255

254256
if (systemInfo)
255257
console.log(llama.systemInfo);

src/cli/commands/InfillCommand.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ async function RunInfill({
259259
});
260260
const logBatchSize = batchSize != null;
261261

262-
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
262+
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
263+
flashAttention
264+
});
263265

264266
if (systemInfo)
265267
console.log(llama.systemInfo);

src/cli/utils/interactivelyAskForModel.ts

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ export async function interactivelyAskForModel({
5656
llama,
5757
modelsDirectory,
5858
allowLocalModels = true,
59-
downloadIntent = true
59+
downloadIntent = true,
60+
flashAttention = false
6061
}: {
6162
llama: Llama,
6263
modelsDirectory?: string,
6364
allowLocalModels?: boolean,
64-
downloadIntent?: boolean
65+
downloadIntent?: boolean,
66+
flashAttention?: boolean
6567
}): Promise<string> {
6668
let localModelFileOptions: (ModelOption & { type: "localModel" })[] = [];
6769
const recommendedModelOptions: (ModelOption & { type: "recommendedModel" })[] = [];
@@ -112,7 +114,9 @@ export async function interactivelyAskForModel({
112114
readItems++;
113115
progressUpdater.setProgress(readItems / ggufFileNames.length, renderProgress());
114116

115-
const compatibilityScore = await ggufInsights?.configurationResolver.scoreModelConfigurationCompatibility();
117+
const compatibilityScore = await ggufInsights?.configurationResolver.scoreModelConfigurationCompatibility({
118+
flashAttention: flashAttention && ggufInsights?.flashAttentionSupported
119+
});
116120

117121
return {
118122
type: "localModel",
@@ -211,7 +215,7 @@ export async function interactivelyAskForModel({
211215
try {
212216
// eslint-disable-next-line no-constant-condition
213217
while (true) {
214-
const minWidth = Math.min(80, process.stdout.columns - 1);
218+
const minWidth = Math.min(80 + (flashAttention ? 26 : 0), process.stdout.columns - 1);
215219
const selectedItem = await basicChooseFromListConsoleInteraction({
216220
title(item, rerender) {
217221
const title = chalk.bold("Select a model:") + " ";
@@ -235,6 +239,17 @@ export async function interactivelyAskForModel({
235239
(String(Math.floor((vramState.used / vramState.total) * 100 * 100) / 100) + "%") + " " +
236240
chalk.dim("(" + bytes(vramState.used) + "/" + bytes(vramState.total) + ")") +
237241
" "
242+
) + (
243+
!flashAttention
244+
? ""
245+
: (
246+
" " +
247+
chalk.bgGray(
248+
" " +
249+
chalk.yellow("Flash attention:") + " " + "enabled" +
250+
" "
251+
)
252+
)
238253
)
239254
);
240255

@@ -273,7 +288,7 @@ export async function interactivelyAskForModel({
273288
},
274289
items: options,
275290
renderItem(item, focused, rerender) {
276-
return renderSelectionItem(item, focused, rerender, activeInteractionController.signal, llama);
291+
return renderSelectionItem(item, focused, rerender, activeInteractionController.signal, llama, flashAttention);
277292
},
278293
canFocusItem(item) {
279294
return item.type === "recommendedModel" || item.type === "localModel" || item.type === "action";
@@ -374,7 +389,9 @@ async function askForModelUrlOrPath(allowLocalModels: boolean): Promise<string |
374389
);
375390
}
376391

377-
function renderSelectionItem(item: ModelOption, focused: boolean, rerender: () => void, abortSignal: AbortSignal, llama: Llama) {
392+
function renderSelectionItem(
393+
item: ModelOption, focused: boolean, rerender: () => void, abortSignal: AbortSignal, llama: Llama, flashAttention: boolean
394+
) {
378395
if (item.type === "localModel") {
379396
let modelText = item.title instanceof Function
380397
? item.title()
@@ -398,7 +415,8 @@ function renderSelectionItem(item: ModelOption, focused: boolean, rerender: () =
398415
recommendedModelOption: item,
399416
abortSignal,
400417
rerenderOption: rerender,
401-
llama
418+
llama,
419+
flashAttention
402420
});
403421
}
404422

@@ -542,12 +560,13 @@ function renderCompatibilityPercentageWithColors(percentage: number, {
542560
}
543561

544562
async function selectFileForModelRecommendation({
545-
recommendedModelOption, llama, abortSignal, rerenderOption
563+
recommendedModelOption, llama, abortSignal, rerenderOption, flashAttention
546564
}: {
547565
recommendedModelOption: ModelOption & { type: "recommendedModel" },
548566
llama: Llama,
549567
abortSignal: AbortSignal,
550-
rerenderOption(): void
568+
rerenderOption(): void,
569+
flashAttention: boolean
551570
}) {
552571
try {
553572
let bestScore: number | undefined = undefined;
@@ -567,7 +586,9 @@ async function selectFileForModelRecommendation({
567586
if (abortSignal.aborted)
568587
return;
569588

570-
const compatibilityScore = await ggufInsights.configurationResolver.scoreModelConfigurationCompatibility();
589+
const compatibilityScore = await ggufInsights.configurationResolver.scoreModelConfigurationCompatibility({
590+
flashAttention
591+
});
571592

572593
if (bestScore == null || compatibilityScore.compatibilityScore > bestScore) {
573594
bestScore = compatibilityScore.compatibilityScore;

src/cli/utils/resolveCommandGgufPath.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ import {getReadablePath} from "./getReadablePath.js";
1212
import {interactivelyAskForModel} from "./interactivelyAskForModel.js";
1313

1414
export async function resolveCommandGgufPath(ggufPath: string | undefined, llama: Llama, fetchHeaders?: Record<string, string>, {
15-
targetDirectory = cliModelsDirectory
15+
targetDirectory = cliModelsDirectory, flashAttention = false
1616
}: {
17-
targetDirectory?: string
17+
targetDirectory?: string, flashAttention?: boolean
1818
} = {}) {
1919
let resolvedGgufPath = ggufPath;
2020

@@ -23,7 +23,8 @@ export async function resolveCommandGgufPath(ggufPath: string | undefined, llama
2323
llama,
2424
modelsDirectory: targetDirectory,
2525
allowLocalModels: true,
26-
downloadIntent: true
26+
downloadIntent: true,
27+
flashAttention
2728
});
2829

2930
if (!isUrl(resolvedGgufPath)) {

test/modelDependent/functionary/sanity.test.ts

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import {describe, expect, test} from "vitest";
2-
import {LlamaChatSession} from "../../../src/index.js";
2+
import {LlamaChatSession, SpecialTokensText, LlamaText} from "../../../src/index.js";
33
import {getModelFile} from "../../utils/modelFiles.js";
44
import {getTestLlama} from "../../utils/getTestLlama.js";
55

@@ -86,7 +86,7 @@ describe("functionary", () => {
8686
`);
8787
});
8888

89-
test("tokenizing text and then detokenizing it arrive at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
89+
test("tokenizing a text and then detokenizing it arrives at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
9090
const modelPath = await getModelFile("functionary-small-v2.5.Q4_0.gguf");
9191
const llama = await getTestLlama();
9292

@@ -178,6 +178,57 @@ describe("functionary", () => {
178178
expect(textWithSpecialTokens).to.eql(text);
179179
expect(textNoSpecialTokens).to.eql(text);
180180
}
181+
182+
{
183+
const text = "Hi there";
184+
185+
const tokensWithTrim = model.tokenize(text, false, "trimLeadingSpace");
186+
const tokensWithoutTrim = model.tokenize(text, false);
187+
188+
expect(model.detokenize(tokensWithTrim)).to.eql(text);
189+
expect(model.detokenize(tokensWithoutTrim)).to.eql(text);
190+
}
191+
{
192+
const text = " Hi there";
193+
194+
const tokensWithTrim = model.tokenize(text, false, "trimLeadingSpace");
195+
const tokensWithoutTrim = model.tokenize(text, false);
196+
197+
expect(model.detokenize(tokensWithTrim)).to.eql(text);
198+
expect(model.detokenize(tokensWithoutTrim)).to.eql(text);
199+
}
200+
});
201+
202+
test("tokenizing a LlamaText and then detokenizing it arrives at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
203+
const modelPath = await getModelFile("functionary-small-v2.5.Q4_0.gguf");
204+
const llama = await getTestLlama();
205+
206+
const model = await llama.loadModel({
207+
modelPath
208+
});
209+
210+
{
211+
const text = LlamaText([
212+
new SpecialTokensText("<|start_header_id|>system<|end_header_id|>\n\n"),
213+
"How much is 6+6\n"
214+
]);
215+
216+
const tokens = text.tokenize(model.tokenizer);
217+
218+
expect(model.detokenize(tokens, true)).to.eql("<|start_header_id|>system<|end_header_id|>\n\nHow much is 6+6\n");
219+
expect(model.detokenize(tokens, false)).to.eql("system\n\nHow much is 6+6\n");
220+
}
221+
{
222+
const text = LlamaText([
223+
new SpecialTokensText("Hi <|start_header_id|>there\n\n"),
224+
"How much is 6+6\n"
225+
]);
226+
227+
const tokens = text.tokenize(model.tokenizer);
228+
229+
expect(model.detokenize(tokens, true)).to.eql("Hi <|start_header_id|>there\n\nHow much is 6+6\n");
230+
expect(model.detokenize(tokens, false)).to.eql("Hi there\n\nHow much is 6+6\n");
231+
}
181232
});
182233
});
183234
});

0 commit comments

Comments
 (0)