Skip to content

Commit c2e322c

Browse files
authored
feat: flash attention (#264)
* feat: flash attention * feat: exclude GPU types from `gpu: "auto"`
1 parent 81e0575 commit c2e322c

20 files changed

+307
-81
lines changed

llama/addon.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,10 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
987987
context_params.embeddings = options.Get("embeddings").As<Napi::Boolean>().Value();
988988
}
989989

990+
if (options.Has("flashAttention")) {
991+
context_params.flash_attn = options.Get("flashAttention").As<Napi::Boolean>().Value();
992+
}
993+
990994
if (options.Has("threads")) {
991995
const auto n_threads = options.Get("threads").As<Napi::Number>().Uint32Value();
992996
const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads;

src/bindings/AddonTypes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export type BindingModule = {
2020
contextSize?: number,
2121
batchSize?: number,
2222
sequences?: number,
23+
flashAttention?: boolean,
2324
logitsAll?: boolean,
2425
embeddings?: boolean,
2526
threads?: number

src/bindings/Llama.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {GbnfJsonSchema} from "../utils/gbnfJson/types.js";
77
import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js";
88
import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js";
99
import {BindingModule} from "./AddonTypes.js";
10-
import {BuildGpu, BuildMetadataFile, LlamaLocks, LlamaLogLevel} from "./types.js";
10+
import {BuildGpu, BuildMetadataFile, LlamaGpuType, LlamaLocks, LlamaLogLevel} from "./types.js";
1111
import {MemoryOrchestrator, MemoryReservation} from "./utils/MemoryOrchestrator.js";
1212

1313
const LlamaLogLevelToAddonLogLevel: ReadonlyMap<LlamaLogLevel, number> = new Map([
@@ -31,7 +31,7 @@ export class Llama {
3131
/** @internal */ public readonly _vramOrchestrator: MemoryOrchestrator;
3232
/** @internal */ public readonly _vramPadding: MemoryReservation;
3333
/** @internal */ public readonly _debug: boolean;
34-
/** @internal */ private readonly _gpu: BuildGpu;
34+
/** @internal */ private readonly _gpu: LlamaGpuType;
3535
/** @internal */ private readonly _buildType: "localBuild" | "prebuilt";
3636
/** @internal */ private readonly _cmakeOptions: Readonly<Record<string, string>>;
3737
/** @internal */ private readonly _supportsGpuOffloading: boolean;
@@ -244,7 +244,10 @@ export class Llama {
244244
await this._bindings.init();
245245
}
246246

247-
/** @internal */
247+
/**
248+
* Log messages related to the Llama instance
249+
* @internal
250+
*/
248251
public _log(level: LlamaLogLevel, message: string) {
249252
this._onAddonLog(LlamaLogLevelToAddonLogLevel.get(level) ?? defaultLogLevel, message + "\n");
250253
}

src/bindings/getLlama.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {
1616
} from "./utils/compileLLamaCpp.js";
1717
import {getLastBuildInfo} from "./utils/lastBuildInfo.js";
1818
import {getClonedLlamaCppRepoReleaseInfo, isLlamaCppRepoCloned} from "./utils/cloneLlamaCppRepo.js";
19-
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaLogLevel} from "./types.js";
19+
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaGpuType, LlamaLogLevel} from "./types.js";
2020
import {BinaryPlatform, getPlatform} from "./utils/getPlatform.js";
2121
import {getBuildFolderNameForBuildOptions} from "./utils/getBuildFolderNameForBuildOptions.js";
2222
import {resolveCustomCmakeOptions} from "./utils/resolveCustomCmakeOptions.js";
@@ -46,7 +46,10 @@ export type LlamaOptions = {
4646
*
4747
* `"auto"` by default.
4848
*/
49-
gpu?: "auto" | "metal" | "cuda" | "vulkan" | false,
49+
gpu?: "auto" | LlamaGpuType | {
50+
type: "auto",
51+
exclude?: LlamaGpuType[]
52+
},
5053

5154
/**
5255
* Set the minimum log level for llama.cpp.
@@ -298,6 +301,9 @@ export async function getLlamaForOptions({
298301
}
299302
}
300303

304+
if (buildGpusToTry.length === 0)
305+
throw new Error("No GPU types available to try building with");
306+
301307
if (build === "auto" || build === "never") {
302308
for (let i = 0; i < buildGpusToTry.length; i++) {
303309
const gpu = buildGpusToTry[i];

src/bindings/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {BinaryPlatform} from "./utils/getPlatform.js";
33
import {BinaryPlatformInfo} from "./utils/getPlatformInfo.js";
44

55
export const buildGpuOptions = ["metal", "cuda", "vulkan", false] as const;
6+
export type LlamaGpuType = "metal" | "cuda" | "vulkan" | false;
67
export const nodeLlamaCppGpuOptions = [
78
"auto",
89
...buildGpuOptions

src/bindings/utils/getGpuTypesToUseForOption.ts

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,41 @@
11
import process from "process";
22
import {BuildGpu, buildGpuOptions} from "../types.js";
3+
import {LlamaOptions} from "../getLlama.js";
34
import {BinaryPlatform, getPlatform} from "./getPlatform.js";
45
import {getBestComputeLayersAvailable} from "./getBestComputeLayersAvailable.js";
56

6-
export async function getGpuTypesToUseForOption(gpu: BuildGpu | "auto", {
7+
export async function getGpuTypesToUseForOption(gpu: Required<LlamaOptions>["gpu"], {
78
platform = getPlatform(),
89
arch = process.arch
910
}: {
1011
platform?: BinaryPlatform,
1112
arch?: typeof process.arch
1213
} = {}): Promise<BuildGpu[]> {
13-
const resolvedGpu = resolveValidGpuOptionForPlatform(gpu, {
14+
const resolvedGpuOption = typeof gpu === "object"
15+
? gpu.type
16+
: gpu;
17+
18+
function withExcludedGpuTypesRemoved(gpuTypes: BuildGpu[]) {
19+
const resolvedExcludeTypes = typeof gpu === "object"
20+
? new Set(gpu.exclude ?? [])
21+
: new Set();
22+
23+
return gpuTypes.filter(gpuType => !resolvedExcludeTypes.has(gpuType));
24+
}
25+
26+
const resolvedGpu = resolveValidGpuOptionForPlatform(resolvedGpuOption, {
1427
platform,
1528
arch
1629
});
1730

1831
if (resolvedGpu === "auto") {
1932
if (arch === process.arch)
20-
return await getBestComputeLayersAvailable();
33+
return withExcludedGpuTypesRemoved(await getBestComputeLayersAvailable());
2134

22-
return [false];
35+
return withExcludedGpuTypesRemoved([false]);
2336
}
2437

25-
return [resolvedGpu];
38+
return withExcludedGpuTypesRemoved([resolvedGpu]);
2639
}
2740

2841
export function resolveValidGpuOptionForPlatform(gpu: BuildGpu | "auto", {

src/cli/commands/ChatCommand.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type ChatCommand = {
4141
noJinja?: boolean,
4242
contextSize?: number,
4343
batchSize?: number,
44+
flashAttention?: boolean,
4445
noTrimWhitespace: boolean,
4546
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
4647
jsonSchemaGrammarFile?: string,
@@ -149,6 +150,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
149150
type: "number",
150151
description: "Batch size to use for the model context. The default value is the context size"
151152
})
153+
.option("flashAttention", {
154+
alias: "fa",
155+
type: "boolean",
156+
default: false,
157+
description: "Enable flash attention"
158+
})
152159
.option("noTrimWhitespace", {
153160
type: "boolean",
154161
alias: ["noTrim"],
@@ -269,7 +276,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
269276
},
270277
async handler({
271278
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
272-
promptFile, wrapper, noJinja, contextSize, batchSize,
279+
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention,
273280
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
274281
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
275282
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
@@ -278,9 +285,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
278285
try {
279286
await RunChat({
280287
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
281-
batchSize, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers,
282-
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
283-
noHistory, environmentFunctions, debug, meter, printTimings
288+
batchSize, flashAttention, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP,
289+
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
290+
maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
284291
});
285292
} catch (err) {
286293
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
@@ -293,9 +300,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
293300

294301
async function RunChat({
295302
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
296-
contextSize, batchSize, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature,
297-
minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty,
298-
repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
303+
contextSize, batchSize, flashAttention, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
304+
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
305+
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
299306
}: ChatCommand) {
300307
if (contextSize === -1) contextSize = undefined;
301308
if (gpuLayers === -1) gpuLayers = undefined;
@@ -360,6 +367,7 @@ async function RunChat({
360367
: contextSize != null
361368
? {fitContext: {contextSize}}
362369
: undefined,
370+
defaultContextFlashAttention: flashAttention,
363371
ignoreMemorySafetyChecks: gpuLayers != null,
364372
onLoadProgress(loadProgress: number) {
365373
progressUpdater.setProgress(loadProgress);

src/cli/commands/CompleteCommand.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type CompleteCommand = {
2828
textFile?: string,
2929
contextSize?: number,
3030
batchSize?: number,
31+
flashAttention?: boolean,
3132
threads: number,
3233
temperature: number,
3334
minP: number,
@@ -104,6 +105,12 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
104105
type: "number",
105106
description: "Batch size to use for the model context. The default value is the context size"
106107
})
108+
.option("flashAttention", {
109+
alias: "fa",
110+
type: "boolean",
111+
default: false,
112+
description: "Enable flash attention"
113+
})
107114
.option("threads", {
108115
type: "number",
109116
default: 6,
@@ -194,14 +201,14 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
194201
},
195202
async handler({
196203
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
197-
threads, temperature, minP, topK,
204+
flashAttention, threads, temperature, minP, topK,
198205
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
199206
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
200207
debug, meter, printTimings
201208
}) {
202209
try {
203210
await RunCompletion({
204-
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
211+
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
205212
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
206213
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
207214
debug, meter, printTimings
@@ -216,7 +223,7 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
216223

217224

218225
async function RunCompletion({
219-
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize,
226+
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
220227
threads, temperature, minP, topK, topP, gpuLayers,
221228
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
222229
maxTokens, debug, meter, printTimings
@@ -276,6 +283,7 @@ async function RunCompletion({
276283
: contextSize != null
277284
? {fitContext: {contextSize}}
278285
: undefined,
286+
defaultContextFlashAttention: flashAttention,
279287
ignoreMemorySafetyChecks: gpuLayers != null,
280288
onLoadProgress(loadProgress: number) {
281289
progressUpdater.setProgress(loadProgress);

src/cli/commands/InfillCommand.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type InfillCommand = {
3030
suffixFile?: string,
3131
contextSize?: number,
3232
batchSize?: number,
33+
flashAttention?: boolean,
3334
threads: number,
3435
temperature: number,
3536
minP: number,
@@ -114,6 +115,12 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
114115
type: "number",
115116
description: "Batch size to use for the model context. The default value is the context size"
116117
})
118+
.option("flashAttention", {
119+
alias: "fa",
120+
type: "boolean",
121+
default: false,
122+
description: "Enable flash attention"
123+
})
117124
.option("threads", {
118125
type: "number",
119126
default: 6,
@@ -204,14 +211,14 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
204211
},
205212
async handler({
206213
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
207-
threads, temperature, minP, topK,
214+
flashAttention, threads, temperature, minP, topK,
208215
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
209216
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
210217
debug, meter, printTimings
211218
}) {
212219
try {
213220
await RunInfill({
214-
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
221+
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
215222
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
216223
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
217224
debug, meter, printTimings
@@ -226,7 +233,7 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
226233

227234

228235
async function RunInfill({
229-
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
236+
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
230237
threads, temperature, minP, topK, topP, gpuLayers,
231238
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
232239
maxTokens, debug, meter, printTimings
@@ -300,6 +307,7 @@ async function RunInfill({
300307
: contextSize != null
301308
? {fitContext: {contextSize}}
302309
: undefined,
310+
defaultContextFlashAttention: flashAttention,
303311
ignoreMemorySafetyChecks: gpuLayers != null,
304312
onLoadProgress(loadProgress: number) {
305313
progressUpdater.setProgress(loadProgress);

0 commit comments

Comments
 (0)