Skip to content

Commit 47c3c5f

Browse files
authored
feat: threads count setting on a model (#33)
1 parent 9bdef11 commit 47c3c5f

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ Optional:
287287
-c, --contextSize Context size to use for the model [number] [default: 4096]
288288
-g, --grammar Restrict the model response to a specific grammar, like JSON for example
289289
[string] [choices: "text", "json", "list", "arithmetic", "japanese", "chess"] [default: "text"]
290+
--threads Number of threads to use for the evaluation of tokens [number] [default: 6]
290291
-t, --temperature Temperature is a hyperparameter that controls the randomness of the generat
291292
ed text. It affects the probability distribution of the model's output toke
292293
ns. A higher temperature (e.g., 1.5) makes the output more random and creat

llama/addon.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
1313
llama_context_params params;
1414
llama_model* model;
1515
float temperature;
16+
int threads;
1617
int32_t top_k;
1718
float top_p;
1819

@@ -21,6 +22,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
2122
params.seed = -1;
2223
params.n_ctx = 4096;
2324
temperature = 0.0f;
25+
threads = 6;
2426
top_k = 40;
2527
top_p = 0.95f;
2628

@@ -74,6 +76,10 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
7476
params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
7577
}
7678

79+
if (options.Has("threads")) {
80+
threads = options.Get("threads").As<Napi::Number>().Int32Value();
81+
}
82+
7783
if (options.Has("temperature")) {
7884
temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
7985
}
@@ -283,7 +289,7 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
283289
protected:
284290
void Execute() {
285291
// Perform the evaluation using llama_eval.
286-
int r = llama_eval(ctx->ctx, tokens.data(), int(tokens.size()), llama_get_kv_cache_token_count(ctx->ctx), 6);
292+
int r = llama_eval(ctx->ctx, tokens.data(), int(tokens.size()), llama_get_kv_cache_token_count(ctx->ctx), (ctx->model)->threads);
287293
if (r != 0) {
288294
SetError("Eval has failed");
289295
return;

src/cli/commands/ChatCommand.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type ChatCommand = {
1818
wrapper: "auto" | "general" | "llamaChat" | "chatML",
1919
contextSize: number,
2020
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[0],
21+
threads: number,
2122
temperature: number,
2223
topK: number,
2324
topP: number,
@@ -76,6 +77,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
7677
description: "Restrict the model response to a specific grammar, like JSON for example",
7778
group: "Optional:"
7879
})
80+
.option("threads", {
81+
type: "number",
82+
default: 6,
83+
description: "Number of threads to use for the evaluation of tokens",
84+
group: "Optional:"
85+
})
7986
.option("temperature", {
8087
alias: "t",
8188
type: "number",
@@ -107,10 +114,10 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
107114
},
108115
async handler({
109116
model, systemInfo, systemPrompt, wrapper, contextSize, grammar,
110-
temperature, topK, topP, maxTokens
117+
threads, temperature, topK, topP, maxTokens
111118
}) {
112119
try {
113-
await RunChat({model, systemInfo, systemPrompt, wrapper, contextSize, grammar, temperature, topK, topP, maxTokens});
120+
await RunChat({model, systemInfo, systemPrompt, wrapper, contextSize, grammar, threads, temperature, topK, topP, maxTokens});
114121
} catch (err) {
115122
console.error(err);
116123
process.exit(1);
@@ -120,7 +127,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
120127

121128

122129
async function RunChat({
123-
model: modelArg, systemInfo, systemPrompt, wrapper, contextSize, grammar: grammarArg, temperature, topK, topP, maxTokens
130+
model: modelArg, systemInfo, systemPrompt, wrapper, contextSize, grammar: grammarArg, threads, temperature, topK, topP, maxTokens
124131
}: ChatCommand) {
125132
const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js");
126133
const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js");
@@ -130,6 +137,7 @@ async function RunChat({
130137
const model = new LlamaModel({
131138
modelPath: modelArg,
132139
contextSize,
140+
threads,
133141
temperature,
134142
topK,
135143
topP

src/llamaEvaluator/LlamaModel.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ export type LlamaModelOptions = {
2121
/** if true, reduce VRAM usage at the cost of performance */
2222
lowVram?: boolean,
2323

24+
/** number of threads to use to evaluate tokens */
25+
threads?: number,
26+
2427
/**
2528
* Temperature is a hyperparameter that controls the randomness of the generated text.
2629
* It affects the probability distribution of the model's output tokens.
@@ -85,6 +88,7 @@ export class LlamaModel {
8588
* @param {number} [options.batchSize] - prompt processing batch size
8689
* @param {number} [options.gpuLayers] - number of layers to store in VRAM
8790
* @param {boolean} [options.lowVram] - if true, reduce VRAM usage at the cost of performance
91+
* @param {number} [options.threads] - number of threads to use to evaluate tokens
8892
* @param {number} [options.temperature] - Temperature is a hyperparameter that controls the randomness of the generated text.
8993
* It affects the probability distribution of the model's output tokens.
9094
* A higher temperature (e.g., 1.5) makes the output more random and creative,
@@ -114,14 +118,15 @@ export class LlamaModel {
114118
*/
115119
public constructor({
116120
modelPath, seed = null, contextSize = 1024 * 4, batchSize, gpuLayers,
117-
lowVram, temperature = 0, topK = 40, topP = 0.95, f16Kv, logitsAll, vocabOnly, useMmap, useMlock, embedding
121+
lowVram, threads = 6, temperature = 0, topK = 40, topP = 0.95, f16Kv, logitsAll, vocabOnly, useMmap, useMlock, embedding
118122
}: LlamaModelOptions) {
119123
this._model = new LLAMAModel(modelPath, removeNullFields({
120124
seed: seed != null ? Math.max(-1, seed) : undefined,
121125
contextSize,
122126
batchSize,
123127
gpuLayers,
124128
lowVram,
129+
threads,
125130
temperature,
126131
topK,
127132
topP,

src/utils/getBin.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ export type LLAMAModel = {
111111
useMmap?: boolean,
112112
useMlock?: boolean,
113113
embedding?: boolean,
114+
threads?: number,
114115
temperature?: number,
115116
topK?: number,
116117
topP?: number

0 commit comments

Comments
 (0)