Skip to content

Commit 62ee6e3

Browse files
committed
feat: evaluateWithMetadata, token confidence
1 parent f050fa4 commit 62ee6e3

File tree

13 files changed

+1180
-107
lines changed

13 files changed

+1180
-107
lines changed

.config/typedoc.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
"interfacePropertiesFormat": "list",
2828
"sort": ["source-order"],
2929
"docsRoot": "../docs",
30-
"intentionallyNotExported": ["MergeOptionalUnionTypes", "GbnfJsonSchemaToTSType", "_LlamaText"],
30+
"intentionallyNotExported": ["MergeOptionalUnionTypes", "PickOptions", "GbnfJsonSchemaToTSType", "_LlamaText"],
3131
"useHTMLEncodedBrackets": true
3232
}

docs/guide/low-level-api.md

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,25 @@ and you can pass no sampling options to avoid making any adjustments to the prob
3838
It's best to avoid getting the full probabilities list unless you really need it,
3939
as passing it to the JavaScript side can be slow.
4040

41+
### Context Shift {#context-shift}
42+
When the context sequence is full and you want to evaluate more tokens onto it,
43+
some tokens will have to be removed to make room for new ones to be added.
44+
45+
Ideally, you'd want to do that on your logic level, so you can control which content to keep and which to remove.
46+
> All the high-level APIs of `node-llama-cpp` [automatically do that](./chat-context-shift.md).
47+
48+
If you don't do that, `node-llama-cpp` will automatically remove the oldest tokens from the context sequence state to make room for new ones.
49+
50+
You can customize the context shift strategy `node-llama-cpp` uses for the context sequence by configuring the [`contextShift`](../api/classes/LlamaContext.md#parameters) option when calling [`.getSequence(...)`](../api/classes/LlamaContext.md#getsequence),
51+
or by passing a customized the [`contextShift`](../api/type-aliases/SequenceEvaluateOptions#contextshift) option to the evaluation method you use.
52+
4153
## Simple Evaluation {#simple-evaluation}
42-
You can evaluate the given input tokens onto a context sequence using [`.evaluate`](../api/classes/LlamaContextSequence.md#evaluate)
54+
You can evaluate the given input tokens onto a context sequence using [`.evaluate(...)`](../api/classes/LlamaContextSequence.md#evaluate)
4355
and generate the next token for the last input token.
4456

4557
On each iteration of the returned iterator, the generated token is then added to the context sequence state and the next token is generated for it, and so on.
4658

47-
When using [`.evaluate`](../api/classes/LlamaContextSequence.md#evaluate), the configured [token predictor](./token-prediction.md) is used to speed up the generation process.
59+
When using [`.evaluate(...)`](../api/classes/LlamaContextSequence.md#evaluate), the configured [token predictor](./token-prediction.md) is used to speed up the generation process.
4860

4961
```typescript
5062
import {fileURLToPath} from "url";
@@ -130,9 +142,67 @@ console.log("Result: " + resText);
130142
```
131143
> If you want to adjust the token probabilities when generating output, consider using [token bias](./token-bias.md) instead
132144
145+
### With Metadata {#evaluation-with-metadata}
146+
You can use [`.evaluateWithMetadata(...)`](../api/classes/LlamaContextSequence.md#evaluatewithmetadata) to evaluate tokens onto the context sequence state like [`.evaluate(...)`](#simple-evaluation), but with metadata emitted for each token.
147+
148+
```typescript
149+
import {fileURLToPath} from "url";
150+
import path from "path";
151+
import {getLlama, Token, SequenceEvaluateOptions} from "node-llama-cpp";
152+
153+
const __dirname = path.dirname(fileURLToPath(import.meta.url));
154+
155+
const llama = await getLlama();
156+
const model = await llama.loadModel({
157+
modelPath: path.join(__dirname, "models", "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf")
158+
});
159+
const context = await model.createContext();
160+
const sequence = context.getSequence();
161+
162+
const input = "The best way to";
163+
const tokens = model.tokenize(input);
164+
const maxTokens = 10;
165+
const res: Array<{
166+
token: Token,
167+
confidence: number,
168+
probabilities: Map<Token, number>
169+
}> = [];
170+
const metadataOptions = {
171+
// configure which metadata should be returned
172+
confidence: true,
173+
probabilities: true
174+
} as const;
175+
const options: SequenceEvaluateOptions = {
176+
temperature: 0.8
177+
};
178+
179+
const iterator = sequence.evaluateWithMetadata(
180+
tokens,
181+
metadataOptions,
182+
options
183+
);
184+
for await (const item of iterator) {
185+
res.push({
186+
token: item.token,
187+
confidence: item.confidence,
188+
probabilities: new Map(
189+
// only keep the top 5 probabilities
190+
[...item.probabilities.entries()].slice(0, 5)
191+
)
192+
});
193+
194+
if (res.length >= maxTokens)
195+
break;
196+
}
197+
198+
const resText = model.detokenize(res.map(({token}) => token));
199+
console.log("Result: " + resText);
200+
console.log("With metadata:", res);
201+
```
202+
133203
### No Generation {#evaluation-without-generation}
134204
To evaluate the input tokens onto a context sequence without generating new tokens,
135-
you can use [`.evaluateWithoutGeneratingNewTokens`](../api/classes/LlamaContextSequence.md#evaluatewithoutgeneratingnewtokens).
205+
you can use [`.evaluateWithoutGeneratingNewTokens(...)`](../api/classes/LlamaContextSequence.md#evaluatewithoutgeneratingnewtokens).
136206

137207
```typescript
138208
import {fileURLToPath} from "url";
@@ -154,7 +224,8 @@ await sequence.evaluateWithoutGeneratingNewTokens(tokens);
154224
```
155225

156226
## Controlled Evaluation {#controlled-evaluation}
157-
To manually control for which of the input tokens to generate output, you can use [`.controlledEvaluate`](../api/classes/LlamaContextSequence.md#controlledevaluate).
227+
To manually control for which of the input tokens to generate output,
228+
you can use [`.controlledEvaluate(...)`](../api/classes/LlamaContextSequence.md#controlledevaluate).
158229

159230
```typescript
160231
import {fileURLToPath} from "url";
@@ -179,8 +250,8 @@ const lastToken = evaluateInput.pop() as Token;
179250
if (lastToken != null)
180251
evaluateInput.push([lastToken, {
181252
generateNext: {
182-
singleToken: true,
183-
probabilitiesList: true,
253+
token: true,
254+
probabilities: true,
184255
options: {
185256
temperature: 0.8
186257
}
@@ -222,7 +293,7 @@ as it may lead to unexpected results.
222293

223294
### Erase State Ranges {#erase-state-ranges}
224295
To erase a range of tokens from the context sequence state,
225-
you can use [`.eraseContextTokenRanges`](../api/classes/LlamaContextSequence.md#erasecontexttokenranges).
296+
you can use [`.eraseContextTokenRanges(...)`](../api/classes/LlamaContextSequence.md#erasecontexttokenranges).
226297

227298
```typescript
228299
import {fileURLToPath} from "url";

llama/addon/AddonContext.cpp

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,13 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
191191
AddonContext* ctx;
192192
AddonSampler* sampler;
193193
bool arrayResult = false;
194-
bool returnLogprobs = false;
195-
bool has_logprobs = false;
196-
size_t logprobs_size;
197-
llama_token * logprobs_tokens;
198-
float * logprobs_probs;
194+
bool returnProbabilities = false;
195+
bool returnConfidence = false;
196+
float tokenConfidence = -1;
197+
bool has_probabilities = false;
198+
size_t probabilities_size;
199+
llama_token * probabilities_tokens;
200+
float * probabilities_probs;
199201
int32_t batchLogitIndex;
200202
llama_token result;
201203
bool no_output = false;
@@ -209,16 +211,17 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
209211
batchLogitIndex = info[0].As<Napi::Number>().Int32Value();
210212
sampler = Napi::ObjectWrap<AddonSampler>::Unwrap(info[1].As<Napi::Object>());
211213
arrayResult = info.Length() > 2 && info[2].IsBoolean();
212-
returnLogprobs = arrayResult ? info[2].As<Napi::Boolean>().Value() : false;
214+
returnProbabilities = arrayResult ? info[2].As<Napi::Boolean>().Value() : false;
215+
returnConfidence = arrayResult && info.Length() > 3 && info[3].IsBoolean() ? info[3].As<Napi::Boolean>().Value() : false;
213216
sampler->Ref();
214217
}
215218
~AddonContextSampleTokenWorker() {
216219
ctx->Unref();
217220
sampler->Unref();
218221

219-
if (has_logprobs) {
220-
delete[] logprobs_tokens;
221-
delete[] logprobs_probs;
222+
if (has_probabilities) {
223+
delete[] probabilities_tokens;
224+
delete[] probabilities_probs;
222225
}
223226
}
224227

@@ -264,32 +267,84 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
264267

265268
llama_sampler_apply(sampler->chain, &cur_p);
266269

267-
if (returnLogprobs) {
270+
if (!(cur_p.selected >= 0 && cur_p.selected < (int32_t)cur_p.size)) {
271+
no_output = true;
272+
return;
273+
}
274+
275+
auto new_token_id = cur_p.data[cur_p.selected].id;
276+
277+
if (returnProbabilities || returnConfidence) {
268278
if (!cur_p.sorted) {
269279
std::sort(cur_p.data, cur_p.data + cur_p.size, [](const llama_token_data & a, const llama_token_data & b) {
270280
return a.logit > b.logit;
271281
});
272282
cur_p.sorted = true;
283+
284+
for (size_t i = 0; i < cur_p.size; i++) {
285+
if (cur_p.data[i].id == new_token_id) {
286+
cur_p.selected = i;
287+
break;
288+
}
289+
}
273290
}
291+
}
274292

275-
logprobs_size = cur_p.size;
276-
logprobs_tokens = new llama_token[logprobs_size];
277-
logprobs_probs = new float[logprobs_size];
293+
if (returnProbabilities) {
294+
probabilities_size = cur_p.size;
295+
probabilities_tokens = new llama_token[probabilities_size];
296+
probabilities_probs = new float[probabilities_size];
297+
float maxLogit = cur_p.size > 0 ? cur_p.data[0].logit : -INFINITY;
278298

279299
for (size_t i = 0; i < cur_p.size; i++) {
280-
logprobs_tokens[i] = cur_p.data[i].id;
281-
logprobs_probs[i] = cur_p.data[i].logit;
300+
auto logit = cur_p.data[i].logit;
301+
302+
probabilities_tokens[i] = cur_p.data[i].id;
303+
probabilities_probs[i] = logit;
304+
305+
if (logit > maxLogit) {
306+
maxLogit = logit;
307+
}
308+
}
309+
310+
if (probabilities_size > 0 && maxLogit != -INFINITY) {
311+
float sum = 0.0f;
312+
for (size_t i = 0; i < probabilities_size; i++) {
313+
float prob = expf(probabilities_probs[i] - maxLogit);
314+
probabilities_probs[i] = prob;
315+
sum += prob;
316+
}
317+
318+
for (size_t i = 0; i < probabilities_size; i++) {
319+
probabilities_probs[i] /= sum;
320+
}
282321
}
283322

284-
has_logprobs = true;
323+
has_probabilities = true;
285324
}
286325

287-
if (!(cur_p.selected >= 0 && cur_p.selected < (int32_t)cur_p.size)) {
288-
no_output = true;
289-
return;
326+
if (returnConfidence) {
327+
if (has_probabilities && cur_p.selected < probabilities_size) {
328+
tokenConfidence = probabilities_probs[cur_p.selected];
329+
} else {
330+
float maxLogit = cur_p.data[0].logit;
331+
float sum = 0.0f;
332+
for (size_t i = 0; i < cur_p.size; i++) {
333+
auto logit = cur_p.data[i].logit;
334+
335+
if (logit > maxLogit) {
336+
maxLogit = logit;
337+
}
338+
}
339+
340+
for (size_t i = 0; i < cur_p.size; i++) {
341+
sum += expf(cur_p.data[i].logit - maxLogit);
342+
}
343+
344+
tokenConfidence = expf(cur_p.data[cur_p.selected].logit - maxLogit) / sum;
345+
}
290346
}
291347

292-
auto new_token_id = cur_p.data[cur_p.selected].id;
293348
sampler->acceptToken(new_token_id);
294349
result = new_token_id;
295350
}
@@ -308,14 +363,18 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
308363

309364
Napi::Array resultArray = Napi::Array::New(Env(), 2);
310365
resultArray.Set(Napi::Number::New(Env(), 0), resultToken);
311-
312-
if (has_logprobs) {
313-
Napi::Array logprobs = Napi::Array::New(Env(), logprobs_size * 2);
314-
for (size_t i = 0; i < logprobs_size; i++) {
315-
logprobs.Set(i * 2, Napi::Number::New(Env(), logprobs_tokens[i]));
316-
logprobs.Set(i * 2 + 1, Napi::Number::New(Env(), logprobs_probs[i]));
366+
367+
if (has_probabilities) {
368+
Napi::Array probabilities = Napi::Array::New(Env(), probabilities_size * 2);
369+
for (size_t i = 0; i < probabilities_size; i++) {
370+
probabilities.Set(i * 2, Napi::Number::New(Env(), probabilities_tokens[i]));
371+
probabilities.Set(i * 2 + 1, Napi::Number::New(Env(), probabilities_probs[i]));
317372
}
318-
resultArray.Set(1, logprobs);
373+
resultArray.Set(1, probabilities);
374+
}
375+
376+
if (returnConfidence && tokenConfidence != -1) {
377+
resultArray.Set(2, Napi::Number::New(Env(), tokenConfidence));
319378
}
320379

321380
deferred.Resolve(resultArray);

src/bindings/AddonTypes.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ export type AddonContext = {
131131
sampleToken(
132132
batchLogitIndex: BatchLogitIndex,
133133
sampler: AddonSampler,
134-
logprobs: boolean
135-
): Promise<[Token | -1, (Token | number)[] | undefined]>,
134+
probabilities: boolean,
135+
confidence?: boolean
136+
): Promise<[token: Token | -1, probabilities: (Token | number)[] | undefined, confidence: number | undefined]>,
136137
disposeSequence(sequenceId: number): void,
137138

138139
// startPos in inclusive, endPos is exclusive

0 commit comments

Comments
 (0)