Skip to content

Commit 8086c5f

Browse files
authored
fix: llama.cpp interface breaking change (#10)
* Made changes to adapt to new llama.cpp interface breaking changes * Now saving the release that the binaries were compiled for as part of the final npm package and download it by default BREAKING CHANGE: only `.gguf` models are supported from now on
1 parent 54a1c6f commit 8086c5f

File tree

11 files changed

+111
-63
lines changed

11 files changed

+111
-63
lines changed

.github/workflows/build.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@ jobs:
2020
- name: Generate docs
2121
run: npm run generate-docs
2222
- name: Download latest llama.cpp release
23-
run: node ./dist/cli/cli.js download --release latest --skipBuild
23+
run: node ./dist/cli/cli.js download --release latest --skipBuild --updateBinariesReleaseMetadata
2424
- name: Upload build artifact
2525
uses: actions/upload-artifact@v3
2626
with:
2727
name: "build"
2828
path: "dist"
29+
- name: Upload binariesGithubRelease.json artifact
30+
uses: actions/upload-artifact@v3
31+
with:
32+
name: "binariesGithubRelease"
33+
path: "llama/binariesGithubRelease.json"
2934
- name: Upload build artifact
3035
uses: actions/upload-artifact@v3
3136
with:
@@ -227,6 +232,9 @@ jobs:
227232
mv artifacts/build dist/
228233
mv artifacts/docs docs/
229234
235+
rm -f ./llama/binariesGithubRelease
236+
mv artifacts/binariesGithubRelease ./llama/binariesGithubRelease.json
237+
230238
echo "Built binaries:"
231239
ls llamaBins
232240
- name: Release

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp";
3030
const __dirname = path.dirname(fileURLToPath(import.meta.url));
3131

3232
const model = new LlamaModel({
33-
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
33+
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf")
3434
});
3535
const context = new LlamaContext({model});
3636
const session = new LlamaChatSession({context});
@@ -73,7 +73,7 @@ export class MyCustomChatPromptWrapper extends ChatPromptWrapper {
7373
}
7474

7575
const model = new LlamaModel({
76-
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin"),
76+
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf"),
7777
promptWrapper: new MyCustomChatPromptWrapper() // by default, LlamaChatPromptWrapper is used
7878
})
7979
const context = new LlamaContext({model});
@@ -103,7 +103,7 @@ import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp";
103103
const __dirname = path.dirname(fileURLToPath(import.meta.url));
104104

105105
const model = new LlamaModel({
106-
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
106+
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf")
107107
});
108108

109109
const context = new LlamaContext({model});

llama/addon.cpp

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
6767
}
6868
}
6969

70+
llama_backend_init(false);
7071
model = llama_load_model_from_file(modelPath.c_str(), params);
7172

