Skip to content

Commit fb4c4df

Browse files
committed
fix: reranking edge case crash
1 parent 15590c3 commit fb4c4df

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

llama/addon/AddonContext.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
639639
}
640640

641641
int32_t inputTokensLength = info[0].As<Napi::Number>().Int32Value();
642+
int32_t maxVectorSize = (info.Length() > 1 && info[1].IsNumber()) ? info[1].As<Napi::Number>().Int32Value() : 0;
642643

643644
if (inputTokensLength <= 0) {
644645
Napi::Error::New(info.Env(), "Invalid input tokens length").ThrowAsJavaScriptException();
@@ -650,15 +651,16 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
650651
const auto* embeddings = pooling_type == LLAMA_POOLING_TYPE_NONE ? NULL : llama_get_embeddings_seq(ctx, 0);
651652
if (embeddings == NULL) {
652653
embeddings = llama_get_embeddings_ith(ctx, inputTokensLength - 1);
654+
}
653655

654-
if (embeddings == NULL) {
655-
Napi::Error::New(info.Env(), std::string("Failed to get embeddings for token ") + std::to_string(inputTokensLength - 1)).ThrowAsJavaScriptException();
656-
return info.Env().Undefined();
657-
}
656+
if (embeddings == NULL) {
657+
Napi::Error::New(info.Env(), std::string("Failed to get embeddings for token ") + std::to_string(inputTokensLength - 1)).ThrowAsJavaScriptException();
658+
return info.Env().Undefined();
658659
}
659660

660-
Napi::Float64Array result = Napi::Float64Array::New(info.Env(), n_embd);
661-
for (size_t i = 0; i < n_embd; ++i) {
661+
size_t resultSize = maxVectorSize == 0 ? n_embd : std::min(n_embd, maxVectorSize);
662+
Napi::Float64Array result = Napi::Float64Array::New(info.Env(), resultSize);
663+
for (size_t i = 0; i < resultSize; i++) {
662664
result[i] = embeddings[i];
663665
}
664666

src/bindings/AddonTypes.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ export type AddonContext = {
143143
// startPos in inclusive, endPos is exclusive
144144
shiftSequenceTokenCells(sequenceId: number, startPos: number, endPos: number, shiftDelta: number): void,
145145

146-
getEmbedding(inputTokensLength: number): Float64Array,
146+
getEmbedding(inputTokensLength: number, maxVectorSize?: number): Float64Array,
147147
getStateSize(): number,
148148
getThreads(): number,
149149
setThreads(threads: number): void,

src/evaluator/LlamaRankingContext.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ export class LlamaRankingContext {
197197
break; // only generate one token to get embeddings
198198
}
199199

200-
const embedding = this._llamaContext._ctx.getEmbedding(input.length);
200+
const embedding = this._llamaContext._ctx.getEmbedding(input.length, 1);
201201
if (embedding.length === 0)
202202
return 0;
203203

0 commit comments

Comments
 (0)