@@ -67,6 +67,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
67
67
}
68
68
}
69
69
70
+ llama_backend_init (false );
70
71
model = llama_load_model_from_file (modelPath.c_str (), params);
71
72
72
73
if (model == NULL ) {
@@ -124,7 +125,18 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
124
125
125
126
// Decode each token and accumulate the result.
126
127
for (size_t i = 0 ; i < tokens.ElementLength (); i++) {
127
- const char * str = llama_token_to_str (ctx, (llama_token)tokens[i]);
128
+ // source: https://github.com/ggerganov/llama.cpp/blob/232caf3c1581a6cb023571780ff41dc2d66d1ca0/llama.cpp#L799-L811
129
+ std::vector<char > result (8 , 0 );
130
+ const int n_tokens = llama_token_to_str (ctx, (llama_token)tokens[i], result.data (), result.size ());
131
+ if (n_tokens < 0 ) {
132
+ result.resize (-n_tokens);
133
+ int check = llama_token_to_str (ctx, (llama_token)tokens[i], result.data (), result.size ());
134
+ GGML_ASSERT (check == -n_tokens);
135
+ } else {
136
+ result.resize (n_tokens);
137
+ }
138
+
139
+ const char * str = result.data ();
128
140
if (str == nullptr ) {
129
141
Napi::Error::New (info.Env (), " Invalid token" ).ThrowAsJavaScriptException ();
130
142
return info.Env ().Undefined ();
@@ -134,6 +146,15 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
134
146
135
147
return Napi::String::New (info.Env (), ss.str ());
136
148
}
149
+ Napi::Value TokenBos (const Napi::CallbackInfo& info) {
150
+ return Napi::Number::From (info.Env (), llama_token_bos (ctx));
151
+ }
152
+ Napi::Value TokenEos (const Napi::CallbackInfo& info) {
153
+ return Napi::Number::From (info.Env (), llama_token_eos (ctx));
154
+ }
155
+ Napi::Value GetMaxContextSize (const Napi::CallbackInfo& info) {
156
+ return Napi::Number::From (info.Env (), llama_n_ctx (ctx));
157
+ }
137
158
Napi::Value Eval (const Napi::CallbackInfo& info);
138
159
static void init (Napi::Object exports) {
139
160
exports.Set (" LLAMAContext" ,
@@ -142,6 +163,9 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
142
163
{
143
164
InstanceMethod (" encode" , &LLAMAContext::Encode),
144
165
InstanceMethod (" decode" , &LLAMAContext::Decode),
166
+ InstanceMethod (" tokenBos" , &LLAMAContext::TokenBos),
167
+ InstanceMethod (" tokenEos" , &LLAMAContext::TokenEos),
168
+ InstanceMethod (" getMaxContextSize" , &LLAMAContext::GetMaxContextSize),
145
169
InstanceMethod (" eval" , &LLAMAContext::Eval),
146
170
}));
147
171
}
@@ -151,7 +175,6 @@ class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
151
175
class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
152
176
LLAMAContext* ctx;
153
177
std::vector<llama_token> tokens;
154
- std::vector<llama_token> restriction;
155
178
llama_token result;
156
179
157
180
public:
@@ -160,13 +183,6 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
160
183
Napi::Uint32Array tokens = info[0 ].As <Napi::Uint32Array>();
161
184
this ->tokens .reserve (tokens.ElementLength ());
162
185
for (size_t i = 0 ; i < tokens.ElementLength (); i++) { this ->tokens .push_back (static_cast <llama_token>(tokens[i])); }
163
-
164
- if (info.Length () > 1 && info[1 ].IsTypedArray ()) {
165
- Napi::Uint32Array restriction = info[1 ].As <Napi::Uint32Array>();
166
- this ->restriction .reserve (restriction.ElementLength ());
167
- for (size_t i = 0 ; i < restriction.ElementLength (); i++) { this ->restriction .push_back (static_cast <llama_token>(restriction[i])); }
168
- std::sort (this ->restriction .begin (), this ->restriction .end ());
169
- }
170
186
}
171
187
~LLAMAContextEvalWorker () { ctx->Unref (); }
172
188
using Napi::AsyncWorker::Queue;
@@ -175,39 +191,30 @@ class LLAMAContextEvalWorker : Napi::AsyncWorker, Napi::Promise::Deferred {
175
191
protected:
176
192
void Execute () {
177
193
// Perform the evaluation using llama_eval.
178
- int r = llama_eval (ctx->ctx , tokens.data (), tokens.size (), llama_get_kv_cache_token_count (ctx->ctx ), 6 );
194
+ int r = llama_eval (ctx->ctx , tokens.data (), int ( tokens.size () ), llama_get_kv_cache_token_count (ctx->ctx ), 6 );
179
195
if (r != 0 ) {
180
196
SetError (" Eval has failed" );
181
197
return ;
182
198
}
183
199
200
+ llama_token new_token_id = 0 ;
201
+
184
202
// Select the best prediction.
185
- float * logits = llama_get_logits (ctx->ctx );
186
- int n_vocab = llama_n_vocab (ctx->ctx );
187
- llama_token re;
188
- if (restriction.empty ()) {
189
- float max = logits[0 ];
190
- re = 0 ;
191
- for (llama_token id = 1 ; id < n_vocab; id++) {
192
- float logit = logits[id];
193
- if (logit > max) {
194
- max = logit;
195
- re = id;
196
- }
197
- }
198
- } else {
199
- float max = logits[restriction[0 ]];
200
- re = 0 ;
201
- for (size_t i = 1 ; i < restriction.size (); i++) {
202
- llama_token id = restriction[i];
203
- float logit = logits[id];
204
- if (logit > max) {
205
- max = logit;
206
- re = id;
207
- }
208
- }
203
+ auto logits = llama_get_logits (ctx->ctx );
204
+ auto n_vocab = llama_n_vocab (ctx->ctx );
205
+
206
+ std::vector<llama_token_data> candidates;
207
+ candidates.reserve (n_vocab);
208
+
209
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
210
+ candidates.emplace_back (llama_token_data{ token_id, logits[token_id], 0 .0f });
209
211
}
210
- result = re;
212
+
213
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
214
+
215
+ new_token_id = llama_sample_token_greedy (ctx->ctx , &candidates_p);
216
+
217
+ result = new_token_id;
211
218
}
212
219
void OnOK () {
213
220
Napi::Env env = Napi::AsyncWorker::Env ();
@@ -223,15 +230,11 @@ Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
223
230
return worker->Promise ();
224
231
}
225
232
226
- Napi::Value tokenBos (const Napi::CallbackInfo& info) { return Napi::Number::From (info.Env (), llama_token_bos ()); }
227
- Napi::Value tokenEos (const Napi::CallbackInfo& info) { return Napi::Number::From (info.Env (), llama_token_eos ()); }
228
233
Napi::Value systemInfo (const Napi::CallbackInfo& info) { return Napi::String::From (info.Env (), llama_print_system_info ()); }
229
234
230
235
Napi::Object registerCallback (Napi::Env env, Napi::Object exports) {
231
236
llama_backend_init (false );
232
237
exports.DefineProperties ({
233
- Napi::PropertyDescriptor::Function (" tokenBos" , tokenBos),
234
- Napi::PropertyDescriptor::Function (" tokenEos" , tokenEos),
235
238
Napi::PropertyDescriptor::Function (" systemInfo" , systemInfo),
236
239
});
237
240
LLAMAModel::init (exports);
0 commit comments