7273
if (model == NULL) {
@@ -124,7 +125,18 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
124125

125126
// Decode each token and accumulate the result.
126127
for (size_t i = 0; i < tokens.ElementLength(); i++) {
127-
const char* str = llama_token_to_str(ctx, (llama_token)tokens[i]);
128+
// source: https://github.com/ggerganov/llama.cpp/blob/232caf3c1581a6cb023571780ff41dc2d66d1ca0/llama.cpp#L799-L811
129+
std::vector<char> result(8, 0);
130+
const int n_tokens = llama_token_to_str(ctx, (llama_token)tokens[i], result.data(), result.size());
131+
if (n_tokens < 0) {
132+
result.resize(-n_tokens);
133+
int check = llama_token_to_str(ctx, (llama_token)tokens[i], result.data(), result.size());
134+
GGML_ASSERT(check == -n_tokens);
135+
} else {
136+
result.resize(n_tokens);
137+
}
138+
139+
const char* str = result.data();
128140
if (str == nullptr) {
129141
Napi::Error::New(info.Env(), "Invalid token").ThrowAsJavaScriptException();
130142
return info.Env().Undefined();
@@ -134,6 +146,15 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
134146

135147
return Napi::String::New(info.Env(), ss.str());
136148
}
149+
Napi::Value TokenBos(const Napi::CallbackInfo& info) {
150+
return Napi::Number::From(info.Env(), llama_token_bos(ctx));
151+
}
152+
Napi::Value TokenEos(const Napi::CallbackInfo& info) {
153+
return Napi::Number::From(info.Env(), llama_token_eos(ctx));
154+
}
155+
Napi::Value GetMaxContextSize(const Napi::CallbackInfo& info) {
156+
return Napi::Number::From(info.Env(), llama_n_ctx(ctx));
157+
}
137158
Napi::Value Eval(const Napi::CallbackInfo& info);
138159
static void init(Napi::Object exports) {
139160
exports.Set("LLAMAContext",
@@ -142,6 +163,9 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
142163
{
143164
InstanceMethod("encode", &LLAMAContext::Encode),
144165
InstanceMethod("decode", &LLAMAContext::Decode),
166+
InstanceMethod("tokenBos", &LLAMAContext::TokenBos),
167+
InstanceMethod("tokenEos", &LLAMAContext::TokenEos),
168+
InstanceMethod("getMaxContextSize", &LLAMAContext::GetMaxContextSize),
145169
InstanceMethod("eval", &LLAMAContext::Eval),
146170
}));
147171
}
@@ -151,7 +175,6 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
151175
class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
152176
LLAMAContext* ctx;
153177
std::vector<llama_token> tokens;
154-
std::vector<llama_token> restriction;
155178
llama_token result;
156179

157180
public:
@@ -160,13 +183,6 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
160183
Napi::Uint32Array tokens = info[0].As<Napi::Uint32Array>();
161184
this->tokens.reserve(tokens.ElementLength());
162185
for (size_t i = 0; i < tokens.ElementLength(); i++) { this->tokens.push_back(static_cast<llama_token>(tokens[i])); }
163-
164-
if (info.Length() > 1 && info[1].IsTypedArray()) {
165-
Napi::Uint32Array restriction = info[1].As<Napi::Uint32Array>();
166-
this->restriction.reserve(restriction.ElementLength());
167-
for (size_t i = 0; i < restriction.ElementLength(); i++) { this->restriction.push_back(static_cast<llama_token>(restriction[i])); }
168-
std::sort(this->restriction.begin(), this->restriction.end());
169-
}
170186
}
171187
~LLAMAContextEvalWorker() { ctx->Unref(); }
172188
using Napi::AsyncWorker::Queue;
@@ -175,39 +191,30 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
175191
protected:
176192
void Execute() {
177193
// Perform the evaluation using llama_eval.
178-
int r = llama_eval(ctx->ctx, tokens.data(), tokens.size(), llama_get_kv_cache_token_count(ctx->ctx), 6);
194+
int r = llama_eval(ctx->ctx, tokens.data(), int(tokens.size()), llama_get_kv_cache_token_count(ctx->ctx), 6);
179195
if (r != 0) {
180196
SetError("Eval has failed");
181197
return;
182198
}
183199

200+
llama_token new_token_id = 0;
201+
184202
// Select the best prediction.
185-
float* logits = llama_get_logits(ctx->ctx);
186-
int n_vocab = llama_n_vocab(ctx->ctx);
187-
llama_token re;
188-
if (restriction.empty()) {
189-
float max = logits[0];
190-
re = 0;
191-
for (llama_token id = 1; id < n_vocab; id++) {
192-
float logit = logits[id];
193-
if (logit > max) {
194-
max = logit;
195-
re = id;
196-
}
197-
}
198-
} else {
199-
float max = logits[restriction[0]];
200-
re = 0;
201-
for (size_t i = 1; i < restriction.size(); i++) {
202-
llama_token id = restriction[i];
203-
float logit = logits[id];
204-
if (logit > max) {
205-
max = logit;
206-
re = id;
207-
}
208-
}
203+
auto logits = llama_get_logits(ctx->ctx);
204+
auto n_vocab = llama_n_vocab(ctx->ctx);
205+
206+
std::vector<llama_token_data> candidates;
207+
candidates.reserve(n_vocab);
208+
209+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
210+
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
209211
}
210-
result = re;
212+
213+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
214+
215+
new_token_id = llama_sample_token_greedy(ctx->ctx , &candidates_p);
216+
217+
result = new_token_id;
211218
}
212219
void OnOK() {
213220
Napi::Env env = Napi::AsyncWorker::Env();
@@ -223,15 +230,11 @@ Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
223230
return worker->Promise();
224231
}
225232

