Skip to content

Commit cabafea

Browse files
committed
feat: SWA support
1 parent 1799127 commit cabafea

19 files changed

+404
-138
lines changed

llama/addon/AddonContext.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
393393
context_params.n_threads = std::max(cpu_get_num_math(), 1);
394394
context_params.n_threads_batch = context_params.n_threads;
395395
context_params.no_perf = true;
396+
context_params.swa_full = false;
396397

397398
if (info.Length() > 1 && info[1].IsObject()) {
398399
Napi::Object options = info[1].As<Napi::Object>();
@@ -433,6 +434,10 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
433434
if (options.Has("performanceTracking")) {
434435
context_params.no_perf = !(options.Get("performanceTracking").As<Napi::Boolean>().Value());
435436
}
437+
438+
if (options.Has("swaFullCache")) {
439+
context_params.swa_full = options.Get("swaFullCache").As<Napi::Boolean>().Value();
440+
}
436441
}
437442
}
438443
AddonContext::~AddonContext() {
@@ -620,6 +625,19 @@ Napi::Value AddonContext::ShiftSequenceTokenCells(const Napi::CallbackInfo& info
620625

621626
return info.Env().Undefined();
622627
}
628+
Napi::Value AddonContext::GetSequenceKvCacheMinPosition(const Napi::CallbackInfo& info) {
629+
if (disposed) {
630+
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
631+
return info.Env().Undefined();
632+
}
633+
634+
int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
635+
636+
637+
const auto minPosition = llama_kv_self_seq_pos_min(ctx, sequenceId);
638+
639+
return Napi::Number::New(info.Env(), minPosition);
640+
}
623641
Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) {
624642
AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info.Env(), this);
625643
worker->Queue();
@@ -926,6 +944,7 @@ void AddonContext::init(Napi::Object exports) {
926944
InstanceMethod("disposeSequence", &AddonContext::DisposeSequence),
927945
InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence),
928946
InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells),
947+
InstanceMethod("getSequenceKvCacheMinPosition", &AddonContext::GetSequenceKvCacheMinPosition),
929948
InstanceMethod("decodeBatch", &AddonContext::DecodeBatch),
930949
InstanceMethod("sampleToken", &AddonContext::SampleToken),
931950
InstanceMethod("getEmbedding", &AddonContext::GetEmbedding),

llama/addon/AddonContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
3636
Napi::Value DisposeSequence(const Napi::CallbackInfo& info);
3737
Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info);
3838
Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info);
39+
Napi::Value GetSequenceKvCacheMinPosition(const Napi::CallbackInfo& info);
3940
Napi::Value DecodeBatch(const Napi::CallbackInfo& info);
4041
Napi::Value SampleToken(const Napi::CallbackInfo& info);
4142

llama/addon/addon.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) {
7373
return Napi::Number::New(info.Env(), typeSize);
7474
}
7575

76+
Napi::Value addonGetGgmlGraphOverheadCustom(const Napi::CallbackInfo& info) {
77+
if (info.Length() < 2 || !info[0].IsNumber() || !info[1].IsBoolean()) {
78+
return Napi::Number::New(info.Env(), 0);
79+
}
80+
81+
const size_t size = info[0].As<Napi::Number>().Uint32Value();
82+
const bool grads = info[1].As<Napi::Boolean>().Value();
83+
84+
const auto graphOverhead = ggml_graph_overhead_custom(size, grads);
85+
86+
return Napi::Number::New(info.Env(), graphOverhead);
87+
}
88+
7689
Napi::Value addonGetConsts(const Napi::CallbackInfo& info) {
7790
Napi::Object consts = Napi::Object::New(info.Env());
7891
consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS));
@@ -231,6 +244,7 @@ Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
231244
Napi::PropertyDescriptor::Function("getMathCores", addonGetMathCores),
232245
Napi::PropertyDescriptor::Function("getBlockSizeForGgmlType", addonGetBlockSizeForGgmlType),
233246
Napi::PropertyDescriptor::Function("getTypeSizeForGgmlType", addonGetTypeSizeForGgmlType),
247+
Napi::PropertyDescriptor::Function("getGgmlGraphOverheadCustom", addonGetGgmlGraphOverheadCustom),
234248
Napi::PropertyDescriptor::Function("getConsts", addonGetConsts),
235249
Napi::PropertyDescriptor::Function("setLogger", setLogger),
236250
Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel),

