Skip to content

Commit c76ec48

Browse files
authored
feat: add support for some llama.cpp params on LlamaModel (#5)
1 parent adfce4c commit c76ec48

File tree

7 files changed

+163
-43
lines changed

7 files changed

+163
-43
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url));
3030
const model = new LlamaModel({
3131
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
3232
});
33-
const session = new LlamaChatSession({context: model.createContext()});
33+
const context = new LlamaContext({model});
34+
const session = new LlamaChatSession({context});
3435

3536

3637
const q1 = "Hi there, how are you?";
@@ -73,7 +74,8 @@ const model = new LlamaModel({
7374
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin"),
7475
promptWrapper: new MyCustomChatPromptWrapper() // by default, LlamaChatPromptWrapper is used
7576
})
76-
const session = new LlamaChatSession({context: model.createContext()});
77+
const context = new LlamaContext({model});
78+
const session = new LlamaChatSession({context});
7779

7880

7981
const q1 = "Hi there, how are you?";
@@ -102,7 +104,7 @@ const model = new LlamaModel({
102104
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
103105
});
104106

105-
const context = model.createContext();
107+
const context = new LlamaContext({model});
106108

107109
const q1 = "Hi there, how are you?";
108110
console.log("AI: " + q1);

llama/addon.cpp

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,80 @@
88

99
class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
1010
public:
11-
llama_context_params params;
12-
llama_model* model;
13-
LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
14-
params = llama_context_default_params();
15-
params.seed = -1;
16-
params.n_ctx = 4096;
17-
model = llama_load_model_from_file(info[0].As<Napi::String>().Utf8Value().c_str(), params);
18-
19-
if (model == NULL) {
20-
Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
21-
return;
11+
llama_context_params params;
12+
llama_model* model;
13+
14+
LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
15+
params = llama_context_default_params();
16+
params.seed = -1;
17+
params.n_ctx = 4096;
18+
19+
// Get the model path
20+
std::string modelPath = info[0].As<Napi::String>().Utf8Value();
21+
22+
if (info.Length() > 1 && info[1].IsObject()) {
23+
Napi::Object options = info[1].As<Napi::Object>();
24+
25+
if (options.Has("seed")) {
26+
params.seed = options.Get("seed").As<Napi::Number>().Int32Value();
27+
}
28+
29+
if (options.Has("contextSize")) {
30+
params.n_ctx = options.Get("contextSize").As<Napi::Number>().Int32Value();
31+
}
32+
33+
if (options.Has("batchSize")) {
34+
params.n_batch = options.Get("batchSize").As<Napi::Number>().Int32Value();
35+
}
36+
37+
if (options.Has("gpuCores")) {
38+
params.n_gpu_layers = options.Get("gpuCores").As<Napi::Number>().Int32Value();
39+
}
40+
41+
if (options.Has("lowVram")) {
42+
params.low_vram = options.Get("lowVram").As<Napi::Boolean>().Value();
43+
}
44+
45+
if (options.Has("f16Kv")) {
46+
params.f16_kv = options.Get("f16Kv").As<Napi::Boolean>().Value();
47+
}
48+
49+
if (options.Has("logitsAll")) {
50+
params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
51+
}
52+
53+
if (options.Has("vocabOnly")) {
54+
params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
55+
}
56+
57+
if (options.Has("useMmap")) {
58+
params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
59+
}
60+
61+
if (options.Has("useMlock")) {
62+
params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
63+
}
64+
65+
if (options.Has("embedding")) {
66+
params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
67+
}
68+
}
69+
70+
model = llama_load_model_from_file(modelPath.c_str(), params);
71+
72+
if (model == NULL) {
73+
Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
74+
return;
75+
}
76+
}
77+
78+
~LLAMAModel() {
79+
llama_free_model(model);
80+
}
81+
82+
static void init(Napi::Object exports) {
83+
exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {}));
2284
}
23-
}
24-
~LLAMAModel() { llama_free_model(model); }
25-
static void init(Napi::Object exports) { exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {})); }
2685
};
2786