226-
Napi::Value tokenBos(const Napi::CallbackInfo& info) { return Napi::Number::From(info.Env(), llama_token_bos()); }
227-
Napi::Value tokenEos(const Napi::CallbackInfo& info) { return Napi::Number::From(info.Env(), llama_token_eos()); }
228233
Napi::Value systemInfo(const Napi::CallbackInfo& info) { return Napi::String::From(info.Env(), llama_print_system_info()); }
229234

230235
Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
231236
llama_backend_init(false);
232237
exports.DefineProperties({
233-
Napi::PropertyDescriptor::Function("tokenBos", tokenBos),
234-
Napi::PropertyDescriptor::Function("tokenEos", tokenEos),
235238
Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
236239
});
237240
LLAMAModel::init(exports);

llama/binariesGithubRelease.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"release": "latest"
3+
}

package.json

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,14 @@
6868
"node-gyp",
6969
"prebuilt-binaries",
7070
"llm",
71-
"ggml",
72-
"ggmlv3",
71+
"gguf",
7372
"raspberry-pi",
7473
"self-hosted",
7574
"local",
7675
"catai"
7776
],
7877
"author": "Gilad S.",
79-
"license": "ISC",
78+
"license": "MIT",
8079
"bugs": {
8180
"url": "https://github.com/withcatai/node-llama-cpp/issues"
8281
},