src/bindings/AddonTypes.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ export type BindingModule = {
2828
embeddings?: boolean,
2929
ranking?: boolean,
3030
threads?: number,
31-
performanceTracking?: boolean
31+
performanceTracking?: boolean,
32+
swaFullCache?: boolean
3233
}): AddonContext
3334
},
3435
AddonGrammar: {
@@ -54,6 +55,7 @@ export type BindingModule = {
5455
getMathCores(): number,
5556
getBlockSizeForGgmlType(ggmlType: number): number | undefined,
5657
getTypeSizeForGgmlType(ggmlType: number): number | undefined,
58+
getGgmlGraphOverheadCustom(size: number, grads: boolean): number,
5759
getConsts(): {
5860
ggmlMaxDims: number,
5961
ggmlTypeF16Size: number,
@@ -143,6 +145,7 @@ export type AddonContext = {
143145
// startPos in inclusive, endPos is exclusive
144146
shiftSequenceTokenCells(sequenceId: number, startPos: number, endPos: number, shiftDelta: number): void,
145147

148+
getSequenceKvCacheMinPosition(sequenceId: number): number,
146149
getEmbedding(inputTokensLength: number, maxVectorSize?: number): Float64Array,
147150
getStateSize(): number,
148151
getThreads(): number,

src/cli/commands/ChatCommand.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type ChatCommand = {
4545
contextSize?: number,
4646
batchSize?: number,
4747
flashAttention?: boolean,
48+
swaFullCache?: boolean,
4849
noTrimWhitespace: boolean,
4950
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
5051
jsonSchemaGrammarFile?: string,
@@ -162,6 +163,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
162163
default: false,
163164
description: "Enable flash attention"
164165
})
166+
.option("swaFullCache", {
167+
alias: "noSwa",
168+
type: "boolean",
169+
default: false,
170+
description: "Disable SWA (Sliding Window Attention) on supported models"
171+
})
165172
.option("noTrimWhitespace", {
166173
type: "boolean",
167174
alias: ["noTrim"],
@@ -308,7 +315,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
308315
},
309316
async handler({
310317
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
311-
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention,
318+
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention, swaFullCache,
312319
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
313320
topP, seed, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
314321
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
@@ -317,7 +324,8 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
317324
try {
318325
await RunChat({
319326
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
320-
batchSize, flashAttention, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, seed,
327+
batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads,
328+
temperature, minP, topK, topP, seed,
321329
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
322330
maxTokens, noHistory, environmentFunctions, tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter,
323331
timing, noMmap, printTimings
@@ -333,7 +341,8 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
333341

334342
async function RunChat({
335343
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
336-
contextSize, batchSize, flashAttention, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
344+
contextSize, batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar: grammarArg,
345+
jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
337346
threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
338347
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, tokenPredictionDraftModel,
339348
tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
@@ -363,11 +372,13 @@ async function RunChat({
363372

364373
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
365374
flashAttention,
375+
swaFullCache,
366376
useMmap
367377
});
368378
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
369379
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
370380
flashAttention,
381+
swaFullCache,
371382
useMmap,
372383
consoleTitle: "Draft model file"
373384
})
@@ -413,6 +424,7 @@ async function RunChat({
413424
? {fitContext: {contextSize}}
414425
: undefined,
415426
defaultContextFlashAttention: flashAttention,
427+
defaultContextSwaFullCache: swaFullCache,
416428
useMmap,
417429
ignoreMemorySafetyChecks: gpuLayers != null,
418430
onLoadProgress(loadProgress: number) {
@@ -446,6 +458,7 @@ async function RunChat({
446458
return await llama.loadModel({
447459
modelPath: resolvedDraftModelPath,
448460
defaultContextFlashAttention: flashAttention,
461+
defaultContextSwaFullCache: swaFullCache,
449462
useMmap,
450463
onLoadProgress(loadProgress: number) {
451464
progressUpdater.setProgress(loadProgress);

src/cli/commands/CompleteCommand.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type CompleteCommand = {
3232
contextSize?: number,
3333
batchSize?: number,
3434
flashAttention?: boolean,
35+
swaFullCache?: boolean,
3536
threads?: number,
3637
temperature: number,
3738
minP: number,
@@ -119,6 +120,12 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
119120
default: false,
120121
description: "Enable flash attention"
121122
})
123+
.option("swaFullCache", {
124+
alias: "noSwa",
125+
type: "boolean",
126+
default: false,
127+
description: "Disable SWA (Sliding Window Attention) on supported models"
128+
})
122129
.option("threads", {
123130
type: "number",
124131
defaultDescription: "Number of cores that are useful for math on the current machine",
@@ -235,14 +242,14 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
235242
},
236243
async handler({
237244
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
238-
flashAttention, threads, temperature, minP, topK,
245+
flashAttention, swaFullCache, threads, temperature, minP, topK,
239246
topP, seed, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
240247
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, tokenPredictionDraftModel, tokenPredictionModelContextSize,
241248
debug, meter, timing, noMmap, printTimings
242249
}) {
243250
try {
244251
await RunCompletion({
245-
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
252+
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
246253
threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty,
247254
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
248255
tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
@@ -257,7 +264,7 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
257264

258265

259266
async function RunCompletion({
260-
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
267+
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
261268
threads, temperature, minP, topK, topP, seed, gpuLayers,
262269
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
263270
tokenPredictionDraftModel, tokenPredictionModelContextSize, maxTokens, debug, meter, timing, noMmap, printTimings
@@ -286,11 +293,13 @@ async function RunCompletion({
286293

287294
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
288295
flashAttention,
296+
swaFullCache,
289297
useMmap
290298
});
291299
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
292300
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
293301
flashAttention,
302+
swaFullCache,
294303
useMmap,
295304
consoleTitle: "Draft model file"
296305
})
@@ -329,6 +338,7 @@ async function RunCompletion({
329338
? {fitContext: {contextSize}}
330339
: undefined,
331340
defaultContextFlashAttention: flashAttention,
341+
defaultContextSwaFullCache: swaFullCache,
332342
useMmap,
333343
ignoreMemorySafetyChecks: gpuLayers != null,
334344
onLoadProgress(loadProgress: number) {
@@ -362,6 +372,7 @@ async function RunCompletion({
362372
return await llama.loadModel({
363373
modelPath: resolvedDraftModelPath,
364374
defaultContextFlashAttention: flashAttention,
375+
defaultContextSwaFullCache: swaFullCache,
365376
useMmap,
366377
onLoadProgress(loadProgress: number) {
367378
progressUpdater.setProgress(loadProgress);

src/cli/commands/InfillCommand.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type InfillCommand = {
3434
contextSize?: number,
3535
batchSize?: number,
3636
flashAttention?: boolean,
37+
swaFullCache?: boolean,
3738
threads?: number,
3839
temperature: number,
3940
minP: number,
@@ -129,6 +130,12 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
129130
default: false,
130131
description: "Enable flash attention"
131132
})
133+
.option("swaFullCache", {
134+
alias: "noSwa",
135+
type: "boolean",
136+
default: false,
137+
description: "Disable SWA (Sliding Window Attention) on supported models"
138+
})
132139
.option("threads", {
133140
type: "number",
134141
defaultDescription: "Number of cores that are useful for math on the current machine",
@@ -245,15 +252,15 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
245252
},
246253
async handler({
247254
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
248-
flashAttention, threads, temperature, minP, topK,
255+
flashAttention, swaFullCache, threads, temperature, minP, topK,
249256
topP, seed, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
250257
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, tokenPredictionDraftModel, tokenPredictionModelContextSize,
251258
debug, meter, timing, noMmap, printTimings
252259
}) {
253260
try {
254261
await RunInfill({
255262
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
256-
threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty,
263+
swaFullCache, threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty,
257264
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
258265
tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
259266
});
@@ -268,7 +275,7 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
268275

269276
async function RunInfill({
270277
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
271-
threads, temperature, minP, topK, topP, seed, gpuLayers,
278+
swaFullCache, threads, temperature, minP, topK, topP, seed, gpuLayers,
272279
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
273280
tokenPredictionDraftModel, tokenPredictionModelContextSize, maxTokens, debug, meter, timing, noMmap, printTimings
274281
}: InfillCommand) {
@@ -296,11 +303,13 @@ async function RunInfill({
296303

297304
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
298305
flashAttention,
306+
swaFullCache,
299307
useMmap
300308
});
301309
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
302310
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
303311
flashAttention,
312+
swaFullCache,
304313
useMmap,
305314
consoleTitle: "Draft model file"
306315
})
@@ -353,6 +362,7 @@ async function RunInfill({
353362
? {fitContext: {contextSize}}
354363
: undefined,
355364
defaultContextFlashAttention: flashAttention,
365+
defaultContextSwaFullCache: swaFullCache,
356366
useMmap,
357367
ignoreMemorySafetyChecks: gpuLayers != null,
358368
onLoadProgress(loadProgress: number) {
@@ -386,6 +396,7 @@ async function RunInfill({
386396
return await llama.loadModel({
387397
modelPath: resolvedDraftModelPath,
388398
defaultContextFlashAttention: flashAttention,
399+
defaultContextSwaFullCache: swaFullCache,
389400
useMmap,
390401
onLoadProgress(loadProgress: number) {
391402
progressUpdater.setProgress(loadProgress);

src/cli/commands/inspect/commands/InspectEstimateCommand.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ type InspectEstimateCommand = {
3232
gpuLayers?: number | "max",
3333
contextSize?: number | "train",
3434
embedding?: boolean,
35-
noMmap?: boolean
35+
noMmap?: boolean,
36+
swaFullCache?: boolean
3637
};
3738

3839
export const InspectEstimateCommand: CommandModule<object, InspectEstimateCommand> = {
@@ -115,10 +116,16 @@ export const InspectEstimateCommand: CommandModule<object, InspectEstimateComman
115116
type: "boolean",
116117
default: false,
117118
description: "Disable mmap (memory-mapped file) usage"
119+
})
120+
.option("swaFullCache", {
121+
alias: "noSwa",
122+
type: "boolean",
123+
default: false,
124+
description: "Disable SWA (Sliding Window Attention) on supported models"
118125
});
119126
},
120127
async handler({
121-
modelPath: ggufPath, header: headerArg, gpu, gpuLayers, contextSize: contextSizeArg, embedding, noMmap
128+
modelPath: ggufPath, header: headerArg, gpu, gpuLayers, contextSize: contextSizeArg, embedding, noMmap, swaFullCache
122129
}: InspectEstimateCommand) {
123130
if (gpuLayers === -1) gpuLayers = undefined;
124131
if (gpuLayers === -2) gpuLayers = "max";
@@ -181,7 +188,8 @@ export const InspectEstimateCommand: CommandModule<object, InspectEstimateComman
181188
targetContextSize: contextSize,
182189
targetGpuLayers: gpuLayers,
183190
embeddingContext: embedding,
184-
useMmap
191+
useMmap,
192+
swaFullCache
185193
});
186194
}
187195

0 commit comments

Comments
 (0)