@@ -41,6 +41,7 @@ type ChatCommand = {
41
41
noJinja ?: boolean ,
42
42
contextSize ?: number ,
43
43
batchSize ?: number ,
44
+ flashAttention ?: boolean ,
44
45
noTrimWhitespace : boolean ,
45
46
grammar : "text" | Parameters < typeof LlamaGrammar . getFor > [ 1 ] ,
46
47
jsonSchemaGrammarFile ?: string ,
@@ -149,6 +150,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
149
150
type : "number" ,
150
151
description : "Batch size to use for the model context. The default value is the context size"
151
152
} )
153
+ . option ( "flashAttention" , {
154
+ alias : "fa" ,
155
+ type : "boolean" ,
156
+ default : false ,
157
+ description : "Enable flash attention"
158
+ } )
152
159
. option ( "noTrimWhitespace" , {
153
160
type : "boolean" ,
154
161
alias : [ "noTrim" ] ,
@@ -269,7 +276,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
269
276
} ,
270
277
async handler ( {
271
278
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
272
- promptFile, wrapper, noJinja, contextSize, batchSize,
279
+ promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention ,
273
280
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
274
281
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
275
282
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
@@ -278,9 +285,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
278
285
try {
279
286
await RunChat ( {
280
287
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
281
- batchSize, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers ,
282
- lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens ,
283
- noHistory, environmentFunctions, debug, meter, printTimings
288
+ batchSize, flashAttention , noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP,
289
+ gpuLayers , lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
290
+ maxTokens , noHistory, environmentFunctions, debug, meter, printTimings
284
291
} ) ;
285
292
} catch ( err ) {
286
293
await new Promise ( ( accept ) => setTimeout ( accept , 0 ) ) ; // wait for logs to finish printing
@@ -293,9 +300,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
293
300
294
301
async function RunChat ( {
295
302
modelPath : modelArg , header : headerArg , gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
296
- contextSize, batchSize, noTrimWhitespace, grammar : grammarArg , jsonSchemaGrammarFile : jsonSchemaGrammarFilePath , threads , temperature ,
297
- minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty ,
298
- repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
303
+ contextSize, batchSize, flashAttention , noTrimWhitespace, grammar : grammarArg , jsonSchemaGrammarFile : jsonSchemaGrammarFilePath ,
304
+ threads , temperature , minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
305
+ repeatFrequencyPenalty , repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
299
306
} : ChatCommand ) {
300
307
if ( contextSize === - 1 ) contextSize = undefined ;
301
308
if ( gpuLayers === - 1 ) gpuLayers = undefined ;
@@ -360,6 +367,7 @@ async function RunChat({
360
367
: contextSize != null
361
368
? { fitContext : { contextSize} }
362
369
: undefined ,
370
+ defaultContextFlashAttention : flashAttention ,
363
371
ignoreMemorySafetyChecks : gpuLayers != null ,
364
372
onLoadProgress ( loadProgress : number ) {
365
373
progressUpdater . setProgress ( loadProgress ) ;
0 commit comments