Skip to content

Commit 28c7984

Browse files
authored
fix: align embedding input with WPM vocabulary type models (#393)
1 parent 4d387de commit 28c7984

File tree

19 files changed

+570
-35
lines changed

19 files changed

+570
-35
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ jobs:
389389

390390
model-dependent-tests:
391391
name: Model dependent tests
392-
runs-on: macos-12
392+
runs-on: ubuntu-24.04
393393
env:
394394
NODE_LLAMA_CPP_GPU: false
395395
needs:
@@ -412,10 +412,10 @@ jobs:
412412
name: llama.cpp
413413
path: llama
414414

415-
- name: Install dependencies on macOS
415+
- name: Install dependencies on Ubuntu
416416
run: |
417-
brew install cmake ninja
418-
alias make=cmake
417+
sudo apt-get update
418+
sudo apt-get install ninja-build cmake
419419
420420
- name: Install modules
421421
run: npm ci

llama/addon/AddonContext.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
531531
}
532532

533533
const int n_embd = llama_n_embd(model->model);
534-
const auto* embeddings = llama_get_embeddings_seq(ctx, 0);
534+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
535+
const auto* embeddings = pooling_type == LLAMA_POOLING_TYPE_NONE ? NULL : llama_get_embeddings_seq(ctx, 0);
535536
if (embeddings == NULL) {
536537
embeddings = llama_get_embeddings_ith(ctx, inputTokensLength - 1);
537538

llama/addon/AddonModel.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "AddonModelLora.h"
1010

1111
static Napi::Value getNapiToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) {
12-
if (token < 0) {
12+
if (token < 0 || token == LLAMA_TOKEN_NULL) {
1313
return Napi::Number::From(info.Env(), -1);
1414
}
1515

@@ -565,6 +565,22 @@ Napi::Value AddonModel::EotToken(const Napi::CallbackInfo& info) {
565565

566566
return getNapiToken(info, model, llama_token_eot(model));
567567
}
568+
Napi::Value AddonModel::ClsToken(const Napi::CallbackInfo& info) {
569+
if (disposed) {
570+
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
571+
return info.Env().Undefined();
572+
}
573+
574+
return getNapiToken(info, model, llama_token_cls(model));
575+
}
576+
Napi::Value AddonModel::SepToken(const Napi::CallbackInfo& info) {
577+
if (disposed) {
578+
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
579+
return info.Env().Undefined();
580+
}
581+
582+
return getNapiToken(info, model, llama_token_sep(model));
583+
}
568584
Napi::Value AddonModel::GetTokenString(const Napi::CallbackInfo& info) {
569585
if (disposed) {
570586
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
@@ -624,11 +640,14 @@ Napi::Value AddonModel::GetVocabularyType(const Napi::CallbackInfo& info) {
624640
return Napi::Number::From(info.Env(), int32_t(vocabularyType));
625641
}
626642
Napi::Value AddonModel::ShouldPrependBosToken(const Napi::CallbackInfo& info) {
627-
const int addBos = llama_add_bos_token(model);
643+
const bool addBos = llama_add_bos_token(model);
628644

629-
bool shouldPrependBos = addBos != -1 ? bool(addBos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
645+
return Napi::Boolean::New(info.Env(), addBos);
646+
}
647+
Napi::Value AddonModel::ShouldAppendEosToken(const Napi::CallbackInfo& info) {
648+
const bool addEos = llama_add_eos_token(model);
630649

631-
return Napi::Boolean::New(info.Env(), shouldPrependBos);
650+
return Napi::Boolean::New(info.Env(), addEos);
632651
}
633652

634653
Napi::Value AddonModel::GetModelSize(const Napi::CallbackInfo& info) {
@@ -659,11 +678,14 @@ void AddonModel::init(Napi::Object exports) {
659678
InstanceMethod("middleToken", &AddonModel::MiddleToken),
660679
InstanceMethod("suffixToken", &AddonModel::SuffixToken),
661680
InstanceMethod("eotToken", &AddonModel::EotToken),
681+
InstanceMethod("clsToken", &AddonModel::ClsToken),
682+
InstanceMethod("sepToken", &AddonModel::SepToken),
662683
InstanceMethod("getTokenString", &AddonModel::GetTokenString),
663684
InstanceMethod("getTokenAttributes", &AddonModel::GetTokenAttributes),
664685
InstanceMethod("isEogToken", &AddonModel::IsEogToken),
665686
InstanceMethod("getVocabularyType", &AddonModel::GetVocabularyType),
666687
InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken),
688+
InstanceMethod("shouldAppendEosToken", &AddonModel::ShouldAppendEosToken),
667689
InstanceMethod("getModelSize", &AddonModel::GetModelSize),
668690
InstanceMethod("dispose", &AddonModel::Dispose),
669691
}

llama/addon/AddonModel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
4949
Napi::Value MiddleToken(const Napi::CallbackInfo& info);
5050
Napi::Value SuffixToken(const Napi::CallbackInfo& info);
5151
Napi::Value EotToken(const Napi::CallbackInfo& info);
52+
Napi::Value ClsToken(const Napi::CallbackInfo& info);
53+
Napi::Value SepToken(const Napi::CallbackInfo& info);
5254
Napi::Value GetTokenString(const Napi::CallbackInfo& info);
5355

5456
Napi::Value GetTokenAttributes(const Napi::CallbackInfo& info);
5557
Napi::Value IsEogToken(const Napi::CallbackInfo& info);
5658
Napi::Value GetVocabularyType(const Napi::CallbackInfo& info);
5759
Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info);
60+
Napi::Value ShouldAppendEosToken(const Napi::CallbackInfo& info);
5861
Napi::Value GetModelSize(const Napi::CallbackInfo& info);
5962

6063
static void init(Napi::Object exports);

src/bindings/AddonTypes.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,14 @@ export type AddonModel = {
9999
middleToken(): Token,
100100
suffixToken(): Token,
101101
eotToken(): Token,
102+
clsToken(): Token,
103+
sepToken(): Token,
102104
getTokenString(token: number): string,
103105
getTokenAttributes(token: Token): number,
104106
isEogToken(token: Token): boolean,
105107
getVocabularyType(): number,
106108
shouldPrependBosToken(): boolean,
109+
shouldAppendEosToken(): boolean,
107110
getModelSize(): number
108111
};
109112

src/bindings/types.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ export enum LlamaVocabularyType {
8585
none = "none",
8686
spm = "spm",
8787
bpe = "bpe",
88-
wpm = "wpm"
88+
wpm = "wpm",
89+
ugm = "ugm",
90+
rwkv = "rwkv"
8991
}
9092
export const LlamaVocabularyTypeValues = Object.freeze([
9193
LlamaVocabularyType.none,

src/evaluator/LlamaCompletion.ts

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {getQueuedTokensBeforeStopTrigger} from "../utils/getQueuedTokensBeforeSt
1111
import {safeEventCallback} from "../utils/safeEventCallback.js";
1212
import {pushAll} from "../utils/pushAll.js";
1313
import {GgufArchitectureType} from "../gguf/types/GgufMetadataTypes.js";
14+
import {resolveBeginningTokenToPrepend} from "../utils/tokenizerUtils.js";
1415
import {LlamaGrammarEvaluationState} from "./LlamaGrammarEvaluationState.js";
1516
import {LlamaGrammar} from "./LlamaGrammar.js";
1617
import {EvaluationPriority} from "./LlamaContext/types.js";
@@ -262,8 +263,10 @@ export class LlamaCompletion {
262263
if (this._sequence == null || this.disposed)
263264
throw new DisposedError();
264265

265-
const bosToken = this._sequence.model.tokens.bos;
266-
const shouldPrependBosToken = this._sequence.model.tokens.shouldPrependBosToken;
266+
const beginningTokenToPrepend = resolveBeginningTokenToPrepend(
267+
this._sequence.model.vocabularyType,
268+
this._sequence.model.tokens
269+
);
267270

268271
const extraEosTokens = getExtraCompletionEosTokens(this._sequence.model);
269272

@@ -274,8 +277,8 @@ export class LlamaCompletion {
274277
}): Promise<Token[]> {
275278
const res = [];
276279

277-
if (shouldPrependBosToken && bosToken != null)
278-
res.push(bosToken);
280+
if (beginningTokenToPrepend != null)
281+
res.push(beginningTokenToPrepend);
279282

280283
const inputTokensSize = Math.max(0, Math.min(maxTokens - res.length, tokens.length));
281284

@@ -305,7 +308,7 @@ export class LlamaCompletion {
305308
const resolvedInput = tokenizeInput(
306309
input,
307310
this._sequence.model.tokenizer,
308-
(shouldPrependBosToken && bosToken != null)
311+
beginningTokenToPrepend != null
309312
? "trimLeadingSpace"
310313
: undefined
311314
);
@@ -406,8 +409,10 @@ export class LlamaCompletion {
406409
const prefixToken = this._sequence.model.tokens.infill.prefix;
407410
const suffixToken = this._sequence.model.tokens.infill.suffix;
408411
const middleToken = this._sequence.model.tokens.infill.middle;
409-
const bosToken = this._sequence.model.tokens.bos;
410-
const shouldPrependBosToken = this._sequence.model.tokens.shouldPrependBosToken;
412+
const beginningTokenToPrepend = resolveBeginningTokenToPrepend(
413+
this._sequence.model.vocabularyType,
414+
this._sequence.model.tokens
415+
);
411416

412417
if (prefixToken == null || suffixToken == null)
413418
throw new UnsupportedError("Infill completions are not supported by this model");
@@ -425,7 +430,7 @@ export class LlamaCompletion {
425430
// 2 - InfillPrefix token, InfillSuffix token
426431
const specialTokensInContext = 2 +
427432
(middleToken != null ? 1 : 0) +
428-
((shouldPrependBosToken && bosToken != null) ? 1 : 0);
433+
(beginningTokenToPrepend != null ? 1 : 0);
429434
const resolvedMaxTokens = maxTokens - specialTokensInContext;
430435
let sizeLeftToFill = resolvedMaxTokens;
431436

@@ -464,8 +469,8 @@ export class LlamaCompletion {
464469

465470
const newContextState: Token[] = [];
466471

467-
if (shouldPrependBosToken && bosToken != null)
468-
newContextState.push(bosToken);
472+
if (beginningTokenToPrepend != null)
473+
newContextState.push(beginningTokenToPrepend);
469474

470475
if (middleToken != null) {
471476
newContextState.push(prefixToken);

src/evaluator/LlamaEmbeddingContext.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import {AsyncDisposeAggregator, EventRelay, withLock} from "lifecycle-utils";
22
import {Token} from "../types.js";
33
import {LlamaText} from "../utils/LlamaText.js";
44
import {tokenizeInput} from "../utils/tokenizeInput.js";
5+
import {resolveBeginningTokenToPrepend, resolveEndTokenToAppend} from "../utils/tokenizerUtils.js";
56
import {LlamaEmbedding} from "./LlamaEmbedding.js";
67
import type {LlamaModel} from "./LlamaModel/LlamaModel.js";
78
import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
@@ -72,7 +73,7 @@ export class LlamaEmbeddingContext {
7273
}
7374

7475
public async getEmbeddingFor(input: Token[] | string | LlamaText) {
75-
const resolvedInput = tokenizeInput(input, this._llamaContext.model.tokenizer);
76+
const resolvedInput = tokenizeInput(input, this._llamaContext.model.tokenizer, undefined, true);
7677

7778
if (resolvedInput.length > this._llamaContext.contextSize)
7879
throw new Error(
@@ -84,6 +85,14 @@ export class LlamaEmbeddingContext {
8485
vector: []
8586
});
8687

88+
const beginningToken = resolveBeginningTokenToPrepend(this.model.vocabularyType, this.model.tokens);
89+
if (beginningToken != null && resolvedInput[0] !== beginningToken)
90+
resolvedInput.unshift(beginningToken);
91+
92+
const endToken = resolveEndTokenToAppend(this.model.vocabularyType, this.model.tokens);
93+
if (endToken != null && resolvedInput.at(-1) !== endToken)
94+
resolvedInput.push(endToken);
95+
8796
return await withLock(this, "evaluate", async () => {
8897
await this._sequence.eraseContextTokenRanges([{
8998
start: 0,
@@ -118,6 +127,10 @@ export class LlamaEmbeddingContext {
118127
return this._llamaContext.disposed;
119128
}
120129

130+
public get model() {
131+
return this._llamaContext.model;
132+
}
133+
121134
/** @internal */
122135
public static async _create({
123136
_model
@@ -130,6 +143,9 @@ export class LlamaEmbeddingContext {
130143
createSignal,
131144
ignoreMemorySafetyChecks
132145
}: LlamaEmbeddingContextOptions) {
146+
if (_model.fileInsights.hasEncoder && _model.fileInsights.hasDecoder)
147+
throw new Error("Computing embeddings is not supported for encoder-decoder models.");
148+
133149
const llamaContext = await _model.createContext({
134150
contextSize,
135151
batchSize,

src/evaluator/LlamaModel/LlamaModel.ts

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ export class LlamaModel {
733733
if (modelLoaded)
734734
await model._model.dispose();
735735

736-
throw loadSignal.reason;
736+
throw loadSignal!.reason;
737737
} else if (!modelLoaded)
738738
throw new Error("Failed to load model");
739739

@@ -757,12 +757,17 @@ export class LlamaModelTokens {
757757
/** @internal */ private _bosToken?: Token;
758758
/** @internal */ private _eosToken?: Token;
759759
/** @internal */ private _eotToken?: Token;
760+
/** @internal */ private _clsToken?: Token;
761+
/** @internal */ private _sepToken?: Token;
760762
/** @internal */ private _nlToken?: Token;
761763
/** @internal */ private _bosString?: string;
762764
/** @internal */ private _eosString?: string;
763765
/** @internal */ private _eotString?: string;
766+
/** @internal */ private _clsString?: string;
767+
/** @internal */ private _sepString?: string;
764768
/** @internal */ private _nlString?: string;
765769
/** @internal */ private _shouldPrependBosToken?: boolean;
770+
/** @internal */ private _shouldAppendEosToken?: boolean;
766771

767772
private constructor(model: AddonModel, disposedState: DisposedState) {
768773
this._model = model;
@@ -826,6 +831,36 @@ export class LlamaModelTokens {
826831
return this._eotToken;
827832
}
828833

834+
/**
835+
* @returns The CLS (Classification) token.
836+
*/
837+
public get cls(): Token | null {
838+
this._ensureNotDisposed();
839+
840+
if (this._clsToken == null)
841+
this._clsToken = this._model.clsToken();
842+
843+
if (this._clsToken === -1)
844+
return null;
845+
846+
return this._clsToken;
847+
}
848+
849+
/**
850+
* @returns The SEP (Sentence Separator) token.
851+
*/
852+
public get sep(): Token | null {
853+
this._ensureNotDisposed();
854+
855+
if (this._sepToken == null)
856+
this._sepToken = this._model.sepToken();
857+
858+
if (this._sepToken === -1)
859+
return null;
860+
861+
return this._sepToken;
862+
}
863+
829864
/**
830865
* @returns The NL (New Line) token.
831866
*/
@@ -892,6 +927,40 @@ export class LlamaModelTokens {
892927
return this._eotString;
893928
}
894929

930+
/**
931+
* @returns The CLS (Classification) token text representation.
932+
*/
933+
public get clsString(): string | null {
934+
this._ensureNotDisposed();
935+
936+
const clsToken = this.cls;
937+
938+
if (clsToken == null)
939+
return null;
940+
941+
if (this._clsString == null)
942+
this._clsString = this._model.getTokenString(clsToken);
943+
944+
return this._clsString;
945+
}
946+
947+
/**
948+
* @returns The SEP (Sentence Separator) token text representation.
949+
*/
950+
public get sepString(): string | null {
951+
this._ensureNotDisposed();
952+
953+
const sepToken = this.sep;
954+
955+
if (sepToken == null)
956+
return null;
957+
958+
if (this._sepString == null)
959+
this._sepString = this._model.getTokenString(sepToken);
960+
961+
return this._sepString;
962+
}
963+
895964
/**
896965
* @returns The NL (New Line) token text representation.
897966
*/
@@ -921,6 +990,18 @@ export class LlamaModelTokens {
921990
return this._shouldPrependBosToken;
922991
}
923992

993+
/**
994+
* @returns Whether we should append an EOS (End Of Sequence) token for evaluations with this model.
995+
*/
996+
public get shouldAppendEosToken(): boolean {
997+
this._ensureNotDisposed();
998+
999+
if (this._shouldAppendEosToken == null)
1000+
this._shouldAppendEosToken = this.bos != null && this._model.shouldAppendEosToken();
1001+
1002+
return this._shouldAppendEosToken;
1003+
}
1004+
9241005
/** @internal */
9251006
private _ensureNotDisposed() {
9261007
if (this._disposedState.disposed)

0 commit comments

Comments
 (0)