Skip to content

Commit 4cf1fba

Browse files
authored
feat: get embedding for text (#144)
* feat: get embedding for text * feat(minor): improve `resolveChatWrapperBasedOnModel` logic * style: improve GitHub release notes formatting
1 parent 36c779d commit 4cf1fba

File tree

10 files changed

+254
-24
lines changed

10 files changed

+254
-24
lines changed

.releaserc.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ const homepageUrlWithoutTrailingSlash = homepageUrl.endsWith("/")
1414
? homepageUrl.slice(0, -1)
1515
: homepageUrl;
1616

17-
const newFooterTemplate = defaultFooterTemplate + "\n---\n" +
18-
`Shipped with \`llama.cpp\` release: [\`${llamaCppRelease.split("`").join("")}\`](https://github.com/${defaultLlamaCppGitHubRepo}/releases/tag/${encodeURIComponent(llamaCppRelease)}) ` +
19-
`(to use the latest \`llama.cpp\` release available, run \`npx --no ${cliBinName} download --release latest\`. [learn more](${homepageUrlWithoutTrailingSlash}/guide/building-from-source))\n`;
17+
const newFooterTemplate = defaultFooterTemplate + "\n---\n\n" +
18+
`Shipped with \`llama.cpp\` release [\`${llamaCppRelease.split("`").join("")}\`](https://github.com/${defaultLlamaCppGitHubRepo}/releases/tag/${encodeURIComponent(llamaCppRelease)})\n\n` +
19+
`> To use the latest \`llama.cpp\` release available, run \`npx --no ${cliBinName} download --release latest\`. ([learn more](${homepageUrlWithoutTrailingSlash}/guide/building-from-source#downloading-a-newer-release))\n`;
2020

2121
/**
22-
* @type {import('semantic-release').GlobalConfig}
22+
* @type {import("semantic-release").GlobalConfig}
2323
*/
2424
export default {
2525
"branches": [

llama/addon.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,9 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
358358
if (info.Length() > 1 && info[1].IsObject()) {
359359
Napi::Object options = info[1].As<Napi::Object>();
360360

361-
if (options.Has("seed")) {
361+
if (options.Has("noSeed")) {
362+
context_params.seed = time(NULL);
363+
} else if (options.Has("seed")) {
362364
context_params.seed = options.Get("seed").As<Napi::Number>().Uint32Value();
363365
}
364366

@@ -370,10 +372,6 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
370372
context_params.n_batch = options.Get("batchSize").As<Napi::Number>().Uint32Value();
371373
}
372374

373-
if (options.Has("logitsAll")) {
374-
context_params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
375-
}
376-
377375
if (options.Has("embedding")) {
378376
context_params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
379377
}
@@ -544,6 +542,23 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
544542
return info.Env().Undefined();
545543
}
546544

545+
Napi::Value GetEmbedding(const Napi::CallbackInfo& info) {
546+
if (disposed) {
547+
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
548+
return info.Env().Undefined();
549+
}
550+
551+
const int n_embd = llama_n_embd(model->model);
552+
const auto * embeddings = llama_get_embeddings(ctx);
553+
554+
Napi::Float64Array result = Napi::Float64Array::New(info.Env(), n_embd);
555+
for (size_t i = 0; i < n_embd; ++i) {
556+
result[i] = embeddings[i];
557+
}
558+
559+
return result;
560+
}
561+
547562
static void init(Napi::Object exports) {
548563
exports.Set(
549564
"AddonContext",
@@ -560,6 +575,7 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
560575
InstanceMethod("decodeBatch", &AddonContext::DecodeBatch),
561576
InstanceMethod("sampleToken", &AddonContext::SampleToken),
562577
InstanceMethod("acceptGrammarEvaluationStateToken", &AddonContext::AcceptGrammarEvaluationStateToken),
578+
InstanceMethod("getEmbedding", &AddonContext::GetEmbedding),
563579
InstanceMethod("dispose", &AddonContext::Dispose)
564580
}
565581
)

src/chatWrappers/resolveChatWrapperBasedOnModel.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ export function resolveChatWrapperBasedOnModel({
2929
if (fileType?.toLowerCase() === "gguf") {
3030
const lowercaseName = name?.toLowerCase();
3131
const lowercaseSubType = subType?.toLowerCase();
32-
const splitLowercaseSubType = lowercaseSubType?.split("-");
33-
const firstSplitLowercaseSubType = splitLowercaseSubType?.[0];
32+
const splitLowercaseSubType = lowercaseSubType?.split("-") ?? [];
33+
const firstSplitLowercaseSubType = splitLowercaseSubType[0];
3434

3535
if (lowercaseName === "llama")
3636
return LlamaChatWrapper;
@@ -48,6 +48,8 @@ export function resolveChatWrapperBasedOnModel({
4848
return AlpacaChatWrapper;
4949
else if (lowercaseName === "functionary")
5050
return FunctionaryChatWrapper;
51+
else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral"))
52+
return ChatMLChatWrapper;
5153
}
5254
}
5355

src/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import {LlamaJsonSchemaGrammar} from "./llamaEvaluator/LlamaJsonSchemaGrammar.js
55
import {LlamaJsonSchemaValidationError} from "./utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js";
66
import {LlamaGrammarEvaluationState, LlamaGrammarEvaluationStateOptions} from "./llamaEvaluator/LlamaGrammarEvaluationState.js";
77
import {LlamaContext, LlamaContextSequence} from "./llamaEvaluator/LlamaContext/LlamaContext.js";
8+
import {
9+
LlamaEmbeddingContext, type LlamaEmbeddingContextOptions, LlamaEmbedding, type LlamaEmbeddingJSON
10+
} from "./llamaEvaluator/LlamaEmbeddingContext.js";
811
import {
912
type LlamaContextOptions, type BatchingOptions, type LlamaContextSequenceRepeatPenalty, type CustomBatchingDispatchSchedule,
1013
type CustomBatchingPrioritizeStrategy, type BatchItem, type PrioritizedBatchItem, type ContextShiftOptions,
@@ -70,6 +73,10 @@ export {
7073
type ContextTokensDeleteRange,
7174
type EvaluationPriority,
7275
type LlamaContextSequenceRepeatPenalty,
76+
LlamaEmbeddingContext,
77+
type LlamaEmbeddingContextOptions,
78+
LlamaEmbedding,
79+
type LlamaEmbeddingJSON,
7380
LlamaChatSession,
7481
defineChatSessionFunction,
7582
type LlamaChatSessionOptions,

src/llamaEvaluator/LlamaContext/LlamaContext.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ export class LlamaContext {
4444
seed = null,
4545
contextSize = model.trainContextSize,
4646
batchSize = contextSize,
47-
logitsAll,
48-
embedding,
4947
threads = 6,
5048
batching: {
5149
dispatchSchedule: batchingDispatchSchedule = "nextTick",
5250
itemsPrioritizingStrategy: batchingItemsPrioritizingStrategy = "maximumParallelism"
53-
} = {}
51+
} = {},
52+
_embedding,
53+
_noSeed
5454
}: LlamaContextOptions) {
5555
if (model.disposed)
5656
throw new DisposedError();
@@ -63,9 +63,9 @@ export class LlamaContext {
6363
seed: seed != null ? Math.max(-1, Math.floor(seed)) : undefined,
6464
contextSize: contextSize * this._totalSequences, // each sequence needs its own <contextSize> of cells
6565
batchSize: this._batchSize,
66-
logitsAll,
67-
embedding,
68-
threads: Math.max(0, Math.floor(threads))
66+
threads: Math.max(0, Math.floor(threads)),
67+
embedding: _embedding,
68+
noSeed: _noSeed
6969
}));
7070
this._batchingOptions = {
7171
dispatchSchedule: batchingDispatchSchedule,

src/llamaEvaluator/LlamaContext/types.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,26 @@ export type LlamaContextOptions = {
2323
/** prompt processing batch size */
2424
batchSize?: number,
2525

26-
/** the llama_eval() call computes all logits, not just the last one */
27-
logitsAll?: boolean,
28-
29-
/** embedding mode only */
30-
embedding?: boolean
31-
3226
/**
3327
* number of threads to use to evaluate tokens.
3428
* set to 0 to use the maximum threads supported by the current machine hardware
3529
*/
3630
threads?: number,
3731

3832
/** control the parallel sequences processing behavior */
39-
batching?: BatchingOptions
33+
batching?: BatchingOptions,
34+
35+
/**
36+
* embedding mode only
37+
* @internal
38+
*/
39+
_embedding?: boolean,
40+
41+
/**
42+
* disable the seed generation
43+
* @internal
44+
*/
45+
_noSeed?: boolean
4046
};
4147
export type LlamaContextSequenceRepeatPenalty = {
4248
/** Tokens to lower the predication probability of to be the next predicted token */
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import {withLock} from "lifecycle-utils";
2+
import {Token} from "../types.js";
3+
import {isLlamaText, LlamaText} from "../utils/LlamaText.js";
4+
import {LlamaModel} from "./LlamaModel.js";
5+
import {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
6+
7+
export type LlamaEmbeddingContextOptions = {
8+
model: LlamaModel,
9+
10+
/** text context size */
11+
contextSize?: number,
12+
13+
/** prompt processing batch size */
14+
batchSize?: number,
15+
16+
/**
17+
* number of threads to use to evaluate tokens.
18+
* set to 0 to use the maximum threads supported by the current machine hardware
19+
*/
20+
threads?: number,
21+
};
22+
23+
export class LlamaEmbeddingContext {
24+
/** @internal */ private readonly _llamaContext: LlamaContext;
25+
/** @internal */ private readonly _sequence: LlamaContextSequence;
26+
27+
public constructor({
28+
model,
29+
contextSize = model.trainContextSize,
30+
batchSize = contextSize,
31+
threads = 6
32+
}: LlamaEmbeddingContextOptions) {
33+
const resolvedContextSize = Math.min(contextSize, model.trainContextSize);
34+
const resolvedBatchSize = Math.min(batchSize, resolvedContextSize);
35+
36+
this._llamaContext = new LlamaContext({
37+
model,
38+
contextSize: resolvedContextSize,
39+
batchSize: resolvedBatchSize,
40+
threads,
41+
_embedding: true,
42+
_noSeed: true
43+
});
44+
this._sequence = this._llamaContext.getSequence();
45+
}
46+
47+
public async getEmbeddingFor(input: Token[] | string | LlamaText) {
48+
const resolvedInput = typeof input === "string"
49+
? this._llamaContext.model.tokenize(input)
50+
: isLlamaText(input)
51+
? input.tokenize(this._llamaContext.model.tokenize)
52+
: input;
53+
54+
if (resolvedInput.length > this._llamaContext.contextSize)
55+
throw new Error(
56+
"Input is longer than the context size. " +
57+
"Try to increase the context size or use another model that supports longer contexts."
58+
);
59+
else if (resolvedInput.length === 0)
60+
return new LlamaEmbedding({vector: []});
61+
62+
return await withLock(this, "evaluate", async () => {
63+
await this._sequence.eraseContextTokenRanges([{
64+
start: 0,
65+
end: this._sequence.nextTokenIndex
66+
}]);
67+
68+
await this._sequence.evaluateWithoutGeneratingNewTokens(resolvedInput);
69+
70+
const embedding = this._llamaContext._ctx.getEmbedding();
71+
const embeddingVector = Array.from(embedding);
72+
73+
return new LlamaEmbedding({vector: embeddingVector});
74+
});
75+
}
76+
77+
public dispose() {
78+
return this._llamaContext.dispose();
79+
}
80+
81+
/** @hidden */
82+
public [Symbol.dispose]() {
83+
return this.dispose();
84+
}
85+
86+
public get disposed() {
87+
return this._llamaContext.disposed;
88+
}
89+
}
90+
91+
export type LlamaEmbeddingJSON = {
92+
type: "LlamaEmbedding",
93+
vector: number[]
94+
};
95+
96+
export class LlamaEmbedding {
97+
public readonly vector: number[];
98+
99+
public constructor({vector}: {vector: number[]}) {
100+
this.vector = vector;
101+
}
102+
103+
public toJSON(): LlamaEmbeddingJSON {
104+
return {
105+
type: "LlamaEmbedding",
106+
vector: this.vector
107+
};
108+
}
109+
110+
public static fromJSON(json: LlamaEmbeddingJSON) {
111+
if (json == null || json.type !== "LlamaEmbedding" || !(json.vector instanceof Array) ||
112+
json.vector.some(v => typeof v !== "number")
113+
)
114+
throw new Error("Invalid LlamaEmbedding JSON");
115+
116+
return new LlamaEmbedding({
117+
vector: json.vector
118+
});
119+
}
120+
}

src/utils/getBin.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ export type AddonContext = {
174174
shiftSequenceTokenCells(sequenceId: number, startPos: number, endPos: number, shiftDelta: number): void,
175175

176176
acceptGrammarEvaluationStateToken(grammarEvaluationState: AddonGrammarEvaluationState, token: Token): void,
177+
178+
getEmbedding(): Float64Array
177179
};
178180

179181
export type BatchLogitIndex = number & {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import {describe, expect, test} from "vitest";
2+
import {LlamaEmbeddingContext, LlamaModel} from "../../../src/index.js";
3+
import {getModelFile} from "../../utils/modelFiles.js";
4+
5+
describe("functionary", () => {
6+
describe("embedding", () => {
7+
test("deterministic", async () => {
8+
const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf");
9+
10+
const model = new LlamaModel({
11+
modelPath
12+
});
13+
const embeddingContext = new LlamaEmbeddingContext({
14+
model,
15+
contextSize: 4096
16+
});
17+
18+
const helloWorldEmbedding = await embeddingContext.getEmbeddingFor("Hello world");
19+
20+
const helloThereEmbedding = await embeddingContext.getEmbeddingFor("Hello there");
21+
22+
expect(helloWorldEmbedding.vector).to.not.eql(helloThereEmbedding.vector);
23+
24+
const helloWorld2Embedding = await embeddingContext.getEmbeddingFor("Hello world");
25+
26+
expect(helloWorld2Embedding.vector).to.eql(helloWorldEmbedding.vector);
27+
expect(helloWorld2Embedding.vector).to.not.eql(helloThereEmbedding.vector);
28+
29+
console.log(helloWorld2Embedding.vector);
30+
}, {
31+
timeout: 1000 * 60 * 60
32+
});
33+
34+
test("deterministic between runs", async () => {
35+
const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf");
36+
37+
const model = new LlamaModel({
38+
modelPath
39+
});
40+
const embeddingContext = new LlamaEmbeddingContext({
41+
model,
42+
contextSize: 4096
43+
});
44+
45+
const helloWorldEmbedding = await embeddingContext.getEmbeddingFor("Hello world");
46+
const helloThereEmbedding = await embeddingContext.getEmbeddingFor("Hello there");
47+
48+
expect(helloWorldEmbedding.vector).to.not.eql(helloThereEmbedding.vector);
49+
50+
embeddingContext.dispose();
51+
52+
const embeddingContext2 = new LlamaEmbeddingContext({
53+
model,
54+
contextSize: 4096
55+
});
56+
57+
const helloWorldEmbedding2 = await embeddingContext2.getEmbeddingFor("Hello world");
58+
const helloThereEmbedding2 = await embeddingContext2.getEmbeddingFor("Hello there");
59+
60+
expect(helloWorldEmbedding2.vector).to.eql(helloWorldEmbedding.vector);
61+
expect(helloThereEmbedding2.vector).to.eql(helloThereEmbedding.vector);
62+
}, {
63+
timeout: 1000 * 60 * 60
64+
});
65+
});
66+
});

test/standalone/parseModelFileName.test.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,15 @@ describe("parseModelFileName", () => {
7070
parameters: "13B"
7171
});
7272
});
73+
74+
test("dolphin-2.1-mistral-7b.Q4_K_M.gguf", () => {
75+
expect(parseModelFileName("dolphin-2.1-mistral-7b.Q4_K_M.gguf"))
76+
.toEqual({
77+
name: "dolphin",
78+
subType: "2.1-mistral",
79+
quantization: "Q4_K_M",
80+
fileType: "gguf",
81+
parameters: "7B"
82+
});
83+
});
7384
});

0 commit comments

Comments
 (0)