2887
class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {

src/cli/commands/ChatCommand.ts

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ type ChatCommand = {
1111
model: string,
1212
systemInfo: boolean,
1313
systemPrompt: string,
14-
wrapper: string
14+
wrapper: string,
15+
contextSize: number
1516
};
1617

1718
export const ChatCommand: CommandModule<object, ChatCommand> = {
@@ -46,11 +47,17 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
4647
choices: ["general", "llama"],
4748
description: "Chat wrapper to use",
4849
group: "Optional:"
50+
})
51+
.option("contextSize", {
52+
type: "number",
53+
default: 1024 * 4,
54+
description: "Context size to use for the model",
55+
group: "Optional:"
4956
});
5057
},
51-
async handler({model, systemInfo, systemPrompt, wrapper}) {
58+
async handler({model, systemInfo, systemPrompt, wrapper, contextSize}) {
5259
try {
53-
await RunChat({model, systemInfo, systemPrompt, wrapper});
60+
await RunChat({model, systemInfo, systemPrompt, wrapper, contextSize});
5461
} catch (err) {
5562
console.error(err);
5663
process.exit(1);
@@ -59,15 +66,18 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
5966
};
6067

6168

62-
async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: ChatCommand) {
69+
async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper, contextSize}: ChatCommand) {
6370
const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js");
6471
const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js");
72+
const {LlamaContext} = await import("../../llamaEvaluator/LlamaContext.js");
6573

6674
const model = new LlamaModel({
67-
modelPath: modelArg
75+
modelPath: modelArg,
76+
contextSize
6877
});
78+
const context = new LlamaContext({model});
6979
const session = new LlamaChatSession({
70-
context: model.createContext(),
80+
context,
7181
printLLamaSystemInfo: systemInfo,
7282
systemPrompt,
7383
promptWrapper: createChatWrapper(wrapper)

src/llamaEvaluator/LlamaChatSession.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper
66
import {LlamaModel} from "./LlamaModel.js";
77
import {LlamaContext} from "./LlamaContext.js";
88

9-
const UNKNOWN_UNICODE_CHAR = "";
9+
const UNKNOWN_UNICODE_CHAR = "\ufffd";
1010

1111
export class LlamaChatSession {
1212
private readonly _systemPrompt: string;

src/llamaEvaluator/LlamaContext.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import {LLAMAContext, llamaCppNode} from "./LlamaBins.js";
2+
import {LlamaModel} from "./LlamaModel.js";
23

3-
type LlamaContextConstructorParameters = {prependBos: boolean, ctx: LLAMAContext};
44
export class LlamaContext {
55
private readonly _ctx: LLAMAContext;
66
private _prependBos: boolean;
77

8-
/** @internal */
9-
public constructor( {ctx, prependBos}: LlamaContextConstructorParameters ) {
10-
this._ctx = ctx;
8+
public constructor({model, prependBos = true}: {model: LlamaModel, prependBos?: boolean}) {
9+
this._ctx = new LLAMAContext(model._model);
1110
this._prependBos = prependBos;
1211
}
1312

src/llamaEvaluator/LlamaModel.ts

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,62 @@
1-
import {LlamaContext} from "./LlamaContext.js";
2-
import {LLAMAContext, llamaCppNode, LLAMAModel} from "./LlamaBins.js";
1+
import {llamaCppNode, LLAMAModel} from "./LlamaBins.js";
32

43

54
export class LlamaModel {
6-
private readonly _model: LLAMAModel;
7-
private readonly _prependBos: boolean;
5+
/** @internal */
6+
public readonly _model: LLAMAModel;
87

9-
public constructor({modelPath, prependBos = true}: { modelPath: string, prependBos?: boolean }) {
10-
this._model = new LLAMAModel(modelPath);
11-
this._prependBos = prependBos;
12-
}
13-
14-
public createContext() {
15-
return new LlamaContext({
16-
ctx: new LLAMAContext(this._model),
17-
prependBos: this._prependBos
18-
});
8+
/**
9+
* options source:
10+
* https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/llama.h#L102 (struct llama_context_params)
11+
* @param {object} options
12+
* @param {string} options.modelPath - path to the model on the filesystem
13+
* @param {number | null} [options.seed] - If null, a random seed will be used
14+
* @param {number} [options.contextSize] - text context size
15+
* @param {number} [options.batchSize] - prompt processing batch size
16+
* @param {number} [options.gpuCores] - number of layers to store in VRAM
17+
* @param {boolean} [options.lowVram] - if true, reduce VRAM usage at the cost of performance
18+
* @param {boolean} [options.f16Kv] - use fp16 for KV cache
19+
* @param {boolean} [options.logitsAll] - the llama_eval() call computes all logits, not just the last one
20+
* @param {boolean} [options.vocabOnly] - only load the vocabulary, no weights
21+
* @param {boolean} [options.useMmap] - use mmap if possible
22+
* @param {boolean} [options.useMlock] - force system to keep model in RAM
23+
* @param {boolean} [options.embedding] - embedding mode only
24+
*/
25+
public constructor({
26+
modelPath, seed = null, contextSize = 1024 * 4, batchSize, gpuCores,
27+
lowVram, f16Kv, logitsAll, vocabOnly, useMmap, useMlock, embedding
28+
}: {
29+
modelPath: string, seed?: number | null, contextSize?: number, batchSize?: number, gpuCores?: number,
30+
lowVram?: boolean, f16Kv?: boolean, logitsAll?: boolean, vocabOnly?: boolean, useMmap?: boolean, useMlock?: boolean,
31+
embedding?: boolean
32+
}) {
33+
this._model = new LLAMAModel(modelPath, removeNullFields({
34+
seed: seed != null ? Math.max(-1, seed) : undefined,
35+
contextSize,
36+
batchSize,
37+
gpuCores,
38+
lowVram,
39+
f16Kv,
40+
logitsAll,
41+
vocabOnly,
42+
useMmap,
43+
useMlock,
44+
embedding
45+
}));
1946
}
2047

2148
public static get systemInfo() {
2249
return llamaCppNode.systemInfo();
2350
}
2451
}
52+
53+
function removeNullFields<T extends object>(obj: T): T {
54+
const newObj: T = Object.assign({}, obj);
55+
56+
for (const key in obj) {
57+
if (newObj[key] == null)
58+
delete newObj[key];
59+
}
60+
61+
return newObj;
62+
}

src/utils/getBin.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,19 @@ export type LlamaCppNodeModule = {
9595
};
9696

9797
export type LLAMAModel = {
98-
new (modelPath: string): LLAMAModel
98+
new (modelPath: string, params: {
99+
seed?: number,
100+
contextSize?: number,
101+
batchSize?: number,
102+
gpuCores?: number,
103+
lowVram?: boolean,
104+
f16Kv?: boolean,
105+
logitsAll?: boolean,
106+
vocabOnly?: boolean,
107+
useMmap?: boolean,
108+
useMlock?: boolean,
109+
embedding?: boolean
110+
}): LLAMAModel
99111
};
100112

101113
export type LLAMAContext = {

0 commit comments

Comments
 (0)