src/chatWrappers/LlamaChatPromptWrapper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
33
// source: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
44
export class LlamaChatPromptWrapper extends ChatPromptWrapper {
55
public override wrapPrompt(prompt: string, {systemPrompt, promptIndex}: {systemPrompt: string, promptIndex: number}) {
6-
if (promptIndex === 0) {
6+
if (promptIndex === 0 && systemPrompt != "") {
77
return "<s>[INST] <<SYS>>\n" + systemPrompt + "\n<</SYS>>\n\n" + prompt + " [/INST]\n\n";
88
} else {
99
return "<s>[INST] " + prompt + " [/INST]\n\n";

src/cli/commands/DownloadCommand.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ import {defaultLlamaCppGitHubRepo, defaultLlamaCppRelease, llamaCppDirectory, te
1111
import {compileLlamaCpp} from "../../utils/compileLLamaCpp.js";
1212
import withOra from "../../utils/withOra.js";
1313
import {clearTempFolder} from "../../utils/clearTempFolder.js";
14+
import {setBinariesGithubRelease} from "../../utils/binariesGithubRelease.js";
1415

1516
type DownloadCommandArgs = {
1617
repo: string,
1718
release: "latest" | string,
1819
arch?: string,
1920
nodeTarget?: string,
20-
skipBuild?: boolean
21+
skipBuild?: boolean,
22+
updateBinariesReleaseMetadata?: boolean
2123
};
2224

2325
export const DownloadCommand: CommandModule<object, DownloadCommandArgs> = {
@@ -33,7 +35,7 @@ export const DownloadCommand: CommandModule<object, DownloadCommandArgs> = {
3335
.option("release", {
3436
type: "string",
3537
default: defaultLlamaCppRelease,
36-
description: "The tag of the llama.cpp release to download. Can also be set via the NODE_LLAMA_CPP_REPO_RELEASE environment variable"
38+
description: "The tag of the llama.cpp release to download. Set to \"latest\" to download the latest release. Can also be set via the NODE_LLAMA_CPP_REPO_RELEASE environment variable"
3739
})
3840
.option("arch", {
3941
type: "string",
@@ -47,12 +49,18 @@ export const DownloadCommand: CommandModule<object, DownloadCommandArgs> = {
4749
type: "boolean",
4850
default: false,
4951
description: "Skip building llama.cpp after downloading it"
52+
})
53+
.option("updateBinariesReleaseMetadata", {
54+
type: "boolean",
55+
hidden: true, // this for the CI to use
56+
default: false,
57+
description: "Update the binariesGithubRelease.json file with the release of llama.cpp that was downloaded"
5058
});
5159
},
5260
handler: DownloadLlamaCppCommand
5361
};
5462

55-
export async function DownloadLlamaCppCommand({repo, release, arch, nodeTarget, skipBuild}: DownloadCommandArgs) {
63+
export async function DownloadLlamaCppCommand({repo, release, arch, nodeTarget, skipBuild, updateBinariesReleaseMetadata}: DownloadCommandArgs) {
5664
const octokit = new Octokit();
5765
const [githubOwner, githubRepo] = repo.split("/");
5866

@@ -147,6 +155,10 @@ export async function DownloadLlamaCppCommand({repo, release, arch, nodeTarget,
147155
});
148156
}
149157

158+
if (updateBinariesReleaseMetadata) {
159+
await setBinariesGithubRelease(githubRelease!.data.tag_name);
160+
}
161+
150162
console.log();
151163
console.log();
152164
console.log(`${chalk.yellow("Repo:")} ${repo}`);

src/config.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import * as path from "path";
33
import * as os from "os";
44
import envVar from "env-var";
55
import * as uuid from "uuid";
6+
import {getBinariesGithubRelease} from "./utils/binariesGithubRelease.js";
67

78
const __dirname = path.dirname(fileURLToPath(import.meta.url));
89

@@ -14,12 +15,13 @@ export const llamaBinsDirectory = path.join(__dirname, "..", "llamaBins");
1415
export const llamaCppDirectory = path.join(llamaDirectory, "llama.cpp");
1516
export const tempDownloadDirectory = path.join(os.tmpdir(), "node-llama-cpp", uuid.v4());
1617
export const usedBinFlagJsonPath = path.join(llamaDirectory, "usedBin.json");
18+
export const binariesGithubReleasePath = path.join(llamaDirectory, "binariesGithubRelease.json");
1719

1820
export const defaultLlamaCppGitHubRepo = env.get("NODE_LLAMA_CPP_REPO")
1921
.default("ggerganov/llama.cpp")
2022
.asString();
2123
export const defaultLlamaCppRelease = env.get("NODE_LLAMA_CPP_REPO_RELEASE")
22-
.default("latest")
24+
.default(await getBinariesGithubRelease())
2325
.asString();
2426
export const defaultSkipDownload = env.get("NODE_LLAMA_CPP_SKIP_DOWNLOAD")
2527
.default("false")

src/llamaEvaluator/LlamaContext.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import {LLAMAContext, llamaCppNode} from "./LlamaBins.js";
1+
import {LLAMAContext} from "./LlamaBins.js";
22
import {LlamaModel} from "./LlamaModel.js";
33

44
export class LlamaContext {
@@ -18,12 +18,12 @@ export class LlamaContext {
1818
return this._ctx.decode(tokens);
1919
}
2020

21-
public async *evaluate(tokens: Uint32Array, getRestrictions?: () => Uint32Array) {
21+
public async *evaluate(tokens: Uint32Array) {
2222
let evalTokens = tokens;
2323

2424
if (this._prependBos) {
2525
const tokenArray = Array.from(tokens);
26-
tokenArray.unshift(llamaCppNode.tokenBos());
26+
tokenArray.unshift(this._ctx.tokenBos());
2727

2828
evalTokens = Uint32Array.from(tokenArray);
2929
this._prependBos = false;
@@ -32,10 +32,10 @@ export class LlamaContext {
3232
// eslint-disable-next-line no-constant-condition
3333
while (true) {
3434
// Evaluate to get the next token.
35-
const nextToken = await this._ctx.eval(evalTokens, getRestrictions?.());
35+
const nextToken = await this._ctx.eval(evalTokens);
3636

3737
// the assistant finished answering
38-
if (nextToken === llamaCppNode.tokenEos())
38+
if (nextToken === this._ctx.tokenEos())
3939
break;
4040

4141
yield nextToken;

src/utils/binariesGithubRelease.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import fs from "fs-extra";
2+
import {binariesGithubReleasePath} from "../config.js";
3+
4+
type BinariesGithubReleaseFile = {
5+
release: "latest" | string
6+
};
7+
8+
export async function getBinariesGithubRelease() {
9+
const binariesGithubRelease: BinariesGithubReleaseFile = await fs.readJson(binariesGithubReleasePath);
10+
11+
return binariesGithubRelease.release;
12+
}
13+
14+
export async function setBinariesGithubRelease(release: BinariesGithubReleaseFile["release"]) {
15+
const binariesGithubReleaseJson: BinariesGithubReleaseFile = {
16+
release: release
17+
};
18+
19+
await fs.writeJson(binariesGithubReleasePath, binariesGithubReleaseJson, {
20+
spaces: 4
21+
});
22+
}

0 commit comments

Comments
 (0)